From c288cc1e1e9cd6def2f72d496334299b7d85ea7e Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Thu, 5 Mar 2026 11:48:07 -0500 Subject: [PATCH] Add missing extension YAMLs and test for completeness Closes #722 --- core/build.gradle.kts | 1 + .../extension/DefaultExtensionCatalog.java | 10 +++ .../substrait/extension/SimpleExtension.java | 5 ++ .../DefaultExtensionCatalogTest.java | 66 +++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java diff --git a/core/build.gradle.kts b/core/build.gradle.kts index df0f3f364..54dae67f8 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -87,6 +87,7 @@ dependencies { testImplementation(platform(libs.junit.bom)) testImplementation(libs.protobuf.java.util) testImplementation(libs.guava) + testImplementation(libs.bundles.jackson) testImplementation(libs.junit.jupiter) testRuntimeOnly(libs.junit.platform.launcher) diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..48162a978 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -14,6 +14,10 @@ public class DefaultExtensionCatalog { public static final String FUNCTIONS_AGGREGATE_APPROX = "extension:io.substrait:functions_aggregate_approx"; + /** Extension identifier for aggregate functions with decimal output. */ + public static final String FUNCTIONS_AGGREGATE_DECIMAL_OUTPUT = + "extension:io.substrait:functions_aggregate_decimal_output"; + /** Extension identifier for generic aggregate functions. */ public static final String FUNCTIONS_AGGREGATE_GENERIC = "extension:io.substrait:functions_aggregate_generic"; @@ -37,6 +41,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for geometry functions. */ public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for logarithmic functions. */ public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic"; @@ -78,7 +85,10 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "logarithmic", "rounding", "rounding_decimal", + "set", "string") + // TODO(#688): functions_list.yaml is not loaded here because it uses lambda type + // expressions (e.g. func any2>) that are not yet supported by the type parser. .stream() .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index ab9058e77..a257a95a4 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -680,6 +680,11 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } + /** Returns true if the given URN has any functions or types loaded in this collection. */ + public boolean containsUrn(String urn) { + return urnSupplier.get().contains(urn) || types().stream().anyMatch(t -> t.urn().equals(urn)); + } + private void checkUrn(String name) { if (urnSupplier.get().contains(name)) { return; diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java new file mode 100644 index 000000000..cbbabcbe0 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -0,0 +1,66 @@ +package io.substrait.extension; + +import static io.substrait.extension.DefaultExtensionCatalog.DEFAULT_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; + +/** + * Verifies that every extension YAML in substrait/extensions is loaded by {@link + * DefaultExtensionCatalog}. + */ +class DefaultExtensionCatalogTest { + + private static final Set UNSUPPORTED_FILES = + Set.of( + // TODO: aggregate_decimal_output defines count and approx_count_distinct with + // decimal<38,0> return types instead of i64. When loaded alongside aggregate_generic, + // the same function key (e.g. count:any) maps to the same Calcite operator twice, + // which breaks the reverse lookup in FunctionConverter.getSqlOperatorFromSubstraitFunc. + // Fixing this requires either deduplicating the operator map or adding type-based + // disambiguation for aggregate functions. + "functions_aggregate_decimal_output.yaml", + "functions_geometry.yaml", // user-defined types not supported in Calcite type conversion + "functions_list.yaml", // TODO(#688): remove once lambda types are supported + "type_variations.yaml", // type variations not yet supported by extension loader + "unknown.yaml" // unknown type extension not yet loaded + ); + + private static final ObjectMapper YAML_MAPPER = new ObjectMapper(new YAMLFactory()); + + @Test + void allExtensionYamlFilesAreLoaded() throws IOException { + List yamlFiles = getExtensionYamlFiles(); + + for (File file : yamlFiles) { + if (UNSUPPORTED_FILES.contains(file.getName())) { + continue; + } + String urn = parseUrn(file); + assertTrue( + DEFAULT_COLLECTION.containsUrn(urn), + file.getName() + " not loaded by DefaultExtensionCatalog (urn: " + urn + ")"); + } + } + + private static String parseUrn(File yamlFile) throws IOException { + JsonNode doc = YAML_MAPPER.readTree(yamlFile); + JsonNode urnNode = doc.get("urn"); + return urnNode == null ? null : urnNode.asText(); + } + + private static List getExtensionYamlFiles() { + File extensionsDir = new File("../substrait/extensions"); + assertTrue(extensionsDir.isDirectory(), "substrait/extensions directory not found"); + File[] files = extensionsDir.listFiles((dir, name) -> name.endsWith(".yaml")); + assertTrue(files != null && files.length > 0, "No YAML files found"); + return List.of(files); + } +}