diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 34e76215f3..ceafc157c4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -84,9 +84,35 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { val keyType = expr.left.dataType.asInstanceOf[ArrayType].elementType val valueType = expr.right.dataType.asInstanceOf[ArrayType].elementType val returnType = MapType(keyType = keyType, valueType = valueType) - val mapFromArraysExpr = - scalarFunctionExprToProtoWithReturnType("map", returnType, false, keysExpr, valuesExpr) - optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) + for { + andBinaryExprProto <- createAndBinaryExpr(expr, inputs, binding) + mapFromArraysExprProto <- scalarFunctionExprToProto("map", keysExpr, valuesExpr) + nullLiteralExprProto <- exprToProtoInternal(Literal(null, returnType), inputs, binding) + } yield { + val caseWhenExprProto = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(andBinaryExprProto) + .addThen(mapFromArraysExprProto) + .setElseExpr(nullLiteralExprProto) + .build() + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(caseWhenExprProto) + .build() + } + } + + private def createAndBinaryExpr( + expr: MapFromArrays, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + IsNotNull(expr.left), + IsNotNull(expr.right), + inputs, + binding, + (builder, binaryExpr) => builder.setAnd(binaryExpr)) } } diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql b/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql index 5d6ac3d550..3016eb5ff6 100644 --- a/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql +++ b/spark/src/test/resources/sql-tests/expressions/map/map_from_arrays.sql @@ -21,16 +21,43 @@ statement CREATE TABLE test_map_from_arrays(k array, v array) USING parquet statement -INSERT INTO test_map_from_arrays VALUES (array('a', 'b', 'c'), array(1, 2, 3)), (array(), array()), (NULL, NULL) +INSERT INTO test_map_from_arrays VALUES + (array('a', 'b', 'c'), array(1, 2, 3)), + (array(), array()), + (NULL, NULL), + (array('x'), NULL), + (NULL, array(99)) +-- basic functionality query spark_answer_only -SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NOT NULL +SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NOT NULL AND v IS NOT NULL --- Comet bug: map_from_arrays(NULL, NULL) causes native crash "map key cannot be null" --- https://github.com/apache/datafusion-comet/issues/3327 -query ignore(https://github.com/apache/datafusion-comet/issues/3327) -SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NULL +-- both inputs NULL should return NULL +query +SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NULL AND v IS NULL + +-- keys not null but values null should return NULL (Spark behavior) +query +SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NOT NULL AND v IS NULL + +-- keys null but values not null should return NULL (Spark behavior) +query +SELECT map_from_arrays(k, v) FROM test_map_from_arrays WHERE k IS NULL AND v IS NOT NULL + +-- all rows including nulls +query spark_answer_only +SELECT map_from_arrays(k, v) FROM test_map_from_arrays -- literal arguments query spark_answer_only SELECT map_from_arrays(array('a', 'b'), array(1, 2)) + +-- literal null arguments +query +SELECT map_from_arrays(NULL, array(1, 2)) + +query +SELECT map_from_arrays(array('a'), NULL) + +query +SELECT map_from_arrays(NULL, NULL) \ No newline at end of file