diff --git a/.gitignore b/.gitignore index 63b93fd42..5eab9ab03 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ out/** */bin .metals .bloop +.project diff --git a/core/src/main/antlr/SubstraitLexer.g4 b/core/src/main/antlr/SubstraitLexer.g4 new file mode 100644 index 000000000..4dfdf5272 --- /dev/null +++ b/core/src/main/antlr/SubstraitLexer.g4 @@ -0,0 +1,132 @@ +lexer grammar SubstraitLexer; + +options { + caseInsensitive = true; +} + +// Whitespace and comment handling +LineComment : '//' ~[\r\n]* -> channel(HIDDEN) ; +BlockComment : ( '/*' ( ~'*' | '*'+ ~[*/] ) '*'* '*/' ) -> channel(HIDDEN) ; +Whitespace : [ \t\r]+ -> channel(HIDDEN) ; + +fragment DIGIT: [0-9]; + +// Syntactic keywords. +If : 'IF'; +Then : 'THEN'; +Else : 'ELSE'; +Func : 'FUNC'; + +// TYPES +Boolean : 'BOOLEAN'; +I8 : 'I8'; +I16 : 'I16'; +I32 : 'I32'; +I64 : 'I64'; +FP32 : 'FP32'; +FP64 : 'FP64'; +String : 'STRING'; +Binary : 'BINARY'; +Timestamp: 'TIMESTAMP'; +Timestamp_TZ: 'TIMESTAMP_TZ'; +Date : 'DATE'; +Time : 'TIME'; +Interval_Year: 'INTERVAL_YEAR'; +Interval_Day: 'INTERVAL_DAY'; +Interval_Compound: 'INTERVAL_COMPOUND'; +UUID : 'UUID'; +Decimal : 'DECIMAL'; +Precision_Time: 'PRECISION_TIME'; +Precision_Timestamp: 'PRECISION_TIMESTAMP'; +Precision_Timestamp_TZ: 'PRECISION_TIMESTAMP_TZ'; +FixedChar: 'FIXEDCHAR'; +VarChar : 'VARCHAR'; +FixedBinary: 'FIXEDBINARY'; +Struct : 'STRUCT'; +NStruct : 'NSTRUCT'; +List : 'LIST'; +Map : 'MAP'; +UserDefined: 'U!'; + +// short names for types +Bool: 'BOOL'; +Str: 'STR'; +VBin: 'VBIN'; +Ts: 'TS'; +TsTZ: 'TSTZ'; +IYear: 'IYEAR'; +IDay: 'IDAY'; +ICompound: 'ICOMPOUND'; +Dec: 'DEC'; +PT: 'PT'; +PTs: 'PTS'; +PTsTZ: 'PTSTZ'; +FChar: 'FCHAR'; +VChar: 'VCHAR'; +FBin: 'FBIN'; + +Any: 'ANY'; +AnyVar: Any [0-9]; + +DoubleColon: '::'; + +// MATH +Plus : '+'; +Minus : '-'; +Asterisk : '*'; +ForwardSlash : '/'; +Percent : '%'; + +// COMPARE +Eq : '='; +Ne : '!='; +Gte : '>='; +Lte : '<='; +Gt : '>'; +Lt : '<'; +Bang : '!'; + +// ORGANIZE +OAngleBracket: Lt; +CAngleBracket: Gt; +OParen: '('; +CParen: ')'; +OBracket: '['; +CBracket: ']'; +Comma: ','; +Colon: ':'; +QMark: '?'; +Hash: '#'; +Dot: '.'; + + +// OPERATIONS +And : 'AND'; +Or : 'OR'; +Assign : ':='; +Arrow : '->'; + + + +fragment Int + : '1'..'9' Digit* + | '0' + ; + +fragment Digit + : '0'..'9' + ; + +Number + : '-'? Int + ; + +Identifier + : ('A'..'Z' | '_' | '$') ('A'..'Z' | '_' | '$' | Digit)* + ; + +Newline + : ( '\r' '\n'? + | '\n' + ) + ; diff --git a/core/src/main/antlr/SubstraitType.g4 b/core/src/main/antlr/SubstraitType.g4 index 9a7c0e7f9..d14df3976 100644 --- a/core/src/main/antlr/SubstraitType.g4 +++ b/core/src/main/antlr/SubstraitType.g4 @@ -1,209 +1,83 @@ grammar SubstraitType; -// -fragment A : [aA]; -fragment B : [bB]; -fragment C : [cC]; -fragment D : [dD]; -fragment E : [eE]; -fragment F : [fF]; -fragment G : [gG]; -fragment H : [hH]; -fragment I : [iI]; -fragment J : [jJ]; -fragment K : [kK]; -fragment L : [lL]; -fragment M : [mM]; -fragment N : [nN]; -fragment O : [oO]; -fragment P : [pP]; -fragment Q : [qQ]; -fragment R : [rR]; -fragment S : [sS]; -fragment T : [tT]; -fragment U : [uU]; -fragment V : [vV]; -fragment W : [wW]; -fragment X : [xX]; -fragment Y : [yY]; -fragment Z : [zZ]; +options { + caseInsensitive = true; +} +import SubstraitLexer; -If : I F; -Then : T H E N; -Else : E L S E; +startRule: expr EOF; -// TYPES -Boolean : B O O L E A N; -I8 : I '8'; -I16 : I '16'; -I32 : I '32'; -I64 : I '64'; -FP32 : F P '32'; -FP64 : F P '64'; -String : S T R I N G; -Binary : B I N A R Y; -Timestamp: T I M E S T A M P; -TimestampTZ: T I M E S T A M P '_' T Z; -Date : D A T E; -Time : T I M E; -IntervalYear: I N T E R V A L '_' Y E A R; -IntervalDay: I N T E R V A L '_' D A Y; -IntervalCompound: I N T E R V A L '_' C O M P O U N D; -UUID : U U I D; -Decimal : D E C I M A L; -PrecisionTimestamp: P R E C I S I O N '_' T I M E S T A M P; -PrecisionTimestampTZ: P R E C I S I O N '_' T I M E S T A M P '_' T Z; -FixedChar: F I X E D C H A R; -VarChar : V A R C H A R; -FixedBinary: F I X E D B I N A R Y; -Struct : S T R U C T; -NStruct : N S T R U C T; -List : L I S T; -Map : M A P; -ANY : A N Y; -UserDefined: U '!'; - - -// OPERATIONS -And : A N D; -Or : O R; -Assign : ':='; - -// COMPARE -Eq : '='; -NotEquals: '!='; -Gte : '>='; -Lte : '<='; -Gt : '>'; -Lt : '<'; -Bang : '!'; - - -// MATH -Plus : '+'; -Minus : '-'; -Asterisk : '*'; -ForwardSlash : '/'; -Percent : '%'; - -// ORGANIZE -OBracket : '['; -CBracket : ']'; -OParen : '('; -CParen : ')'; -SColon : ';'; -Comma : ','; -QMark : '?'; -Colon : ':'; -SingleQuote: '\''; - - -Number - : '-'? Int - ; - -Identifier - : ('a'..'z' | 'A'..'Z' | '_' | '$') ('a'..'z' | 'A'..'Z' | '_' | '$' | Digit)* - ; - -LineComment - : '//' ~[\r\n]* -> channel(HIDDEN) - ; - -BlockComment - : ( '/*' - ( '/'* BlockComment - | ~[/*] - | '/'+ ~[/*] - | '*'+ ~[/*] - )* - '*'* - '*/' - ) -> channel(HIDDEN) - ; - -Whitespace - : [ \t]+ -> channel(HIDDEN) - ; - -Newline - : ( '\r' '\n'? - | '\n' - ) - ; - - -fragment Int - : '1'..'9' Digit* - | '0' - ; - -fragment Digit - : '0'..'9' - ; - -start: expr EOF; +typeStatement: typeDef EOF; scalarType - : Boolean #Boolean - | I8 #i8 - | I16 #i16 - | I32 #i32 - | I64 #i64 - | FP32 #fp32 - | FP64 #fp64 - | String #string - | Binary #binary - | Timestamp #timestamp - | TimestampTZ #timestampTz - | Date #date - | Time #time - | IntervalYear #intervalYear - | UUID #uuid - | UserDefined Identifier #userDefined + : Boolean #boolean + | I8 #i8 + | I16 #i16 + | I32 #i32 + | I64 #i64 + | FP32 #fp32 + | FP64 #fp64 + | String #string + | Binary #binary + | Timestamp #timestamp + | Timestamp_TZ #timestampTz + | Date #date + | Time #time + | Interval_Year #intervalYear + | UUID #uuid ; parameterizedType - : FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar - | VarChar isnull='?'? Lt len=numericParameter Gt #varChar - | FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary - | Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal - | IntervalDay isnull='?'? Lt precision=numericParameter Gt #intervalDay - | IntervalCompound isnull='?'? Lt precision=numericParameter Gt #intervalCompound - | PrecisionTimestamp isnull='?'? Lt precision=numericParameter Gt #precisionTimestamp - | PrecisionTimestampTZ isnull='?'? Lt precision=numericParameter Gt #precisionTimestampTZ - | Struct isnull='?'? Lt expr (Comma expr)* Gt #struct - | NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct - | List isnull='?'? Lt expr Gt #list - | Map isnull='?'? Lt key=expr Comma value=expr Gt #map + : FixedChar isnull=QMark? Lt length=numericParameter Gt #fixedChar + | VarChar isnull=QMark? Lt length=numericParameter Gt #varChar + | FixedBinary isnull=QMark? Lt length=numericParameter Gt #fixedBinary + | Decimal isnull=QMark? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal + | Interval_Day isnull=QMark? Lt precision=numericParameter Gt #precisionIntervalDay + | Interval_Compound isnull=QMark? Lt precision=numericParameter Gt #precisionIntervalCompound + | Precision_Time isnull=QMark? Lt precision=numericParameter Gt #precisionTime + | Precision_Timestamp isnull=QMark? Lt precision=numericParameter Gt #precisionTimestamp + | Precision_Timestamp_TZ isnull=QMark? Lt precision=numericParameter Gt #precisionTimestampTZ + | Struct isnull=QMark? Lt expr (Comma expr)* Gt #struct + | NStruct isnull=QMark? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct + | List isnull=QMark? Lt expr Gt #list + | Map isnull=QMark? Lt key=expr Comma value=expr Gt #map + | Func isnull=QMark? Lt params=funcParams Arrow returnType=expr Gt #func + | UserDefined Identifier isnull=QMark? (Lt expr (Comma expr)* Gt)? #userDefined + ; + +funcParams + : expr #singleFuncParam + | OParen expr (Comma expr)* CParen #funcParamsWithParens ; numericParameter - : Number #numericLiteral - | Identifier #numericParameterName - | expr #numericExpression + : Number #numericLiteral + | Identifier #numericParameterName + | expr #numericExpression ; -anyType: ANY; +anyType + : Any isnull=QMark? + | AnyVar isnull=QMark? + ; -type - : scalarType isnull='?'? +typeDef + : scalarType isnull=QMark? | parameterizedType - | anyType isnull='?'? + | anyType ; -// : (OParen innerExpr CParen | innerExpr) - expr - : OParen expr CParen #ParenExpression - | Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=type Newline* #MultilineDefinition - | type #TypeLiteral - | number=Number #LiteralNumber - | identifier=Identifier isnull='?'? #TypeParam - | Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall - | left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | NotEquals | Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr - | If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr - | (Bang) expr #NotExpr - | ifExpr=expr QMark thenExpr=expr Colon elseExpr=expr #Ternary + : OParen expr CParen #ParenExpression + | Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=typeDef Newline* #MultilineDefinition + | typeDef #TypeLiteral + | Number #LiteralNumber + | Identifier isnull=QMark? #ParameterName + | Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall + | left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | Ne | + Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr + | If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr + | (Bang) expr #NotExpr + | ifExpr=expr QMark thenExpr=expr Colon elseExpr=expr #Ternary ; diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 6b89840f6..4c3f314e5 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -67,6 +67,13 @@ public ParameterizedType intervalCompoundE(String precision) { .build(); } + public ParameterizedType precisionTimeE(String precision) { + return ParameterizedType.PrecisionTime.builder() + .nullable(nullable) + .precision(parameter(precision, false)) + .build(); + } + public ParameterizedType precisionTimestampE(String precision) { return ParameterizedType.PrecisionTimestamp.builder() .nullable(nullable) diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index a183c1959..cc9bc068c 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -117,6 +117,21 @@ public static ImmutableTypeExpression.IntervalCompound.Builder builder() { } } + @Value.Immutable + abstract class PrecisionTime extends BaseTypeExpression implements NullableType { + + public abstract TypeExpression precision(); + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableTypeExpression.PrecisionTime.Builder builder() { + return ImmutableTypeExpression.PrecisionTime.builder(); + } + } + @Value.Immutable abstract class PrecisionTimestamp extends BaseTypeExpression implements NullableType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index b7524911b..808a0dd5a 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -48,6 +48,10 @@ public TypeExpression intervalCompoundE(TypeExpression precision) { .build(); } + public TypeExpression precisionTimeE(TypeExpression precision) { + return TypeExpression.PrecisionTime.builder().nullable(nullable).precision(precision).build(); + } + public TypeExpression precisionTimestampE(TypeExpression precision) { return TypeExpression.PrecisionTimestamp.builder() .nullable(nullable) diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 31d632c71..2ef76b50f 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -14,6 +14,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.IntervalCompound expr) throws E; + R visit(TypeExpression.PrecisionTime expr) throws E; + R visit(TypeExpression.PrecisionTimestamp expr) throws E; R visit(TypeExpression.PrecisionTimestampTZ expr) throws E; @@ -62,6 +64,11 @@ public R visit(TypeExpression.Decimal expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.PrecisionTime expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.PrecisionTimestamp expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 13936dc07..b714a7772 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -32,10 +32,10 @@ public static class ParseDeserializer extends StdDeserializer { private static final long serialVersionUID = 2105956703553161270L; - private final BiFunction converter; + private final BiFunction converter; public ParseDeserializer( - Class clazz, BiFunction converter) { + Class clazz, BiFunction converter) { super(clazz); this.converter = converter; } diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index ec96875bf..8555ab219 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -5,10 +5,12 @@ import io.substrait.function.ParameterizedTypeCreator; import io.substrait.function.TypeExpression; import io.substrait.function.TypeExpressionCreator; +import io.substrait.type.ImmutableType; import io.substrait.type.SubstraitTypeParser; import io.substrait.type.SubstraitTypeVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.function.Function; @@ -21,17 +23,18 @@ public class ParseToPojo { - public static Type type(String urn, SubstraitTypeParser.StartContext ctx) { + public static Type type(String urn, SubstraitTypeParser.StartRuleContext ctx) { Visitor visitor = Visitor.simple(urn); return (Type) ctx.accept(visitor); } public static ParameterizedType parameterizedType( - String urn, SubstraitTypeParser.StartContext ctx) { + String urn, SubstraitTypeParser.StartRuleContext ctx) { return (ParameterizedType) ctx.accept(Visitor.parameterized(urn)); } - public static TypeExpression typeExpression(String urn, SubstraitTypeParser.StartContext ctx) { + public static TypeExpression typeExpression( + String urn, SubstraitTypeParser.StartRuleContext ctx) { return ctx.accept(Visitor.expression(urn)); } @@ -78,10 +81,15 @@ private void checkExpression() { } @Override - public TypeExpression visitStart(final SubstraitTypeParser.StartContext ctx) { + public TypeExpression visitStartRule(SubstraitTypeParser.StartRuleContext ctx) { return ctx.expr().accept(this); } + @Override + public TypeExpression visitTypeStatement(SubstraitTypeParser.TypeStatementContext ctx) { + return ctx.typeDef().accept(this); + } + @Override public Type visitBoolean(final SubstraitTypeParser.BooleanContext ctx) { return withNull(ctx).BOOLEAN; @@ -109,7 +117,7 @@ public Type visitI64(final SubstraitTypeParser.I64Context ctx) { @Override public TypeExpression visitTypeLiteral(final SubstraitTypeParser.TypeLiteralContext ctx) { - return ctx.type().accept(this); + return ctx.typeDef().accept(this); } @Override @@ -158,7 +166,8 @@ public Type visitIntervalYear(final SubstraitTypeParser.IntervalYearContext ctx) } @Override - public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayContext ctx) { + public TypeExpression visitPrecisionIntervalDay( + final SubstraitTypeParser.PrecisionIntervalDayContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); if (precision instanceof Integer) { @@ -174,8 +183,8 @@ public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayCont } @Override - public TypeExpression visitIntervalCompound( - final SubstraitTypeParser.IntervalCompoundContext ctx) { + public TypeExpression visitPrecisionIntervalCompound( + final SubstraitTypeParser.PrecisionIntervalCompoundContext ctx) { boolean nullable = ctx.isnull != null; Object precision = i(ctx.precision); if (precision instanceof Integer) { @@ -196,16 +205,46 @@ public Type visitUuid(final SubstraitTypeParser.UuidContext ctx) { } @Override - public Type visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) { + public TypeExpression visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) { String name = ctx.Identifier().getSymbol().getText(); - return withNull(ctx).userDefined(urn, name); + boolean nullable = ctx.isnull != null; + List paramExprs = ctx.expr(); + if (paramExprs.isEmpty()) { + return withNull(nullable).userDefined(urn, name); + } + List params = new ArrayList<>(); + for (SubstraitTypeParser.ExprContext paramExpr : paramExprs) { + TypeExpression te = paramExpr.accept(this); + if (te instanceof Type) { + params.add(ImmutableType.ParameterDataType.builder().type((Type) te).build()); + } else if (te instanceof TypeExpression.IntegerLiteral) { + params.add( + ImmutableType.ParameterIntegerValue.builder() + .value(((TypeExpression.IntegerLiteral) te).value()) + .build()); + } else if (te instanceof ParameterizedType.StringLiteral) { + params.add( + ImmutableType.ParameterStringValue.builder() + .value(((ParameterizedType.StringLiteral) te).value()) + .build()); + } else { + throw new UnsupportedOperationException( + "Unsupported type parameter in user-defined type: " + te); + } + } + return Type.UserDefined.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(params) + .build(); } @Override public TypeExpression visitFixedChar(final SubstraitTypeParser.FixedCharContext ctx) { boolean nullable = ctx.isnull != null; return of( - ctx.len, + ctx.length, withNull(nullable)::fixedChar, withNullP(nullable)::fixedCharE, withNullE(nullable)::fixedCharE); @@ -232,7 +271,7 @@ private TypeExpression of( public TypeExpression visitVarChar(final SubstraitTypeParser.VarCharContext ctx) { boolean nullable = ctx.isnull != null; return of( - ctx.len, + ctx.length, withNull(nullable)::varChar, withNullP(nullable)::varCharE, withNullE(nullable)::varCharE); @@ -242,7 +281,7 @@ public TypeExpression visitVarChar(final SubstraitTypeParser.VarCharContext ctx) public TypeExpression visitFixedBinary(final SubstraitTypeParser.FixedBinaryContext ctx) { boolean nullable = ctx.isnull != null; return of( - ctx.len, + ctx.length, withNull(nullable)::fixedBinary, withNullP(nullable)::fixedBinaryE, withNullE(nullable)::fixedBinaryE); @@ -276,6 +315,22 @@ public TypeExpression visitDecimal(final SubstraitTypeParser.DecimalContext ctx) return withNullE(nullable).decimalE(ctx.precision.accept(this), ctx.scale.accept(this)); } + @Override + public TypeExpression visitPrecisionTime(final SubstraitTypeParser.PrecisionTimeContext ctx) { + boolean nullable = ctx.isnull != null; + Object precision = i(ctx.precision); + if (precision instanceof Integer) { + return withNull(nullable).precisionTime((Integer) precision); + } + if (precision instanceof String) { + checkParameterizedOrExpression(); + return withNullP(nullable).precisionTimeE((String) precision); + } + + checkExpression(); + return withNullE(nullable).precisionTimeE(ctx.precision.accept(this)); + } + @Override public TypeExpression visitPrecisionTimestamp( final SubstraitTypeParser.PrecisionTimestampContext ctx) { @@ -354,6 +409,23 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) throw new UnsupportedOperationException(); } + @Override + public TypeExpression visitFunc(final SubstraitTypeParser.FuncContext ctx) { + throw new UnsupportedOperationException(); + } + + @Override + public TypeExpression visitSingleFuncParam( + final SubstraitTypeParser.SingleFuncParamContext ctx) { + throw new UnsupportedOperationException(); + } + + @Override + public TypeExpression visitFuncParamsWithParens( + final SubstraitTypeParser.FuncParamsWithParensContext ctx) { + throw new UnsupportedOperationException(); + } + @Override public TypeExpression visitList(final SubstraitTypeParser.ListContext ctx) { boolean nullable = ctx.isnull != null; @@ -390,7 +462,7 @@ public TypeExpression visitMap(final SubstraitTypeParser.MapContext ctx) { private TypeCreator withNull(SubstraitTypeParser.ScalarTypeContext required) { return Type.withNullability( - ((SubstraitTypeParser.TypeContext) required.parent).isnull != null); + ((SubstraitTypeParser.TypeDefContext) required.parent).isnull != null); } private TypeCreator withNull(boolean nullable) { @@ -406,7 +478,7 @@ private ParameterizedTypeCreator withNullP(boolean nullable) { } @Override - public TypeExpression visitType(final SubstraitTypeParser.TypeContext ctx) { + public TypeExpression visitTypeDef(final SubstraitTypeParser.TypeDefContext ctx) { if (ctx.scalarType() != null) { return ctx.scalarType().accept(this); } else if (ctx.parameterizedType() != null) { @@ -418,7 +490,7 @@ public TypeExpression visitType(final SubstraitTypeParser.TypeContext ctx) { } @Override - public TypeExpression visitTypeParam(final SubstraitTypeParser.TypeParamContext ctx) { + public TypeExpression visitParameterName(final SubstraitTypeParser.ParameterNameContext ctx) { checkParameterizedOrExpression(); boolean nullable = ctx.isnull != null; return ParameterizedType.StringLiteral.builder() @@ -538,8 +610,9 @@ public TypeExpression visitNumericExpression( @Override public TypeExpression visitAnyType(SubstraitTypeParser.AnyTypeContext anyType) { - boolean nullable = ((SubstraitTypeParser.TypeContext) anyType.parent).isnull != null; - return withNullP(nullable).parameter("any"); + boolean nullable = anyType.isnull != null; + String name = anyType.AnyVar() != null ? anyType.AnyVar().getText().toLowerCase() : "any"; + return withNullP(nullable).parameter(name); } @Override diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index b8e94793b..7d7742578 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -28,7 +28,7 @@ public static TypeExpression parseExpression(String str, String urn) { return parse(str, urn, ParseToPojo::typeExpression); } - private static SubstraitTypeParser.StartContext parse(String str) { + private static SubstraitTypeParser.StartRuleContext parse(String str) { SubstraitTypeLexer lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); lexer.removeErrorListeners(); lexer.addErrorListener(TypeErrorListener.INSTANCE); @@ -36,11 +36,11 @@ private static SubstraitTypeParser.StartContext parse(String str) { SubstraitTypeParser parser = new io.substrait.type.SubstraitTypeParser(tokenStream); parser.removeErrorListeners(); parser.addErrorListener(TypeErrorListener.INSTANCE); - return parser.start(); + return parser.startRule(); } public static T parse( - String str, String urn, BiFunction func) { + String str, String urn, BiFunction func) { return func.apply(urn, parse(str)); } diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index e9c39ccf5..010ad123a 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -5,6 +5,8 @@ import io.substrait.function.ParameterizedTypeCreator; import io.substrait.function.TypeExpression; import io.substrait.function.TypeExpressionCreator; +import io.substrait.type.ImmutableType; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import org.junit.jupiter.api.Test; @@ -107,6 +109,17 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.struct(r.I8, n.I16), "STRUCT?"); test(v, n.list(n.I8), "LIST?"); test(v, n.map(r.I16, n.I8), "MAP?"); + + test(v, r.intervalDay(6), "INTERVAL_DAY<6>"); + test(v, n.intervalDay(9), "INTERVAL_DAY?<9>"); + test(v, r.intervalCompound(6), "INTERVAL_COMPOUND<6>"); + test(v, n.intervalCompound(9), "INTERVAL_COMPOUND?<9>"); + test(v, r.precisionTime(6), "PRECISION_TIME<6>"); + test(v, n.precisionTime(9), "PRECISION_TIME?<9>"); + test(v, r.precisionTimestamp(6), "PRECISION_TIMESTAMP<6>"); + test(v, n.precisionTimestamp(9), "PRECISION_TIMESTAMP?<9>"); + test(v, r.precisionTimestampTZ(6), "PRECISION_TIMESTAMP_TZ<6>"); + test(v, n.precisionTimestampTZ(9), "PRECISION_TIMESTAMP_TZ?<9>"); } private void parameterizedTests(ParseToPojo.Visitor v) { @@ -114,12 +127,63 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.structE(r.I8, r.I16, n.I8, pr.parameter("K")), "STRUCT"); test(v, pr.parameter("any"), "any"); test(v, pn.parameter("any"), "any?"); + test(v, pr.parameter("any1"), "any1"); + test(v, pn.parameter("any1"), "any1?"); test(v, pn.listE(pr.parameter("any")), "list?"); test(v, pn.listE(pn.parameter("any")), "list?"); test(v, pn.structE(r.I8, r.I16, n.I8, pr.parameter("K")), "STRUCT?"); test(v, pr.decimalE("P", "S"), "DECIMAL"); test(v, pr.decimalE("P", "0"), "DECIMAL"); test(v, pr.decimalE("14", "S"), "DECIMAL<14, S>"); + + test(v, pr.intervalDayE("P"), "INTERVAL_DAY

"); + test(v, pr.intervalCompoundE("P"), "INTERVAL_COMPOUND

"); + test(v, pn.precisionTimeE("P"), "PRECISION_TIME?

"); + test(v, pr.precisionTimeE("P"), "PRECISION_TIME

"); + test(v, pr.precisionTimestampE("P"), "PRECISION_TIMESTAMP

"); + test(v, pr.precisionTimestampTZE("P"), "PRECISION_TIMESTAMP_TZ

"); + } + + @Test + void userDefinedWithTypeParams() { + ParseToPojo.Visitor v = ParseToPojo.Visitor.simple(URN); + test( + v, + Type.UserDefined.builder() + .nullable(false) + .urn(URN) + .name("foo") + .addTypeParameters(ImmutableType.ParameterDataType.builder().type(r.I8).build()) + .build(), + "u!foo"); + test( + v, + Type.UserDefined.builder() + .nullable(true) + .urn(URN) + .name("foo") + .addTypeParameters(ImmutableType.ParameterDataType.builder().type(r.I8).build()) + .build(), + "u!foo?"); + test( + v, + Type.UserDefined.builder() + .nullable(false) + .urn(URN) + .name("foo") + .addTypeParameters(ImmutableType.ParameterIntegerValue.builder().value(1).build()) + .build(), + "u!foo<1>"); + test( + v, + Type.UserDefined.builder() + .nullable(false) + .urn(URN) + .name("foo") + .addTypeParameters(ImmutableType.ParameterDataType.builder().type(r.I8).build()) + .addTypeParameters(ImmutableType.ParameterIntegerValue.builder().value(2).build()) + .build(), + "u!foo"); } private static void test(ParseToPojo.Visitor visitor, TypeExpression expected, String toParse) { diff --git a/spark/spark_dialect.yaml b/spark/spark_dialect.yaml index 79678251a..9bcfd5e6a 100644 --- a/spark/spark_dialect.yaml +++ b/spark/spark_dialect.yaml @@ -57,14 +57,17 @@ supported_types: system_metadata: name: "TimestampNTZType" supported_as_column: true + max_precision: 9 - type: "PRECISION_TIMESTAMP_TZ" system_metadata: name: "TimestampType" supported_as_column: true + max_precision: 9 - type: "INTERVAL_DAY" system_metadata: name: "DayTimeIntervalType" supported_as_column: true + max_precision: 9 - type: "INTERVAL_YEAR" system_metadata: name: "YearMonthIntervalType" diff --git a/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala b/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala index 2a527e6d1..1a9db477b 100644 --- a/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala +++ b/spark/src/main/scala/io/substrait/spark/utils/DialectGenerator.scala @@ -7,8 +7,8 @@ import io.substrait.spark.expression.Sig import org.apache.spark.sql.catalyst.expressions.{BinaryOperator, Expression, Literal} import org.apache.spark.sql.types.{ByteType, DateType, DayTimeIntervalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, TimestampNTZType, TimestampType, YearMonthIntervalType} -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.dataformat.yaml.YAMLMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.networknt.schema.{InputFormat, SchemaRegistry, SpecificationVersion} import io.substrait.extension.SimpleExtension @@ -29,7 +29,7 @@ case class Dialect( // Types section case class TypeMetadata(name: String, supported_as_column: Boolean) -case class SupportedType(`type`: String, system_metadata: TypeMetadata) +case class SupportedType(`type`: String, system_metadata: TypeMetadata, max_precision: Option[Integer] = None) // Functions section case class FunctionMetadata(name: String, notation: String) @@ -83,7 +83,12 @@ class DialectGenerator { def generateYaml(): String = { // Generate the dialect YAML - val mapper = new ObjectMapper(new YAMLFactory()).registerModules(DefaultScalaModule) + val mapper = YAMLMapper + .builder() + .defaultPropertyInclusion( + JsonInclude.Value.ALL_NON_ABSENT) + .addModule(DefaultScalaModule) + .build() val yaml = mapper.writeValueAsString(generate()) // Validate against the substrait dialect schema @@ -113,9 +118,9 @@ class DialectGenerator { SupportedType("FIXED_CHAR", TypeMetadata("StringType", true)), SupportedType("BINARY", TypeMetadata("BinaryType", true)), SupportedType("BOOL", TypeMetadata("BooleanType", true)), - SupportedType("PRECISION_TIMESTAMP", TypeMetadata("TimestampNTZType", true)), - SupportedType("PRECISION_TIMESTAMP_TZ", TypeMetadata("TimestampType", true)), - SupportedType("INTERVAL_DAY", TypeMetadata("DayTimeIntervalType", true)), + SupportedType("PRECISION_TIMESTAMP", TypeMetadata("TimestampNTZType", true), max_precision = Some(9)), + SupportedType("PRECISION_TIMESTAMP_TZ", TypeMetadata("TimestampType", true), max_precision = Some(9)), + SupportedType("INTERVAL_DAY", TypeMetadata("DayTimeIntervalType", true), max_precision = Some(9)), SupportedType("INTERVAL_YEAR", TypeMetadata("YearMonthIntervalType", true)), SupportedType("LIST", TypeMetadata("ArrayType", true)), SupportedType("MAP", TypeMetadata("MapType", true)), diff --git a/substrait b/substrait index a9b90657d..e4ce3f871 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit a9b90657db1e51bba69fcbc4a8c1edac3b661975 +Subproject commit e4ce3f8710d91cb95cce010c6b59eb1d618bb2a6