Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ out/**
.metals
.bloop
.project
.classpath
.settings
bin/
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a Lambda expression.
*
* @param expr the Lambda expression
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.Lambda expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a scalar function invocation.
*
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,34 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class Lambda implements Expression {
public abstract Type.Struct parameters();

public abstract Expression body();

@Override
public Type getType() {
List<Type> paramTypes = parameters().fields();
Type returnType = body().getType();

// TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for
// declaring otherwise.
// See: https://github.com/substrait-io/substrait/issues/976
return Type.withNullability(false).func(paramTypes, returnType);
}

public static ImmutableExpression.Lambda.Builder builder() {
return ImmutableExpression.Lambda.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* Base interface for user-defined literals.
*
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
*/
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a Lambda expression.
*
* @param expr the Lambda expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.Lambda expr, C context) throws E;

/**
* Visit a user-defined any literal.
*
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/io/substrait/expression/FieldReference.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {

public abstract Optional<Integer> outerReferenceStepsOut();

public abstract Optional<Integer> lambdaParameterReferenceStepsOut();

@Override
public Type getType() {
return type();
Expand All @@ -38,13 +40,18 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
public boolean isSimpleRootReference() {
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
&& !outerReferenceStepsOut().isPresent()
&& !lambdaParameterReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

public boolean isLambdaParameterReference() {
return lambdaParameterReferenceStepsOut().isPresent();
}

public FieldReference dereferenceStruct(int index) {
Type newType = StructFieldFinder.getReferencedType(type(), index);
return dereference(newType, StructField.of(index));
Expand Down Expand Up @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
index, currentOffset));
}

public static FieldReference newLambdaParameterReference(
int paramIndex, Type.Struct lambdaParamsType, int stepsOut) {
return ImmutableFieldReference.builder()
.addSegments(StructField.of(paramIndex))
.type(lambdaParamsType.fields().get(paramIndex))
.lambdaParameterReferenceStepsOut(stepsOut)
.build();
}

public interface ReferenceSegment {
FieldReference apply(FieldReference reference);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ public Expression visit(
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
throws RuntimeException {
return io.substrait.proto.Expression.newBuilder()
.setLambda(
io.substrait.proto.Expression.Lambda.newBuilder()
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
.setBody(expr.body().accept(this, context)))
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
Expand Down Expand Up @@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
out.setOuterReference(
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
.setStepsOut(expr.outerReferenceStepsOut().get()));
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
out.setLambdaParameterReference(
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
} else {
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack();

public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -75,6 +76,25 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getDirectReference().getStructField().getField(),
rootType,
reference.getOuterReference().getStepsOut());
case LAMBDA_PARAMETER_REFERENCE:
{
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
reference.getLambdaParameterReference();

int stepsOut = lambdaParamRef.getStepsOut();
Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut);

// Check for unsupported nested field access
if (reference.getDirectReference().getStructField().hasChild()) {
throw new UnsupportedOperationException(
"Nested field access in lambda parameters is not yet supported");
}

return FieldReference.newLambdaParameterReference(
reference.getDirectReference().getStructField().getField(),
lambdaParameters,
stepsOut);
}
case ROOTTYPE_NOT_SET:
default:
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
Expand Down Expand Up @@ -260,6 +280,27 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case LAMBDA:
{
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
Type.Struct parameters =
(Type.Struct)
protoTypeConverter.from(
io.substrait.proto.Type.newBuilder()
.setStruct(protoLambda.getParameters())
.build());

lambdaParameterStack.push(parameters);

Expression body;
try {
body = from(protoLambda.getBody());
} finally {
lambdaParameterStack.pop();
}

return Expression.Lambda.builder().parameters(parameters).body(body).build();
}
// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down Expand Up @@ -579,4 +620,42 @@ public Expression.SortField fromSortField(SortField s) {
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}

/**
* A stack for tracking lambda parameter types during expression parsing.
*
* <p>When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack.
* Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference:
*
* <ul>
* <li>stepsOut=0 refers to the innermost (current) lambda
* <li>stepsOut=1 refers to the next enclosing lambda
* <li>stepsOut=N refers to N levels up
* </ul>
*/
private static class LambdaParameterStack {
private final List<Type.Struct> stack = new ArrayList<>();

void push(Type.Struct parameters) {
stack.add(parameters);
}

void pop() {
if (stack.isEmpty()) {
throw new IllegalArgumentException("Lambda parameter stack is empty");
}
stack.remove(stack.size() - 1);
}

Type.Struct get(int stepsOut) {
int index = stack.size() - 1 - stepsOut;
if (index < 0 || index >= stack.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, stack.size()));
}
return stack.get(index);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog {
/** Extension identifier for set functions. */
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";

/** Extension identifier for list functions. */
public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list";

/** Extension identifier for string functions. */
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";

Expand All @@ -75,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
"arithmetic",
"comparison",
"datetime",
"list",
"logarithmic",
"rounding",
"rounding_decimal",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator<T, I> {
T listE(T type);

T mapE(T key, T value);

T funcE(Iterable<? extends T> parameterTypes, T returnType);
}
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/function/ParameterizedType.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ <R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameter
}
}

@Value.Immutable
abstract class Func extends BaseParameterizedType implements NullableType {
public abstract java.util.List<ParameterizedType> parameterTypes();

public abstract ParameterizedType returnType();

public static ImmutableParameterizedType.Func.Builder builder() {
return ImmutableParameterizedType.Func.builder();
}

@Override
<R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameterizedTypeVisitor)
throws E {
return parameterizedTypeVisitor.visit(this);
}
}

@Value.Immutable
abstract class ListType extends BaseParameterizedType implements NullableType {
public abstract ParameterizedType name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ public ParameterizedType listE(ParameterizedType type) {
return ParameterizedType.ListType.builder().nullable(nullable).name(type).build();
}

@Override
public ParameterizedType funcE(
Iterable<? extends ParameterizedType> parameterTypes, ParameterizedType returnType) {
return ParameterizedType.Func.builder()
.nullable(nullable)
.addAllParameterTypes(parameterTypes)
.returnType(returnType)
.build();
}

@Override
public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) {
return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor<R, E extends Throwable> extends TypeVi

R visit(ParameterizedType.StringLiteral stringLiteral) throws E;

R visit(ParameterizedType.Func expr) throws E;

abstract class ParameterizedTypeThrowsVisitor<R, E extends Throwable>
extends TypeVisitor.TypeThrowsVisitor<R, E> implements ParameterizedTypeVisitor<R, E> {

Expand Down Expand Up @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E {
public R visit(ParameterizedType.StringLiteral stringLiteral) throws E {
throw t();
}

@Override
public R visit(ParameterizedType.Func expr) throws E {
throw t();
}
}
}
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ public String visit(final Type.Map expr) {
return "map";
}

@Override
public String visit(Type.Func type) throws RuntimeException {
return "func";
}

@Override
public String visit(final Type.UserDefined expr) {
return String.format("u!%s", expr.name());
Expand Down Expand Up @@ -210,6 +215,11 @@ public String visit(ParameterizedType.Map expr) throws RuntimeException {
return "map";
}

@Override
public String visit(ParameterizedType.Func expr) throws RuntimeException {
return "func";
}

@Override
public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException {
if (expr.value().toLowerCase().startsWith("any")) {
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/function/TypeExpression.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,22 @@ <R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) th
}
}

@Value.Immutable
abstract class Func extends BaseTypeExpression implements NullableType {
public abstract java.util.List<TypeExpression> parameterTypes();

public abstract TypeExpression returnType();

public static ImmutableTypeExpression.Func.Builder builder() {
return ImmutableTypeExpression.Func.builder();
}

@Override
<R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract class BinaryOperation extends BaseTypeExpression {
public enum OpType {
Expand Down
Loading
Loading