diff --git a/.gitignore b/.gitignore index 5eab9ab03..4093b00d1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ out/** .metals .bloop .project +.classpath +.settings +bin/ diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 340aa5b04..6343bc904 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E { return visitFallback(expr, context); } + /** + * Visits a Lambda expression. + * + * @param expr the Lambda expression + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.Lambda expr, C context) throws E { + return visitFallback(expr, context); + } + /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 2197b0bd0..72510ad2c 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -758,6 +758,34 @@ public R accept( } } + @Value.Immutable + abstract class Lambda implements Expression { + public abstract Type.Struct parameters(); + + public abstract Expression body(); + + @Override + public Type getType() { + List paramTypes = parameters().fields(); + Type returnType = body().getType(); + + // TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for + // declaring otherwise. + // See: https://github.com/substrait-io/substrait/issues/976 + return Type.withNullability(false).func(paramTypes, returnType); + } + + public static ImmutableExpression.Lambda.Builder builder() { + return ImmutableExpression.Lambda.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + /** * Base interface for user-defined literals. * diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 05540a924..147c7f8b7 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -321,6 +321,16 @@ public interface ExpressionVisitor outerReferenceStepsOut(); + public abstract Optional lambdaParameterReferenceStepsOut(); + @Override public Type getType() { return type(); @@ -38,13 +40,18 @@ public R accept( public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() - && !outerReferenceStepsOut().isPresent(); + && !outerReferenceStepsOut().isPresent() + && !lambdaParameterReferenceStepsOut().isPresent(); } public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + public boolean isLambdaParameterReference() { + return lambdaParameterReferenceStepsOut().isPresent(); + } + public FieldReference dereferenceStruct(int index) { Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + public static FieldReference newLambdaParameterReference( + int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(paramIndex)) + .type(lambdaParamsType.fields().get(paramIndex)) + .lambdaParameterReferenceStepsOut(stepsOut) + .build(); + } + public interface ReferenceSegment { FieldReference apply(FieldReference reference); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index eb2e45784..869aa6cc4 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -387,6 +387,18 @@ public Expression visit( }); } + @Override + public Expression visit( + io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambda( + io.substrait.proto.Expression.Lambda.newBuilder() + .setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct()) + .setBody(expr.body().accept(this, context))) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.UserDefinedAnyLiteral expr, @@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() .setStepsOut(expr.outerReferenceStepsOut().get())); + } else if (expr.lambdaParameterReferenceStepsOut().isPresent()) { + out.setLambdaParameterReference( + io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder() + .setStepsOut(expr.lambdaParameterReferenceStepsOut().get())); } else { out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index e4a9fffea..290e88c49 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,6 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; + private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -75,6 +76,25 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getDirectReference().getStructField().getField(), rootType, reference.getOuterReference().getStepsOut()); + case LAMBDA_PARAMETER_REFERENCE: + { + io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = + reference.getLambdaParameterReference(); + + int stepsOut = lambdaParamRef.getStepsOut(); + Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + + // Check for unsupported nested field access + if (reference.getDirectReference().getStructField().hasChild()) { + throw new UnsupportedOperationException( + "Nested field access in lambda parameters is not yet supported"); + } + + return FieldReference.newLambdaParameterReference( + reference.getDirectReference().getStructField().getField(), + lambdaParameters, + stepsOut); + } case ROOTTYPE_NOT_SET: default: throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); @@ -260,6 +280,27 @@ public Type visit(Type.Struct type) throws RuntimeException { } } + case LAMBDA: + { + io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); + Type.Struct parameters = + (Type.Struct) + protoTypeConverter.from( + io.substrait.proto.Type.newBuilder() + .setStruct(protoLambda.getParameters()) + .build()); + + lambdaParameterStack.push(parameters); + + Expression body; + try { + body = from(protoLambda.getBody()); + } finally { + lambdaParameterStack.pop(); + } + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); + } // TODO enum. case ENUM: throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); @@ -579,4 +620,42 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } + + /** + * A stack for tracking lambda parameter types during expression parsing. + * + *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. + * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: + * + *

    + *
  • stepsOut=0 refers to the innermost (current) lambda + *
  • stepsOut=1 refers to the next enclosing lambda + *
  • stepsOut=N refers to N levels up + *
+ */ + private static class LambdaParameterStack { + private final List stack = new ArrayList<>(); + + void push(Type.Struct parameters) { + stack.add(parameters); + } + + void pop() { + if (stack.isEmpty()) { + throw new IllegalArgumentException("Lambda parameter stack is empty"); + } + stack.remove(stack.size() - 1); + } + + Type.Struct get(int stepsOut) { + int index = stack.size() - 1 - stepsOut; + if (index < 0 || index >= stack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, stack.size())); + } + return stack.get(index); + } + } } diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..d8b6b1fa6 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for set functions. */ public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for string functions. */ public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; @@ -75,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "arithmetic", "comparison", "datetime", + "list", "logarithmic", "rounding", "rounding_decimal", diff --git a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java index c3360093d..b802d93e0 100644 --- a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator { T listE(T type); T mapE(T key, T value); + + T funcE(Iterable parameterTypes, T returnType); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..f35c21f7a 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -200,6 +200,23 @@ R accept(final ParameterizedTypeVisitor parameter } } + @Value.Immutable + abstract class Func extends BaseParameterizedType implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract ParameterizedType returnType(); + + public static ImmutableParameterizedType.Func.Builder builder() { + return ImmutableParameterizedType.Func.builder(); + } + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + } + @Value.Immutable abstract class ListType extends BaseParameterizedType implements NullableType { public abstract ParameterizedType name(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 4c3f314e5..b35d8a781 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -103,6 +103,16 @@ public ParameterizedType listE(ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } + @Override + public ParameterizedType funcE( + Iterable parameterTypes, ParameterizedType returnType) { + return ParameterizedType.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + @Override public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..755c99777 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.StringLiteral stringLiteral) throws E; + R visit(ParameterizedType.Func expr) throws E; + abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E { public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } + + @Override + public R visit(ParameterizedType.Func expr) throws E { + throw t(); + } } } diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index d6fc1bdb8..5c942f46a 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -150,6 +150,11 @@ public String visit(final Type.Map expr) { return "map"; } + @Override + public String visit(Type.Func type) throws RuntimeException { + return "func"; + } + @Override public String visit(final Type.UserDefined expr) { return String.format("u!%s", expr.name()); @@ -210,6 +215,11 @@ public String visit(ParameterizedType.Map expr) throws RuntimeException { return "map"; } + @Override + public String visit(ParameterizedType.Func expr) throws RuntimeException { + return "func"; + } + @Override public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { if (expr.value().toLowerCase().startsWith("any")) { diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index cc9bc068c..1ee1cae56 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -206,6 +206,22 @@ R acceptE(final TypeExpressionVisitor visitor) th } } + @Value.Immutable + abstract class Func extends BaseTypeExpression implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract TypeExpression returnType(); + + public static ImmutableTypeExpression.Func.Builder builder() { + return ImmutableTypeExpression.Func.builder(); + } + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract class BinaryOperation extends BaseTypeExpression { public enum OpType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index 808a0dd5a..62e1daee2 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -86,6 +86,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } + @Override + public TypeExpression funcE( + Iterable parameterTypes, TypeExpression returnType) { + return TypeExpression.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public static class Assign { String name; TypeExpression expr; diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 2ef76b50f..e1bef4398 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -26,6 +26,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Map expr) throws E; + R visit(TypeExpression.Func expr) throws E; + R visit(TypeExpression.BinaryOperation expr) throws E; R visit(TypeExpression.NotOperation expr) throws E; @@ -104,6 +106,11 @@ public R visit(TypeExpression.Map expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.Func expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.BinaryOperation expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 1e9254716..11fc14f27 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -439,6 +439,18 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.Lambda lambda, EmptyVisitationContext context) + throws E { + Optional newBody = lambda.body().accept(this, context); + + if (allEmpty(newBody)) { + return Optional.empty(); + } + return Optional.of( + Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + } + // utilities protected Optional> visitExprList( diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index e71b0b00c..36e5b6dcc 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -242,5 +242,10 @@ public Integer visit(Type.Map type) throws RuntimeException { public Integer visit(Type.UserDefined type) throws RuntimeException { return 0; } + + @Override + public Integer visit(Type.Func type) throws RuntimeException { + return 0; + } } } diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..e9d711d96 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException { public String visit(Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } + + @Override + public String visit(Type.Func type) throws RuntimeException { + return String.format( + "func%s<%s -> %s>", + n(type), + type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), + type.returnType().accept(this)); + } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 86eaa733c..e82f8c1cf 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -361,6 +361,22 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract class Func implements Type { + public abstract java.util.List parameterTypes(); + + public abstract Type returnType(); + + public static ImmutableType.Func.Builder builder() { + return ImmutableType.Func.builder(); + } + + @Override + public R accept(TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 999769cd9..6a897417e 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -84,6 +84,14 @@ public final Type intervalCompound(int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } + public Type.Func func(java.util.List parameterTypes, Type returnType) { + return Type.Func.builder() + .nullable(nullable) + .parameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index 9cf772232..d76b7a1f9 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -52,6 +52,8 @@ public interface TypeVisitor { R visit(Type.Decimal type) throws E; + R visit(Type.Func type) throws E; + R visit(Type.Struct type) throws E; R visit(Type.ListType type) throws E; @@ -192,6 +194,11 @@ public R visit(Type.PrecisionTimestampTZ type) throws E { throw t(); } + @Override + public R visit(Type.Func type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index 8555ab219..ddad886b1 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableParameterizedType; import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; @@ -411,19 +412,61 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) @Override public TypeExpression visitFunc(final SubstraitTypeParser.FuncContext ctx) { - throw new UnsupportedOperationException(); + boolean nullable = ctx.isnull != null; + + // Process function parameters + List paramExprs; + if (ctx.params instanceof SubstraitTypeParser.SingleFuncParamContext) { + paramExprs = + java.util.Collections.singletonList( + ((SubstraitTypeParser.SingleFuncParamContext) ctx.params).expr().accept(this)); + } else if (ctx.params instanceof SubstraitTypeParser.FuncParamsWithParensContext) { + paramExprs = + ((SubstraitTypeParser.FuncParamsWithParensContext) ctx.params) + .expr().stream() + .map(e -> e.accept(this)) + .collect(java.util.stream.Collectors.toList()); + } else { + throw new UnsupportedOperationException( + "Unknown funcParams type: " + ctx.params.getClass()); + } + + // Process return type + TypeExpression returnExpr = ctx.returnType.accept(this); + + // If all types are instances of Type, we return a Type + if (paramExprs.stream().allMatch(p -> p instanceof Type) && returnExpr instanceof Type) { + ImmutableType.Func.Builder builder = ImmutableType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((Type) p)); + return builder.returnType((Type) returnExpr).build(); + } + + // If all types are instances of ParameterizedType, we return a ParameterizedType + if (paramExprs.stream().allMatch(p -> p instanceof ParameterizedType) + && returnExpr instanceof ParameterizedType) { + checkParameterizedOrExpression(); + ImmutableParameterizedType.Func.Builder builder = + ParameterizedType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((ParameterizedType) p)); + return builder.returnType((ParameterizedType) returnExpr).build(); + } + + throw new UnsupportedOperationException( + "func type with TypeExpression-level parameter or return types are not yet supported"); } @Override public TypeExpression visitSingleFuncParam( final SubstraitTypeParser.SingleFuncParamContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitSingleFuncParam is handled in visitFunc directly"); } @Override public TypeExpression visitFuncParamsWithParens( final SubstraitTypeParser.FuncParamsWithParensContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitFuncParamsWithParens is handled in visitFunc directly"); } @Override diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 67d7bc9b5..3909d4033 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -142,6 +142,16 @@ public final T visit(final Type.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision()); } + @Override + public final T visit(final Type.Func expr) { + return typeContainer(expr) + .func( + expr.parameterTypes().stream() + .map(t -> t.accept(this)) + .collect(java.util.stream.Collectors.toList()), + expr.returnType().accept(this)); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 57b1f26b5..47842382f 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -119,6 +119,8 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); + public abstract T func(Iterable parameterTypes, T returnType); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index bdb600c1c..24231aebc 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -77,6 +77,13 @@ public Type from(io.substrait.proto.Type type) { case PRECISION_TIMESTAMP_TZ: return n(type.getPrecisionTimestampTz().getNullability()) .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case FUNC: + return n(type.getFunc().getNullability()) + .func( + type.getFunc().getParameterTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()), + from(type.getFunc().getReturnType())); case STRUCT: return n(type.getStruct().getNullability()) .struct( diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 6422904c4..c0e785db4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -154,6 +154,16 @@ public Type precisionTimestampTZ(Integer precision) { .build()); } + @Override + public Type func(Iterable parameterTypes, Type returnType) { + return wrap( + Type.Func.newBuilder() + .addAllParameterTypes(parameterTypes) + .setReturnType(returnType) + .setNullability(nullability) + .build()); + } + @Override public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); @@ -237,6 +247,8 @@ protected Type wrap(final Object o) { return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); } else if (o instanceof Type.PrecisionTimestampTZ) { return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Func) { + return bldr.setFunc((Type.Func) o).build(); } else if (o instanceof Type.Struct) { return bldr.setStruct((Type.Struct) o).build(); } else if (o instanceof Type.List) { diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java new file mode 100644 index 000000000..7aed4e961 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -0,0 +1,13 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.junit.jupiter.api.Test; + +class DefaultExtensionCatalogTest { + + @Test + void defaultCollectionLoads() { + assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } +} diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index 010ad123a..92989795b 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -8,6 +8,7 @@ import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import org.junit.jupiter.api.Test; class TestTypeParser { @@ -120,6 +121,11 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.precisionTimestamp(9), "PRECISION_TIMESTAMP?<9>"); test(v, r.precisionTimestampTZ(6), "PRECISION_TIMESTAMP_TZ<6>"); test(v, n.precisionTimestampTZ(9), "PRECISION_TIMESTAMP_TZ?<9>"); + + test(v, r.func(List.of(r.I8), r.I32), "func i32>"); + test(v, r.func(List.of(r.I8, r.I8), r.I32), "func<(i8, i8) -> i32>"); + test(v, n.func(List.of(r.I8), r.I32), "func? i32>"); + test(v, r.func(List.of(n.I8), n.I32), "func i32?>"); } private void parameterizedTests(ParseToPojo.Visitor v) { @@ -142,6 +148,16 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.precisionTimeE("P"), "PRECISION_TIME

"); test(v, pr.precisionTimestampE("P"), "PRECISION_TIMESTAMP

"); test(v, pr.precisionTimestampTZE("P"), "PRECISION_TIMESTAMP_TZ

"); + + test(v, pr.funcE(List.of(pr.parameter("any")), r.I64), "func i64>"); + test(v, pr.funcE(List.of(pr.parameter("any"), r.I64), r.I64), "func<(any, i64) -> i64>"); + test(v, pr.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func any1>"); + test(v, pn.funcE(List.of(pr.parameter("any")), n.I64), "func? i64?>"); + test(v, pn.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func? any1>"); + test( + v, + pr.funcE(List.of(pr.parameter("any1"), r.I8), pr.parameter("any1")), + "func<(any1, i8) -> any1>"); } @Test diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java new file mode 100644 index 000000000..c0a979a94 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -0,0 +1,361 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.type.Type; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests + * from substrait-go. + */ +class LambdaExpressionRoundtripTest extends TestBase { + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void zeroParameterLambda() { + Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(emptyParams).body(body).build(); + + verifyRoundTrip(lambda); + + // Verify the lambda type + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(0, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.returnType()); + } + + /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ + @Test + void validStepsOut0() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Lambda body references parameter 0 with stepsOut=0 + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify types + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(1, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + // Reference the 3rd parameter (string) + FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify return type is string + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.STRING, funcType.returnType()); + } + + /** Test type resolution for different parameter types. */ + @Test + void typeResolution() { + // Test cases: (paramTypes, fieldIndex, expectedReturnType) + record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + + List testCases = + List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); + + for (TestCase tc : testCases) { + Type.Struct params = + Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); + + FieldReference paramRef = + FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + verifyRoundTrip(lambda); + + // Verify the body type matches expected + assertEquals( + tc.expectedType, + lambda.body().getType(), + "Body type should match referenced parameter type"); + + // Verify lambda return type + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals( + tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); + } + } + + /** + * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> + * outer[$0] : i64) : func<(i64, i64) -> func i64>> + */ + @Test + void nestedLambdaWithOuterRef() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter 0 with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + verifyRoundTrip(outerLambda); + + // Verify structure + assertInstanceOf(Expression.Lambda.class, outerLambda.body()); + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + assertEquals(1, resultInner.parameters().fields().size()); + } + + /** + * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: + * string) -> (($0: fp64) -> outer[$2] : string) : func<...> + */ + @Test + void outerRefTypeResolution() { + Type.Struct outerParams = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); + + // Inner references outer's field 2 (string) with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + verifyRoundTrip(outerLambda); + + // Verify inner lambda's return type is string (from outer param 2) + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + Type.Func innerFuncType = (Type.Func) resultInner.getType(); + assertEquals( + R.STRING, + innerFuncType.returnType(), + "Inner lambda return type should be string from outer.$2"); + + // Verify body's type is also string + assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + verifyRoundTrip(lambda); + + // Verify the nested FieldRef has its type resolved + Expression.Cast resultCast = (Expression.Cast) lambda.body(); + assertInstanceOf(FieldReference.class, resultCast.input()); + FieldReference resultFieldRef = (FieldReference) resultCast.input(); + + assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + + // Verify lambda return type is i64 (cast output) + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.I64, funcType.returnType()); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + verifyRoundTrip(lambda); + + // Navigate to the deeply nested FieldRef (2 levels deep) + Expression.Cast resultOuter = (Expression.Cast) lambda.body(); + Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); + FieldReference resultFieldRef = (FieldReference) resultInner.input(); + + // Verify type is resolved even at depth 2 + assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType()); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + verifyRoundTrip(lambda); + } + + /** Test lambda getType returns correct Func type. */ + @Test + void lambdaGetTypeReturnsFunc() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + Type lambdaType = lambda.getType(); + + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + } + + // ==================== Validation Error Tests ==================== + + /** + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: + * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + */ + @Test + void invalidOuterRef_stepsOutTooHigh() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Create a parameter reference with stepsOut=1 but no outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(invalidRef).build(); + + // Convert to proto - this should work + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); + + // Converting back should fail because stepsOut=1 references non-existent outer lambda + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent outer lambda"); + } + + /** + * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: + * i32) -> $5 : INVALID (only has 1 param) + */ + @Test + void invalidFieldIndex_outOfBounds() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Create a reference to field 5, but lambda only has 1 parameter (index 0) + // This will fail at build time since newLambdaParameterReference accesses fields.get(5) + assertThrows( + IndexOutOfBoundsException.class, + () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, + "Should fail when field index is out of bounds"); + } + + /** + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> + * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + */ + @Test + void nestedInvalidOuterRef() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references stepsOut=2, but only 1 outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + // Convert to proto + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); + + // Converting back should fail because stepsOut=2 references non-existent grandparent + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent grandparent lambda"); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index a22756cc9..2d345642c 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -202,6 +202,12 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context return ""; } + @Override + public String visit(Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..f5f8d93e9 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -149,6 +149,11 @@ public String visit(Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } + @Override + public String visit(Type.Func type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + @Override public String visit(Struct type) throws RuntimeException { StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 5fd9efdcb..d7147a594 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -42,6 +42,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSubQuery; @@ -423,6 +424,23 @@ public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeExce return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } + @Override + public RexNode visit(Expression.Lambda expr, Context context) throws RuntimeException { + List parameters = + IntStream.range(0, expr.parameters().fields().size()) + .mapToObj( + i -> + new RexLambdaRef( + i, + "p" + i, + typeConverter.toCalcite(typeFactory, expr.parameters().fields().get(i)))) + .collect(Collectors.toList()); + + RexNode body = expr.body().accept(this, context); + + return rexBuilder.makeLambdaCall(body, parameters); + } + @Override public RexNode visit(Switch expr, Context context) throws RuntimeException { RexNode match = expr.match().accept(this, context); @@ -697,6 +715,23 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } return rexInputRef; + } else if (expr.isLambdaParameterReference()) { + // as of now calcite doesn't support nested lambda functions + // https://github.com/substrait-io/substrait-java/issues/711 + int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); + if (stepsOut != 0) { + throw new UnsupportedOperationException( + "Calcite does not support nested lambdas (stepsOut=" + stepsOut + ")"); + } + + final ReferenceSegment segment = expr.segments().get(0); + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return new RexLambdaRef(field.offset(), "p" + field.offset(), calciteType); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } } return visitFallback(expr, context); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 90f04a326..cc2b098e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -5,14 +5,32 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.calcite.sql.SqlBasicFunction; +import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; public class FunctionMappings { // Static list of signature mapping between Calcite SQL operators and Substrait base function // names. + /** The transform:list_func function; applies a lambda to each element of an array. */ + public static final SqlFunction TRANSFORM = + SqlBasicFunction.create( + "transform", + opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + + /** The filter:list_func function; filters elements of an array using a predicate lambda. */ + public static final SqlFunction FILTER = + SqlBasicFunction.create( + "filter", + opBinding -> opBinding.getOperandType(0), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -100,7 +118,9 @@ public class FunctionMappings { s(SqlLibraryOperators.RPAD, "rpad"), s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), - s(SqlLibraryOperators.PARSE_DATE, "strptime_date")) + s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), + s(TRANSFORM, "transform"), + s(FILTER, "filter")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..b56fa4fd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -128,6 +128,11 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.Func type) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } + @Override public Boolean visit(Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime @@ -234,4 +239,9 @@ public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return false; } + + @Override + public Boolean visit(ParameterizedType.Func expr) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..176b246e7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -7,6 +7,7 @@ import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; +import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -202,12 +203,30 @@ public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { @Override public Expression visitLambda(RexLambda rexLambda) { - throw new UnsupportedOperationException("RexLambda not supported"); + List paramTypes = + rexLambda.getParameters().stream() + .map(param -> typeConverter.toSubstrait(param.getType())) + .collect(Collectors.toList()); + + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + + Expression body = rexLambda.getExpression().accept(this); + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @Override public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { - throw new UnsupportedOperationException("RexLambdaRef not supported"); + int fieldIndex = rexLambdaRef.getIndex(); + Type paramType = typeConverter.toSubstrait(rexLambdaRef.getType()); + + return FieldReference.builder() + .addSegments(FieldReference.StructField.of(fieldIndex)) + .type(paramType) + .lambdaParameterReferenceStepsOut( + 0) // Always 0 since Calcite doesn't support nested Lambda expressions for now + // https://github.com/substrait-io/substrait-java/issues/711 + .build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java new file mode 100644 index 000000000..add746ee1 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -0,0 +1,149 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not + * support nested lambda expressions for the moment, so all tests use stepsOut=0. + */ +class LambdaExpressionTest extends PlanTestBase { + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void lambdaExpressionZeroParameters() { + Type.Struct params = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not + * support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java new file mode 100644 index 000000000..427dc67ba --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -0,0 +1,38 @@ +package io.substrait.isthmus; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.Plan; +import io.substrait.plan.ProtoPlanConverter; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +class LambdaRoundtripTest extends PlanTestBase { + + public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Plan.Builder builder = io.substrait.proto.Plan.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + @Test + void testBasicLambdaRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/basic-lambda.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFieldRefRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-field-ref.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFunctionRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-with-function.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } +} diff --git a/isthmus/src/test/resources/lambdas/basic-lambda.json b/isthmus/src/test/resources/lambdas/basic-lambda.json new file mode 100644 index 000000000..114e3ad6d --- /dev/null +++ b/isthmus/src/test/resources/lambdas/basic-lambda.json @@ -0,0 +1,132 @@ +{ + "version": { + "majorNumber": 0, + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-field-ref.json b/isthmus/src/test/resources/lambdas/lambda-field-ref.json new file mode 100644 index 000000000..58c041582 --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-field-ref.json @@ -0,0 +1,135 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-with-function.json b/isthmus/src/test/resources/lambdas/lambda-with-function.json new file mode 100644 index 000000000..9c0a1a55b --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-with-function.json @@ -0,0 +1,172 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + }, + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "multiply:i32_i32" + } + }, + { + "extensionFunction": { + "extensionUrnReference": 2, + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + }, + { + "value": { + "literal": { + "i32": 2 + } + } + } + ] + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index c280e1fb1..0db03388e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -158,4 +158,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(`type`: Type.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] }