diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py b/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py index 26a6d85..699f4bd 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/clients/foundry_sql_server.py @@ -4,7 +4,7 @@ import time import warnings -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload from foundry_dev_tools.clients.api_client import APIClient from foundry_dev_tools.errors.handling import ErrorHandlingConfig @@ -13,7 +13,14 @@ FoundrySqlQueryFailedError, FoundrySqlSerializationFormatNotImplementedError, ) -from foundry_dev_tools.utils.api_types import Ref, SqlDialect, SQLReturnType, assert_in_literal +from foundry_dev_tools.utils.api_types import ( + ArrowCompressionCodec, + FurnaceSqlDialect, + Ref, + SqlDialect, + SQLReturnType, + assert_in_literal, +) if TYPE_CHECKING: import pandas as pd @@ -296,3 +303,346 @@ def api_queries_results( }, **kwargs, ) + + +class FoundrySqlServerClientV2(APIClient): + """FoundrySqlServerClientV2 implements the newer foundry-sql-server API. + + This client uses a different API flow compared to V1: + - Executes queries via POST to /api/ with applicationId and sql + - Polls POST to /api/status for query completion + - Retrieves results via POST to /api/stream with tickets + """ + + api_name = "foundry-sql-server" + + @overload + def query_foundry_sql( + self, + query: str, + return_type: Literal["pandas"], + branch: Ref = ..., + sql_dialect: FurnaceSqlDialect = ..., + arrow_compression_codec: ArrowCompressionCodec = ..., + timeout: int = ..., + experimental_use_trino: bool = ..., + ) -> pd.DataFrame: ... + + @overload + def query_foundry_sql( + self, + query: str, + return_type: Literal["polars"], + branch: Ref = ..., + sql_dialect: FurnaceSqlDialect = ..., + arrow_compression_codec: ArrowCompressionCodec = ..., + timeout: int = ..., + experimental_use_trino: bool = ..., + ) -> pl.DataFrame: ... + + @overload + def query_foundry_sql( + self, + query: str, + return_type: Literal["spark"], + branch: Ref = ..., + sql_dialect: FurnaceSqlDialect = ..., + arrow_compression_codec: ArrowCompressionCodec = ..., + timeout: int = ..., + experimental_use_trino: bool = ..., + ) -> pyspark.sql.DataFrame: ... + + @overload + def query_foundry_sql( + self, + query: str, + return_type: Literal["arrow"], + branch: Ref = ..., + sql_dialect: FurnaceSqlDialect = ..., + arrow_compression_codec: ArrowCompressionCodec = ..., + timeout: int = ..., + experimental_use_trino: bool = ..., + ) -> pa.Table: ... + + @overload + def query_foundry_sql( + self, + query: str, + return_type: Literal["pandas", "polars", "spark", "arrow"] = ..., + branch: Ref = ..., + sql_dialect: FurnaceSqlDialect = ..., + arrow_compression_codec: ArrowCompressionCodec = ..., + timeout: int = ..., + experimental_use_trino: bool = ..., + ) -> pd.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame: ... + + def query_foundry_sql( + self, + query: str, + return_type: Literal["pandas", "polars", "spark", "arrow"] = "pandas", + branch: Ref = "master", + sql_dialect: FurnaceSqlDialect = "SPARK", + arrow_compression_codec: ArrowCompressionCodec = "NONE", + timeout: int = 600, + experimental_use_trino: bool = False, + ) -> pd.DataFrame | pl.DataFrame | pa.Table | pyspark.sql.DataFrame: + """Queries the Foundry SQL server using the V2 API. + + Uses Arrow IPC to communicate with the Foundry SQL Server Endpoint. + + Example: + df = client.query_foundry_sql( + query="SELECT * FROM `ri.foundry.main.dataset.abc` LIMIT 10" + ) + + Args: + query: The SQL Query + return_type: The return type (pandas, polars, spark, or arrow). Note: "raw" is not supported in V2. + branch: The dataset branch to query + sql_dialect: The SQL dialect to use (only SPARK is supported for V2) + arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD) + timeout: Query timeout in seconds + experimental_use_trino: If True, modifies the query to use Trino backend by adding /*+ backend(trino) */ hint + + Returns: + :external+pandas:py:class:`~pandas.DataFrame` | :external+polars:py:class:`~polars.DataFrame` | :external+pyarrow:py:class:`~pyarrow.Table` | :external+spark:py:class:`~pyspark.sql.DataFrame`: + + A pandas DataFrame, polars, Spark DataFrame or pyarrow.Table with the result. + + Raises: + FoundrySqlQueryFailedError: If the query fails + FoundrySqlQueryClientTimedOutError: If the query times out + TypeError: If an invalid sql_dialect or arrow_compression_codec is provided + ValueError: If an unsupported return_type is provided + + """ # noqa: E501 + if experimental_use_trino: + # Case-insensitive replacement of first SELECT keyword + import re + + query = re.sub(r"\bSELECT\b", "SELECT /*+ backend(trino) */", query, count=1, flags=re.IGNORECASE) + + response_json = self.api_query( + query=query, + dialect=sql_dialect, + branch=branch, + arrow_compression_codec=arrow_compression_codec, + timeout=timeout, + ).json() + + query_handle = self._extract_query_handle(response_json) + start_time = time.time() + + while response_json.get("status", {}).get("type") != "ready": + time.sleep(0.2) + response = self.api_status(query_handle) + response_json = response.json() + + if response_json.get("status", {}).get("type") == "failed": + raise FoundrySqlQueryFailedError(response, query=query, branch=branch, dialect=sql_dialect) + if time.time() > start_time + timeout: + raise FoundrySqlQueryClientTimedOutError(response, timeout=timeout) + + ticket = self._extract_ticket(response_json) + + arrow_stream_reader = self.read_stream_results_arrow(ticket) + + if return_type == "pandas": + return arrow_stream_reader.read_pandas() + + if return_type == "polars": + from foundry_dev_tools._optional.polars import pl + + arrow_table = arrow_stream_reader.read_all() + return pl.from_arrow(arrow_table) + + if return_type == "spark": + from foundry_dev_tools.utils.converter.foundry_spark import ( + arrow_stream_to_spark_dataframe, + ) + + return arrow_stream_to_spark_dataframe(arrow_stream_reader) + + if return_type == "arrow": + return arrow_stream_reader.read_all() + + msg = ( + f"Unsupported return_type: {return_type}. " + f"V2 API supports: pandas, polars, spark, arrow (raw is not supported)" + ) + raise ValueError(msg) + + def _extract_query_handle(self, response_json: dict[str, Any]) -> dict[str, Any]: + """Extract query handle from execute response. + + Args: + response_json: Response JSON from execute API + + + Returns: + Query handle dict + + Raises: + KeyError: If the response JSON doesn't contain the expected structure + + """ + response_type = response_json.get("type") + if not response_type: + msg = f"Response JSON missing 'type' field. Response: {response_json}" + raise KeyError(msg) + + type_data = response_json.get(response_type) + if not type_data: + msg = f"Response JSON missing '{response_type}' field. Response: {response_json}" + raise KeyError(msg) + + query_handle = type_data.get("queryHandle") + if not query_handle: + msg = f"Response JSON missing 'queryHandle' in '{response_type}'. Response: {response_json}" + raise KeyError(msg) + + return query_handle + + def _extract_ticket(self, response_json: dict[str, Any]) -> dict[str, Any]: + """Extract tickets from success response. + + Args: + response_json: Success response JSON from status API + + Returns: + Ticket dict with id, tickets list, and type. Example: {"id": 0, "tickets": [...], "type": "furnace"} + + Raises: + KeyError: If the response JSON doesn't contain the expected structure + + """ + try: + status = response_json["status"] + ready = status["ready"] + ticket_groups = ready["tickets"] + except KeyError as exc: + msg = ( + f"Response JSON missing expected structure. " + f"Expected path: status.ready.tickets. Response: {response_json}" + ) + raise KeyError(msg) from exc + + # we combine all tickets into one to get the full data + # if performance is a concern this should be done in parallel + return { + "id": 0, + "tickets": [ticket for ticket_group in ticket_groups for ticket in ticket_group["tickets"]], + "type": "furnace", + } + + def read_stream_results_arrow(self, ticket: dict[str, Any]) -> pa.ipc.RecordBatchStreamReader: + """Fetch query results using tickets and return Arrow stream reader. + + Args: + ticket: dict of tickets e.g. { "id": 0, "tickets": ["ey...", ...], "type": "furnace", } + + Returns: + Arrow RecordBatchStreamReader + + """ + from foundry_dev_tools._optional.pyarrow import pa + + response = self.api_stream_ticket(ticket) + response.raw.decode_content = True + + return pa.ipc.RecordBatchStreamReader(response.raw) + + def api_query( + self, + query: str, + dialect: FurnaceSqlDialect, + branch: Ref, + arrow_compression_codec: ArrowCompressionCodec = "NONE", + timeout: int = 600, + **kwargs, + ) -> requests.Response: + """Execute a SQL query via the V2 API. + + Args: + query: The SQL query string + dialect: The SQL dialect to use (only SPARK is supported) + branch: The dataset branch to query + arrow_compression_codec: Arrow compression codec (NONE, LZ4, ZSTD) + timeout: Query timeout in seconds (used for error context) + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with query handle and initial status + + """ + assert_in_literal(dialect, FurnaceSqlDialect, "dialect") + assert_in_literal(arrow_compression_codec, ArrowCompressionCodec, "arrow_compression_codec") + + return self.api_request( + "POST", + "sql-endpoint/v1/queries/query", + json={ + "querySpec": { + "query": query, + "tableProviders": {}, + "dialect": dialect, + "options": {"options": [{"option": "arrowCompressionCodec", "value": arrow_compression_codec}]}, + }, + "executionParams": { + "defaultBranchIds": [{"type": "datasetBranch", "datasetBranch": branch}], + "resultFormat": "ARROW", + "resultMode": "AUTO", + }, + }, + error_handling=ErrorHandlingConfig(branch=branch, dialect=dialect, timeout=timeout), + **kwargs, + ) + + def api_status( + self, + query_handle: dict[str, Any], + **kwargs, + ) -> requests.Response: + """Get the status of a SQL query via the V2 API. + + Args: + query_handle: Query handle dict from execute response + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with query status + + """ + return self.api_request( + "POST", + "sql-endpoint/v1/queries/status", + json=query_handle, + **kwargs, + ) + + def api_stream_ticket( + self, + ticket: dict, + **kwargs, + ) -> requests.Response: + """Stream query results using a ticket via the V2 API. + + Args: + ticket: Ticket dict containing id, tickets list, and type. + Example: {"id": 0, "tickets": ["eyJhbGc...", "eyJhbGc..."], "type": "furnace"} + **kwargs: gets passed to :py:meth:`APIClient.api_request` + + Returns: + Response with streaming Arrow data + + """ + return self.api_request( + "POST", + "sql-endpoint/v1/queries/stream", + json=ticket, + headers={ + "Accept": "application/octet-stream", + }, + stream=True, + **kwargs, + ) diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py b/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py index 7a956dc..f0ce36e 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/config/context.py @@ -152,6 +152,11 @@ def foundry_sql_server(self) -> foundry_sql_server.FoundrySqlServerClient: """Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClient`.""" return foundry_sql_server.FoundrySqlServerClient(self) + @cached_property + def foundry_sql_server_v2(self) -> foundry_sql_server.FoundrySqlServerClientV2: + """Returns :py:class:`foundry_dev_tools.clients.foundry_sql_server.FoundrySqlServerClientV2`.""" + return foundry_sql_server.FoundrySqlServerClientV2(self) + @cached_property def build2(self) -> build2.Build2Client: """Returns :py:class:`foundry_dev_tools.clients.build2.Build2Client`.""" diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py index 1287123..06d9363 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/handling.py @@ -47,9 +47,7 @@ ) from foundry_dev_tools.errors.meta import FoundryAPIError from foundry_dev_tools.errors.multipass import DuplicateGroupNameError -from foundry_dev_tools.errors.sql import ( - FoundrySqlQueryFailedError, -) +from foundry_dev_tools.errors.sql import FoundrySqlQueryFailedError, FurnaceSqlSqlParseError from foundry_dev_tools.utils.misc import decamelize LOGGER = logging.getLogger(__name__) @@ -59,6 +57,8 @@ "DataProxy:SchemaNotFound": DatasetHasNoSchemaError, "DataProxy:FallbackBranchesNotSpecifiedInQuery": BranchNotFoundError, "DataProxy:BadSqlQuery": FoundrySqlQueryFailedError, + "FurnaceSql:SqlParseError": FurnaceSqlSqlParseError, + "SqlQueryService:SqlSyntaxError": FurnaceSqlSqlParseError, "DataProxy:DatasetNotFound": DatasetNotFoundError, "Catalog:DuplicateDatasetName": DatasetAlreadyExistsError, "Catalog:DatasetsNotFound": DatasetNotFoundError, diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py index efe789a..dd5e89f 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/errors/sql.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from foundry_dev_tools.errors.meta import FoundryAPIError +from foundry_dev_tools.utils.misc import decamelize if TYPE_CHECKING: import requests @@ -15,9 +16,51 @@ class FoundrySqlQueryFailedError(FoundryAPIError): message = "Foundry SQL Query Failed." - def __init__(self, response: requests.Response): - self.error_message = response.json().get("status", {}).get("failed", {}).get("errorMessage", "") - super().__init__(response=response, info=self.error_message) + def __init__(self, response: requests.Response, **context_kwargs): + kwargs = {} + info = "" + + try: + response_json = response.json() + failed_data = response_json.get("status", {}).get("failed", {}) + + # Try to extract V2 error structure with rich parameters + if error_code := failed_data.get("errorCode"): + kwargs["error_code"] = error_code + if error_name := failed_data.get("errorName"): + kwargs["error_name"] = error_name + if error_instance_id := failed_data.get("errorInstanceId"): + kwargs["error_instance_id"] = error_instance_id + + # Extract all parameters and convert camelCase to snake_case + if parameters := failed_data.get("parameters"): + for key, value in parameters.items(): + kwargs[decamelize(key)] = value + + # Prefer userFriendlyMessage as the info text + info = parameters.get("userFriendlyMessage", "") + + # Fall back to V1 errorMessage if userFriendlyMessage not available + if not info: + info = failed_data.get("errorMessage", "") + + # Store legacy error_message attribute for backward compatibility + self.error_message = info + + except Exception: # noqa: BLE001 + # If any error occurs during extraction, fall back to empty + self.error_message = "" + + # Merge context kwargs (e.g., query, branch) with extracted error parameters + kwargs.update(context_kwargs) + + super().__init__(response=response, info=info, **kwargs) + + +class FurnaceSqlSqlParseError(FoundryAPIError): + """Exception is thrown when SQL Query is not valid.""" + + message = "Foundry SQL Query Parsing Failed." class FoundrySqlQueryClientTimedOutError(FoundryAPIError): diff --git a/libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py b/libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py index 0eb3a9c..eae3458 100644 --- a/libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py +++ b/libs/foundry-dev-tools/src/foundry_dev_tools/utils/api_types.py @@ -95,6 +95,12 @@ def assert_in_literal(option, literal, variable_name) -> None: # noqa: ANN001 SqlDialect = Literal["ANSI", "SPARK"] """The SQL Dialect for Foundry SQL queries.""" +FurnaceSqlDialect = Literal["SPARK"] +"""The SQL Dialect for Furnace SQL queries (V2 API). Only SPARK is supported.""" + +ArrowCompressionCodec = Literal["NONE", "LZ4", "ZSTD"] +"""The Arrow compression codec for Foundry SQL queries.""" + SQLReturnType = Literal["pandas", "polars", "spark", "arrow", "raw"] """The return_types for sql queries. diff --git a/tests/integration/clients/test_foundry_sql_server.py b/tests/integration/clients/test_foundry_sql_server.py index 929cddd..f1b22f1 100644 --- a/tests/integration/clients/test_foundry_sql_server.py +++ b/tests/integration/clients/test_foundry_sql_server.py @@ -2,7 +2,11 @@ import pytest from foundry_dev_tools.errors.dataset import BranchNotFoundError, DatasetHasNoSchemaError, DatasetNotFoundError -from foundry_dev_tools.errors.sql import FoundrySqlQueryFailedError, FoundrySqlSerializationFormatNotImplementedError +from foundry_dev_tools.errors.sql import ( + FoundrySqlQueryFailedError, + FoundrySqlSerializationFormatNotImplementedError, + FurnaceSqlSqlParseError, +) from tests.integration.conftest import TEST_SINGLETON @@ -67,3 +71,217 @@ def test_legacy_fallback(mocker): TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql(f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}`") query_foundry_sql_legacy_spy.assert_called() + + +def test_v1_ansi_sql_dialect(): + """Test V1 client with ANSI SQL dialect (uses double quotes instead of backticks).""" + # Test basic query with ANSI dialect - note the use of double quotes instead of backticks + result = TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql( + query=f'SELECT sepal_width, sepal_length FROM "{TEST_SINGLETON.iris_new.rid}" LIMIT 5', + sql_dialect="ANSI", + ) + assert result.shape[0] == 5 + assert result.shape[1] == 2 + + # Test with aggregation using ANSI dialect + result_agg = TEST_SINGLETON.ctx.foundry_sql_server.query_foundry_sql( + query=f'SELECT COUNT(*) as cnt FROM "{TEST_SINGLETON.iris_new.rid}"', + sql_dialect="ANSI", + ) + assert result_agg.shape[0] == 1 + assert "cnt" in result_agg.columns + + +# V2 Client Tests + + +def test_v2_smoke(): + """Test basic V2 client functionality with a simple query.""" + one_row_one_column = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 1", + ) + assert one_row_one_column.shape == (1, 1) + + one_row_one_column = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 1", + return_type="arrow", + ) + assert one_row_one_column.num_columns == 1 + assert one_row_one_column.num_rows == 1 + assert one_row_one_column.column_names == ["sepal_width"] + + +def test_v2_multiple_rows(): + """Test V2 client with multiple rows.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10", + ) + assert result.shape[0] == 10 + assert result.shape[1] == 5 # iris dataset has 5 columns + + +def test_v2_return_type_arrow(): + """Test V2 client with Arrow return type.""" + import pyarrow as pa + + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 5", + return_type="arrow", + ) + assert isinstance(result, pa.Table) + assert result.num_rows == 5 + + +def test_v2_return_type_raw_not_supported(): + """Test V2 client with raw return type.""" + with pytest.raises(ValueError, match="Unsupported return_type: raw"): + TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_width, sepal_length FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 3", + return_type="raw", # type: ignore[arg-type] + ) + + +def test_v2_aggregation_query(): + """Test V2 client with aggregation query.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f""" + SELECT + COUNT(*) as total_count, + AVG(sepal_width) as avg_sepal_width + FROM `{TEST_SINGLETON.iris_new.rid}` + """, + ) + assert result.shape == (1, 2) + assert "total_count" in result.columns + assert "avg_sepal_width" in result.columns + + +def test_v2_query_failed(): + """Test V2 client with invalid SQL query.""" + with pytest.raises(FurnaceSqlSqlParseError): + TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT foo, bar, FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 100", + ) + + +def test_v2_disable_arrow_compression(): + """Test V2 client with arrow compression disabled.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 5", + arrow_compression_codec="NONE", + ) + assert result.shape[0] == 5 + + +def test_v2_with_where_clause(): + """Test V2 client with WHERE clause.""" + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f""" + SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` + WHERE is_setosa = 'setosa' + LIMIT 20 + """, + ) + assert result.shape[0] <= 20 + # Verify all returned rows have is_setosa = 'setosa' + if result.shape[0] > 0: + assert all(result["is_setosa"] == "setosa") + + +def test_v2_polars_return_type(): + polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 2", + return_type="polars", + ) + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.height == 2 + assert polars_df.width == 1 + + +def test_v2_polars_parquet(): + polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_parquet.rid}` LIMIT 2", + return_type="polars", + ) + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.height == 2 + assert polars_df.width == 1 + + polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_parquet.rid}` LIMIT 2", + return_type="polars", + experimental_use_trino=True, + ) + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.height == 2 + assert polars_df.width == 1 + + +def test_v2_polars_parquet_hive_partitioning(): + polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_hive_partitioned.rid}` LIMIT 2", + return_type="polars", + experimental_use_trino=True, + ) + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.height == 2 + assert polars_df.width == 1 + + polars_df = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_hive_partitioned.rid}` LIMIT 2", return_type="polars" + ) + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.height == 2 + assert polars_df.width == 1 + + +def test_v2_arrow_compression_codecs(): + """Test V2 client with different arrow compression codecs.""" + # Test with LZ4 compression + result_lz4 = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10", + arrow_compression_codec="LZ4", + ) + assert result_lz4.shape[0] == 10 + assert result_lz4.shape[1] == 5 + + # Test with ZSTD compression + result_zstd = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10", + arrow_compression_codec="ZSTD", + ) + assert result_zstd.shape[0] == 10 + assert result_zstd.shape[1] == 5 + + # Test with NONE compression (default) + result_none = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT * FROM `{TEST_SINGLETON.iris_new.rid}` LIMIT 10", + arrow_compression_codec="NONE", + ) + assert result_none.shape[0] == 10 + assert result_none.shape[1] == 5 + + # Verify all results have the same data + import pandas as pd + + pd.testing.assert_frame_equal(result_lz4, result_zstd) + pd.testing.assert_frame_equal(result_lz4, result_none) + + +def test_v2_trino_engine_in_response(mocker): + """Test that when experimental_use_trino=True, the API response indicates trino engine.""" + # Spy on the api_query method to capture the initial response + api_query_spy = mocker.spy(TEST_SINGLETON.ctx.foundry_sql_server_v2, "api_query") + + # Execute query with trino enabled using parquet dataset (trino works with parquet) + result = TEST_SINGLETON.ctx.foundry_sql_server_v2.query_foundry_sql( + query=f"SELECT sepal_length FROM `{TEST_SINGLETON.iris_parquet.rid}` LIMIT 1", + experimental_use_trino=True, + ) + + assert result.shape == (1, 1) + + # Verify the API response indicates TRINO backend + response_json = api_query_spy.spy_return.json() + backend = response_json[response_json["type"]]["queryStructure"]["metadata"]["backend"] + assert backend == "TRINO" diff --git a/tests/unit/clients/test_foundry_sql_server.py b/tests/unit/clients/test_foundry_sql_server.py index 6f0da5c..f3c1d0b 100644 --- a/tests/unit/clients/test_foundry_sql_server.py +++ b/tests/unit/clients/test_foundry_sql_server.py @@ -121,3 +121,236 @@ def test_exception_unknown_json(mocker, test_context_mock): timeout=0.001, ) assert exception.value.error_message == "" + + +def test_v2_experimental_use_trino(mocker, test_context_mock): + """Test that experimental_use_trino parameter modifies the query correctly.""" + import pandas as pd + + mocker.patch("time.sleep") # we do not want to wait in tests + + # Mock the arrow stream reader to return a simple pandas DataFrame + mock_arrow_reader = mocker.MagicMock() + mock_arrow_reader.read_pandas.return_value = pd.DataFrame({"col1": [1, 2, 3]}) + mocker.patch.object( + test_context_mock.foundry_sql_server_v2, + "read_stream_results_arrow", + return_value=mock_arrow_reader, + ) + + # Mock the api_query endpoint (initial query execution) + query_matcher = mocker.MagicMock() + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"), + json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-id-123", "type": "foundry"}}}, + additional_matcher=query_matcher, + ) + + # Mock the api_status endpoint (poll for completion - returns ready immediately) + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"), + json={ + "status": { + "type": "ready", + "ready": {"tickets": [{"tickets": ["eyJhbGc...mock-ticket-1", "eyJhbGc...mock-ticket-2"]}]}, + } + }, + ) + + # Test with experimental_use_trino=True + df = test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT * FROM `ri.foundry.main.dataset.test-dataset`", + experimental_use_trino=True, + ) + + # Verify the query was modified to include the Trino backend hint + call_args = query_matcher.call_args_list[0] + request = call_args[0][0] + request_json = request.json() + + assert "SELECT /*+ backend(trino) */ * FROM" in request_json["querySpec"]["query"] + assert df.shape[0] == 3 + + # Reset for second test + query_matcher.reset_mock() + + # Test with experimental_use_trino=False (default) + df = test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT * FROM `ri.foundry.main.dataset.test-dataset`", + experimental_use_trino=False, + ) + + # Verify the query was NOT modified + call_args = query_matcher.call_args_list[0] + request = call_args[0][0] + request_json = request.json() + + assert request_json["querySpec"]["query"] == "SELECT * FROM `ri.foundry.main.dataset.test-dataset`" + assert "backend(trino)" not in request_json["querySpec"]["query"] + assert df.shape[0] == 3 + + +def test_v2_poll_for_query_completion_timeout(mocker, test_context_mock): + """Test that V2 query times out correctly when polling takes too long.""" + mocker.patch("time.sleep") # we do not want to wait in tests + + # Mock the api_query endpoint (initial query execution) + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"), + json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-timeout-123", "type": "foundry"}}}, + ) + + # Mock the api_status endpoint to always return running status + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"), + json={"status": {"type": "running", "running": {}}}, + ) + + with pytest.raises(FoundrySqlQueryClientTimedOutError): + test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT * FROM `ri.foundry.main.dataset.test-dataset`", + timeout=0.001, + ) + + +def test_v2_ansi_dialect_not_supported(test_context_mock): + """Test that V2 client rejects ANSI SQL dialect.""" + with pytest.raises(TypeError, match="'ANSI' is not a valid option for dialect"): + test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT * FROM `ri.foundry.main.dataset.test-dataset`", + sql_dialect="ANSI", # type: ignore[arg-type] + ) + + +def test_v2_invalid_compression_codec(test_context_mock): + """Test that V2 client rejects invalid arrow compression codec.""" + with pytest.raises(TypeError, match="'INVALID' is not a valid option for arrow_compression_codec"): + test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT * FROM `ri.foundry.main.dataset.test-dataset`", + arrow_compression_codec="INVALID", # type: ignore[arg-type] + ) + + +def test_v2_query_failed_error_details(mocker, test_context_mock): + """Test that V2 error responses with rich parameters are properly extracted.""" + mocker.patch("time.sleep") + + # Mock the api_query endpoint (initial query execution) + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"), + json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-id", "type": "foundry"}}}, + ) + + # Mock the api_status endpoint with V2 error structure containing rich parameters + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"), + json={ + "status": { + "type": "failed", + "failed": { + "errorCode": "INVALID_ARGUMENT", + "errorName": "SqlQueryService:SqlSyntaxError", + "errorInstanceId": "c16cb2b7-01ec-42a9-9ee2-0e57e2aed4ba", + "parameters": { + "endLine": 1, + "endColumn": 15350, + "dialect": "SPARK", + "queryFragment": "", + "startColumn": 15340, + "startLine": 1, + "userFriendlyMessage": ( + "From line 1, column 15340 to line 1, column 15350: " + "Column 'COLUMN_NAME' not found in table 'my_table'; did you mean 'column_name'?" + ), + }, + }, + } + }, + ) + + with pytest.raises(FoundrySqlQueryFailedError) as exception: + test_context_mock.foundry_sql_server_v2.query_foundry_sql( + "SELECT COLUMN_NAME FROM `ri.foundry.main.dataset.test-dataset`", + ) + + # Verify all error parameters are extracted and accessible + assert exception.value.error_code == "INVALID_ARGUMENT" + assert exception.value.error_name == "SqlQueryService:SqlSyntaxError" + assert exception.value.error_instance_id == "c16cb2b7-01ec-42a9-9ee2-0e57e2aed4ba" + + # Verify parameters are converted from camelCase to snake_case and accessible + assert exception.value.start_line == 1 + assert exception.value.end_line == 1 + assert exception.value.start_column == 15340 + assert exception.value.end_column == 15350 + assert exception.value.dialect == "SPARK" + # query_fragment is in kwargs even if empty + assert "query_fragment" in exception.value.kwargs + + # Verify userFriendlyMessage is used as the info text and accessible + assert exception.value.user_friendly_message == ( + "From line 1, column 15340 to line 1, column 15350: " + "Column 'COLUMN_NAME' not found in table 'my_table'; did you mean 'column_name'?" + ) + + # Verify the exception message string includes the user-friendly message + exception_str = str(exception.value) + assert "COLUMN_NAME" in exception_str + assert "my_table" in exception_str + + +def test_v2_polling_error_includes_context(mocker, test_context_mock): + """Test that polling errors include query context for better debugging.""" + mocker.patch("time.sleep") + + test_query = "SELECT * FROM `ri.foundry.main.dataset.test-dataset`" + + # Mock the api_query endpoint (initial query execution) + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/query"), + json={"type": "running", "running": {"queryHandle": {"queryId": "test-query-id", "type": "foundry"}}}, + ) + + # Mock the api_status endpoint with polling error + test_context_mock.mock_adapter.register_uri( + "POST", + build_api_url(TEST_HOST.url, "foundry-sql-server", "sql-endpoint/v1/queries/status"), + json={ + "status": { + "type": "failed", + "failed": { + "errorCode": "ModuleGroupService:ErrorPollingModule", + "errorInstanceId": "5be87070-3aa3-4ed6-aa6a-d9b5041885af", + "errorMessage": "Error polling for job status. Please resubmit.", + "retryable": False, + }, + } + }, + ) + + with pytest.raises(FoundrySqlQueryFailedError) as exception: + test_context_mock.foundry_sql_server_v2.query_foundry_sql(test_query) + + # Verify error details are extracted + assert exception.value.error_code == "ModuleGroupService:ErrorPollingModule" + assert exception.value.error_instance_id == "5be87070-3aa3-4ed6-aa6a-d9b5041885af" + assert exception.value.error_message == "Error polling for job status. Please resubmit." + + # Verify query context is included in the error + assert exception.value.query == test_query + assert exception.value.branch == "master" + assert exception.value.dialect == "SPARK" + + # Verify context appears in exception string + exception_str = str(exception.value) + assert "query = " + test_query in exception_str + assert "branch = master" in exception_str + assert "dialect = SPARK" in exception_str + assert "ModuleGroupService:ErrorPollingModule" in exception_str