From 7d997ef100bc42bfc04201117d6824bb0e6f1b9a Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 13 Mar 2026 14:53:16 +0000 Subject: [PATCH 01/39] refactor(adapters): move runtime config decisions into core helpers --- sqlspec/adapters/adbc/config.py | 29 +++++--- sqlspec/adapters/adbc/core.py | 50 ++++++++++++++ sqlspec/adapters/asyncpg/config.py | 32 ++++----- sqlspec/adapters/asyncpg/core.py | 50 ++++++++++++++ sqlspec/adapters/psqlpy/config.py | 38 +++++------ sqlspec/adapters/psqlpy/core.py | 50 ++++++++++++++ sqlspec/adapters/psycopg/config.py | 66 ++++++++----------- sqlspec/adapters/psycopg/core.py | 50 ++++++++++++++ .../test_adbc/test_extension_detection.py | 26 +++++++- .../unit/adapters/test_asyncpg/test_config.py | 35 +++++++++- .../unit/adapters/test_psqlpy/test_config.py | 35 +++++++++- .../unit/adapters/test_psycopg/test_config.py | 49 +++++++++++++- 12 files changed, 423 insertions(+), 87 deletions(-) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 79992ee27..5b0676846 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -8,11 +8,14 @@ from sqlspec.adapters.adbc.core import ( apply_driver_features, build_connection_config, + build_postgres_extension_probe_names, detect_postgres_extensions, get_statement_config, is_postgres_dialect, resolve_dialect_from_config, resolve_driver_connect_func, + resolve_postgres_extension_state, + resolve_runtime_statement_config, ) from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, AdbcSessionContext from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig @@ -276,15 +279,25 @@ def _detect_extensions_if_needed(self) -> None: connection = self.create_connection() try: - self._pgvector_available, self._paradedb_available = detect_postgres_extensions( + probe_names = build_postgres_extension_probe_names(self.driver_features) + pgvector_available, paradedb_available = detect_postgres_extensions( connection, - enable_pgvector=self.driver_features.get("enable_pgvector", False), - enable_paradedb=self.driver_features.get("enable_paradedb", False), + enable_pgvector="vector" in probe_names, + enable_paradedb="pg_search" in probe_names, ) finally: connection.close() - self._update_dialect_for_extensions() + detected_extensions: set[str] = set() + if pgvector_available: + detected_extensions.add("vector") + if paradedb_available: + detected_extensions.add("pg_search") + self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( + self.statement_config, + self.driver_features, + detected_extensions, + ) def provide_connection(self, *args: Any, **kwargs: Any) -> "AdbcConnectionContext": """Provide a connection context manager. @@ -315,10 +328,10 @@ def provide_session( A context manager that yields an AdbcDriver instance. """ self._detect_extensions_if_needed() - statement_config = ( - statement_config - or self.statement_config - or get_statement_config(resolve_dialect_from_config(self.connection_config)) + statement_config = resolve_runtime_statement_config( + statement_config, + self.statement_config, + get_statement_config(resolve_dialect_from_config(self.connection_config)), ) handler = _AdbcSessionConnectionHandler(self) diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index 102be5b23..5af07e572 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -41,6 +41,7 @@ __all__ = ( "apply_driver_features", "build_connection_config", + "build_postgres_extension_probe_names", "build_profile", "collect_rows", "create_mapped_exception", @@ -68,7 +69,9 @@ "resolve_many_rowcount", "resolve_parameter_casts", "resolve_parameter_styles", + "resolve_postgres_extension_state", "resolve_rowcount", + "resolve_runtime_statement_config", ) COLUMN_CACHE_MAX_SIZE: int = 256 @@ -214,6 +217,53 @@ def detect_postgres_extensions( return False, False +def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": + """Return enabled PostgreSQL extension names to probe on first connection.""" + if driver_features is None: + return [] + + extensions: list[str] = [] + if driver_features.get("enable_pgvector", False): + extensions.append("vector") + if driver_features.get("enable_paradedb", False): + extensions.append("pg_search") + return extensions + + +def resolve_postgres_extension_state( + statement_config: "StatementConfig", + driver_features: "Mapping[str, Any] | None", + detected_extensions: "set[str] | None" = None, +) -> "tuple[StatementConfig, bool, bool]": + """Resolve detected PostgreSQL extension flags and promoted dialect.""" + detected = detected_extensions or set() + pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + paradedb_available = bool( + driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected + ) + + if statement_config.dialect == "postgres": + if paradedb_available: + statement_config = statement_config.replace(dialect="paradedb") + elif pgvector_available: + statement_config = statement_config.replace(dialect="pgvector") + + return statement_config, pgvector_available, paradedb_available + + +def resolve_runtime_statement_config( + statement_config: "StatementConfig | None", + configured_statement_config: "StatementConfig | None", + default_config: "StatementConfig", +) -> "StatementConfig": + """Resolve the effective runtime statement config for a session.""" + if statement_config is not None: + return statement_config + if configured_statement_config is not None: + return configured_statement_config + return default_config + + def normalize_driver_path(driver_name: str) -> str: """Normalize a driver name to an importable connect function path.""" stripped = driver_name.strip() diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index c3a741a66..a4436aeb8 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -13,9 +13,12 @@ from sqlspec.adapters.asyncpg.core import ( apply_driver_features, build_connection_config, + build_postgres_extension_probe_names, default_statement_config, register_json_codecs, register_pgvector_support, + resolve_postgres_extension_state, + resolve_runtime_statement_config, ) from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, AsyncpgExceptionHandler, AsyncpgSessionContext from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs @@ -459,29 +462,21 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None: # Detect extensions on first connection, update dialect if self._pgvector_available is None: - extensions = [ - name - for name, enabled in [ - ("vector", self.driver_features.get("enable_pgvector", False)), - ("pg_search", self.driver_features.get("enable_paradedb", False)), - ] - if enabled - ] + detected_extensions: set[str] = set() + extensions = build_postgres_extension_probe_names(self.driver_features) if extensions: try: results = await connection.fetch( "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", extensions ) - detected = {r["extname"] for r in results} - self._pgvector_available = "vector" in detected - self._paradedb_available = "pg_search" in detected + detected_extensions = {r["extname"] for r in results} except Exception: - self._pgvector_available = False - self._paradedb_available = False - else: - self._pgvector_available = False - self._paradedb_available = False - self._update_dialect_for_extensions() + detected_extensions = set() + self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( + self.statement_config, + self.driver_features, + detected_extensions, + ) if self._pgvector_available: await register_pgvector_support(connection) @@ -563,7 +558,8 @@ def provide_session( return AsyncpgSessionContext( acquire_connection=factory.acquire_connection, release_connection=factory.release_connection, - statement_config=statement_config or (lambda: self.statement_config or default_statement_config), + statement_config=statement_config + or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=self.driver_features, prepare_driver=self._prepare_driver, ) diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index 5792b85c6..2ebbc7e96 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -38,6 +38,7 @@ "NormalizedStackOperation", "apply_driver_features", "build_connection_config", + "build_postgres_extension_probe_names", "build_profile", "build_statement_config", "collect_rows", @@ -50,6 +51,8 @@ "register_json_codecs", "register_pgvector_support", "resolve_many_rowcount", + "resolve_postgres_extension_state", + "resolve_runtime_statement_config", ) ASYNC_PG_STATUS_REGEX: "re.Pattern[str]" = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) @@ -256,6 +259,53 @@ def apply_driver_features( return statement_config, processed_features +def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": + """Return enabled PostgreSQL extension names to probe on first connection.""" + if driver_features is None: + return [] + + extensions: list[str] = [] + if driver_features.get("enable_pgvector", False): + extensions.append("vector") + if driver_features.get("enable_paradedb", False): + extensions.append("pg_search") + return extensions + + +def resolve_postgres_extension_state( + statement_config: "StatementConfig", + driver_features: "Mapping[str, Any] | None", + detected_extensions: "set[str] | None" = None, +) -> "tuple[StatementConfig, bool, bool]": + """Resolve detected PostgreSQL extension flags and promoted dialect.""" + detected = detected_extensions or set() + pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + paradedb_available = bool( + driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected + ) + + if statement_config.dialect == "postgres": + if paradedb_available: + statement_config = statement_config.replace(dialect="paradedb") + elif pgvector_available: + statement_config = statement_config.replace(dialect="pgvector") + + return statement_config, pgvector_available, paradedb_available + + +def resolve_runtime_statement_config( + statement_config: "StatementConfig | None", + configured_statement_config: "StatementConfig | None", + default_config: "StatementConfig", +) -> "StatementConfig": + """Resolve the effective runtime statement config for a session.""" + if statement_config is not None: + return statement_config + if configured_statement_config is not None: + return configured_statement_config + return default_config + + def parse_status(status: Any) -> int: """Parse AsyncPG status string to extract row count. diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 87a528a74..5c946fa5a 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -7,7 +7,14 @@ from typing_extensions import NotRequired from sqlspec.adapters.psqlpy._typing import PsqlpyConnection -from sqlspec.adapters.psqlpy.core import apply_driver_features, build_connection_config, default_statement_config +from sqlspec.adapters.psqlpy.core import ( + apply_driver_features, + build_connection_config, + build_postgres_extension_probe_names, + default_statement_config, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, PsqlpyExceptionHandler, PsqlpySessionContext from sqlspec.adapters.psqlpy.type_converter import register_pgvector from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs @@ -232,30 +239,22 @@ async def _ensure_connection_initialized(self, connection: "PsqlpyConnection") - """Ensure connection callback has been called exactly once for this connection.""" # Detect extensions on first connection, update dialect if self._pgvector_available is None: - extensions = [ - name - for name, enabled in [ - ("vector", self.driver_features.get("enable_pgvector", False)), - ("pg_search", self.driver_features.get("enable_paradedb", False)), - ] - if enabled - ] + detected_extensions: set[str] = set() + extensions = build_postgres_extension_probe_names(self.driver_features) if extensions: try: result = await connection.fetch( "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", [extensions] ) rows = result.result() if result else [] - detected = {r["extname"] for r in rows} - self._pgvector_available = "vector" in detected - self._paradedb_available = "pg_search" in detected + detected_extensions = {r["extname"] for r in rows} except Exception: - self._pgvector_available = False - self._paradedb_available = False - else: - self._pgvector_available = False - self._paradedb_available = False - self._update_dialect_for_extensions() + detected_extensions = set() + self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( + self.statement_config, + self.driver_features, + detected_extensions, + ) conn_id = id(connection) if conn_id in self._initialized_connection_ids: @@ -338,7 +337,8 @@ def provide_session( return PsqlpySessionContext( acquire_connection=factory.acquire_connection, release_connection=factory.release_connection, - statement_config=statement_config or (lambda: self.statement_config or default_statement_config), + statement_config=statement_config + or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=self.driver_features, prepare_driver=self._prepare_driver, ) diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index 85b719474..81ba74a06 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -38,6 +38,7 @@ "apply_driver_features", "build_connection_config", "build_insert_statement", + "build_postgres_extension_probe_names", "build_profile", "build_statement_config", "coerce_numeric_for_write", @@ -53,6 +54,8 @@ "get_parameter_casts", "normalize_scalar_parameter", "prepare_parameters_with_casts", + "resolve_postgres_extension_state", + "resolve_runtime_statement_config", "split_schema_and_table", ) @@ -257,6 +260,53 @@ def apply_driver_features( return statement_config, features +def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": + """Return enabled PostgreSQL extension names to probe on first connection.""" + if driver_features is None: + return [] + + extensions: list[str] = [] + if driver_features.get("enable_pgvector", False): + extensions.append("vector") + if driver_features.get("enable_paradedb", False): + extensions.append("pg_search") + return extensions + + +def resolve_postgres_extension_state( + statement_config: "StatementConfig", + driver_features: "Mapping[str, Any] | None", + detected_extensions: "set[str] | None" = None, +) -> "tuple[StatementConfig, bool, bool]": + """Resolve detected PostgreSQL extension flags and promoted dialect.""" + detected = detected_extensions or set() + pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + paradedb_available = bool( + driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected + ) + + if statement_config.dialect == "postgres": + if paradedb_available: + statement_config = statement_config.replace(dialect="paradedb") + elif pgvector_available: + statement_config = statement_config.replace(dialect="pgvector") + + return statement_config, pgvector_available, paradedb_available + + +def resolve_runtime_statement_config( + statement_config: "StatementConfig | None", + configured_statement_config: "StatementConfig | None", + default_config: "StatementConfig", +) -> "StatementConfig": + """Resolve the effective runtime statement config for a session.""" + if statement_config is not None: + return statement_config + if configured_statement_config is not None: + return configured_statement_config + return default_config + + def collect_rows(query_result: Any | None) -> "tuple[list[dict[str, Any]], list[str]]": """Collect psqlpy rows and column names. diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 3c86bc94e..dca788d02 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -7,7 +7,13 @@ from typing_extensions import NotRequired from sqlspec.adapters.psycopg._typing import PsycopgAsyncConnection, PsycopgSyncConnection -from sqlspec.adapters.psycopg.core import apply_driver_features, default_statement_config +from sqlspec.adapters.psycopg.core import ( + apply_driver_features, + build_postgres_extension_probe_names, + default_statement_config, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.adapters.psycopg.driver import ( PsycopgAsyncCursor, PsycopgAsyncDriver, @@ -281,30 +287,22 @@ def _configure_connection(self, conn: "PsycopgSyncConnection") -> None: # Detect extensions on first connection, update dialect if self._pgvector_available is None: - extensions = [ - name - for name, enabled in [ - ("vector", self.driver_features.get("enable_pgvector", False)), - ("pg_search", self.driver_features.get("enable_paradedb", False)), - ] - if enabled - ] + detected_extensions: set[str] = set() + extensions = build_postgres_extension_probe_names(self.driver_features) if extensions: try: cursor = conn.execute( "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) ) results = cursor.fetchall() - detected = {r[0] for r in results} # type: ignore[index] - self._pgvector_available = "vector" in detected - self._paradedb_available = "pg_search" in detected + detected_extensions = {r[0] for r in results} # type: ignore[index] except Exception: - self._pgvector_available = False - self._paradedb_available = False - else: - self._pgvector_available = False - self._paradedb_available = False - self._update_dialect_for_extensions() + detected_extensions = set() + self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( + self.statement_config, + self.driver_features, + detected_extensions, + ) if self._pgvector_available: register_pgvector_sync(conn) @@ -367,7 +365,8 @@ def provide_session( return PsycopgSyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, - statement_config=statement_config or (lambda: self.statement_config or default_statement_config), + statement_config=statement_config + or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=self.driver_features, prepare_driver=self._prepare_driver, ) @@ -571,30 +570,22 @@ async def _configure_async_connection(self, conn: "PsycopgAsyncConnection") -> N # Detect extensions on first connection, update dialect if self._pgvector_available is None: - extensions = [ - name - for name, enabled in [ - ("vector", self.driver_features.get("enable_pgvector", False)), - ("pg_search", self.driver_features.get("enable_paradedb", False)), - ] - if enabled - ] + detected_extensions: set[str] = set() + extensions = build_postgres_extension_probe_names(self.driver_features) if extensions: try: cursor = await conn.execute( "SELECT extname FROM pg_extension WHERE extname = ANY(%s::text[])", (extensions,) ) results = await cursor.fetchall() - detected = {r[0] for r in results} # type: ignore[index] - self._pgvector_available = "vector" in detected - self._paradedb_available = "pg_search" in detected + detected_extensions = {r[0] for r in results} # type: ignore[index] except Exception: - self._pgvector_available = False - self._paradedb_available = False - else: - self._pgvector_available = False - self._paradedb_available = False - self._update_dialect_for_extensions() + detected_extensions = set() + self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( + self.statement_config, + self.driver_features, + detected_extensions, + ) if self._pgvector_available: await register_pgvector_async(conn) @@ -676,7 +667,8 @@ def provide_session( return PsycopgAsyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, - statement_config=statement_config or (lambda: self.statement_config or default_statement_config), + statement_config=statement_config + or (lambda: resolve_runtime_statement_config(None, self.statement_config, default_statement_config)), driver_features=self.driver_features, prepare_driver=self._prepare_driver, ) diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index bf7716194..6c0bb4c29 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -52,6 +52,7 @@ "build_async_pipeline_execution_result", "build_copy_from_command", "build_pipeline_execution_result", + "build_postgres_extension_probe_names", "build_profile", "build_statement_config", "build_truncate_command", @@ -63,7 +64,9 @@ "execute_with_optional_parameters_async", "pipeline_supported", "resolve_many_rowcount", + "resolve_postgres_extension_state", "resolve_rowcount", + "resolve_runtime_statement_config", ) TRANSACTION_STATUS_IDLE = 0 @@ -206,6 +209,53 @@ def apply_driver_features( return statement_config, features +def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": + """Return enabled PostgreSQL extension names to probe on first connection.""" + if driver_features is None: + return [] + + extensions: list[str] = [] + if driver_features.get("enable_pgvector", False): + extensions.append("vector") + if driver_features.get("enable_paradedb", False): + extensions.append("pg_search") + return extensions + + +def resolve_postgres_extension_state( + statement_config: "StatementConfig", + driver_features: "Mapping[str, Any] | None", + detected_extensions: "set[str] | None" = None, +) -> "tuple[StatementConfig, bool, bool]": + """Resolve detected PostgreSQL extension flags and promoted dialect.""" + detected = detected_extensions or set() + pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + paradedb_available = bool( + driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected + ) + + if statement_config.dialect == "postgres": + if paradedb_available: + statement_config = statement_config.replace(dialect="paradedb") + elif pgvector_available: + statement_config = statement_config.replace(dialect="pgvector") + + return statement_config, pgvector_available, paradedb_available + + +def resolve_runtime_statement_config( + statement_config: "StatementConfig | None", + configured_statement_config: "StatementConfig | None", + default_config: "StatementConfig", +) -> "StatementConfig": + """Resolve the effective runtime statement config for a session.""" + if statement_config is not None: + return statement_config + if configured_statement_config is not None: + return configured_statement_config + return default_config + + def collect_rows(fetched_data: "list[Any] | None", description: "list[Any] | None") -> "tuple[list[Any], list[str]]": """Collect psycopg rows and column names. diff --git a/tests/unit/adapters/test_adbc/test_extension_detection.py b/tests/unit/adapters/test_adbc/test_extension_detection.py index 69be22a48..12bd32a50 100644 --- a/tests/unit/adapters/test_adbc/test_extension_detection.py +++ b/tests/unit/adapters/test_adbc/test_extension_detection.py @@ -3,7 +3,13 @@ from unittest.mock import MagicMock from sqlspec.adapters.adbc.config import AdbcConfig -from sqlspec.adapters.adbc.core import apply_driver_features, detect_postgres_extensions, get_statement_config +from sqlspec.adapters.adbc.core import ( + apply_driver_features, + build_postgres_extension_probe_names, + detect_postgres_extensions, + get_statement_config, + resolve_postgres_extension_state, +) def test_apply_driver_features_sets_pgvector_default() -> None: @@ -30,6 +36,11 @@ def test_apply_driver_features_respects_user_overrides() -> None: assert features["enable_paradedb"] is False +def test_build_postgres_extension_probe_names_filters_disabled_features() -> None: + """Only enabled extension probes should be returned.""" + assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] + + def test_detect_postgres_extensions_returns_tuple() -> None: """detect_postgres_extensions returns (pgvector_available, paradedb_available).""" # Mock a connection with a cursor that returns pgvector extension @@ -94,6 +105,19 @@ def test_adbc_config_initializes_extension_flags_to_none() -> None: assert config._paradedb_available is None # pyright: ignore[reportPrivateUsage] +def test_resolve_postgres_extension_state_promotes_paradedb() -> None: + """Detected extensions should promote the runtime dialect.""" + statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( + get_statement_config("postgres"), + {"enable_pgvector": True, "enable_paradedb": True}, + {"vector", "pg_search"}, + ) + + assert statement_config.dialect == "paradedb" + assert pgvector_available is True + assert paradedb_available is True + + def test_adbc_config_update_dialect_for_extensions_pgvector() -> None: """Dialect switches to pgvector when pgvector is available.""" config = AdbcConfig(connection_config={"uri": "postgresql://localhost/test"}) diff --git a/tests/unit/adapters/test_asyncpg/test_config.py b/tests/unit/adapters/test_asyncpg/test_config.py index 6bc3f096b..8eac2ba6c 100644 --- a/tests/unit/adapters/test_asyncpg/test_config.py +++ b/tests/unit/adapters/test_asyncpg/test_config.py @@ -6,7 +6,11 @@ from sqlspec.adapters.asyncpg._typing import AsyncpgSessionContext from sqlspec.adapters.asyncpg.config import AsyncpgConfig -from sqlspec.adapters.asyncpg.core import build_statement_config +from sqlspec.adapters.asyncpg.core import ( + build_postgres_extension_probe_names, + build_statement_config, + resolve_postgres_extension_state, +) from sqlspec.core import StatementConfig @@ -42,6 +46,24 @@ def deserializer(_: str) -> object: assert parameter_config.json_deserializer is deserializer +def test_asyncpg_build_postgres_extension_probe_names_filters_disabled_features() -> None: + """Only enabled extension probes should be returned.""" + assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] + + +def test_asyncpg_resolve_postgres_extension_state_promotes_paradedb() -> None: + """Detected extensions should promote the runtime dialect.""" + statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( + StatementConfig(dialect="postgres"), + {"enable_pgvector": True, "enable_paradedb": True}, + {"vector", "pg_search"}, + ) + + assert statement_config.dialect == "paradedb" + assert pgvector_available is True + assert paradedb_available is True + + @pytest.mark.anyio async def test_asyncpg_session_context_resolves_callable_statement_config() -> None: """Session context should call statement_config when it's a callable.""" @@ -72,3 +94,14 @@ async def test_asyncpg_session_context_preserves_explicit_statement_config() -> async with context as driver: assert driver.statement_config is explicit_config + + +def test_asyncpg_provide_session_tracks_promoted_statement_config() -> None: + """Runtime statement config should resolve the current config dialect lazily.""" + config = AsyncpgConfig() + config.statement_config = config.statement_config.replace(dialect="pgvector") + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + assert session_config().dialect == "pgvector" diff --git a/tests/unit/adapters/test_psqlpy/test_config.py b/tests/unit/adapters/test_psqlpy/test_config.py index 764a7ef76..a89922afa 100644 --- a/tests/unit/adapters/test_psqlpy/test_config.py +++ b/tests/unit/adapters/test_psqlpy/test_config.py @@ -6,7 +6,11 @@ from sqlspec.adapters.psqlpy._typing import PsqlpySessionContext from sqlspec.adapters.psqlpy.config import PsqlpyConfig -from sqlspec.adapters.psqlpy.core import build_statement_config +from sqlspec.adapters.psqlpy.core import ( + build_postgres_extension_probe_names, + build_statement_config, + resolve_postgres_extension_state, +) from sqlspec.core import StatementConfig @@ -34,6 +38,24 @@ def serializer(_: object) -> str: assert parameter_config.json_serializer is serializer +def test_psqlpy_build_postgres_extension_probe_names_filters_disabled_features() -> None: + """Only enabled extension probes should be returned.""" + assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] + + +def test_psqlpy_resolve_postgres_extension_state_promotes_paradedb() -> None: + """Detected extensions should promote the runtime dialect.""" + statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( + StatementConfig(dialect="postgres"), + {"enable_pgvector": True, "enable_paradedb": True}, + {"vector", "pg_search"}, + ) + + assert statement_config.dialect == "paradedb" + assert pgvector_available is True + assert paradedb_available is True + + @pytest.mark.anyio async def test_psqlpy_session_context_resolves_callable_statement_config() -> None: """Session context should call statement_config when it's a callable.""" @@ -64,3 +86,14 @@ async def test_psqlpy_session_context_preserves_explicit_statement_config() -> N async with context as driver: assert driver.statement_config is explicit_config + + +def test_psqlpy_provide_session_tracks_promoted_statement_config() -> None: + """Runtime statement config should resolve the current config dialect lazily.""" + config = PsqlpyConfig() + config.statement_config = config.statement_config.replace(dialect="pgvector") + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + assert session_config().dialect == "pgvector" diff --git a/tests/unit/adapters/test_psycopg/test_config.py b/tests/unit/adapters/test_psycopg/test_config.py index ff07174e9..a87103062 100644 --- a/tests/unit/adapters/test_psycopg/test_config.py +++ b/tests/unit/adapters/test_psycopg/test_config.py @@ -5,8 +5,13 @@ import pytest from sqlspec.adapters.psycopg._typing import PsycopgAsyncSessionContext, PsycopgSyncSessionContext -from sqlspec.adapters.psycopg.config import PsycopgSyncConfig -from sqlspec.adapters.psycopg.core import build_statement_config, default_statement_config +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.psycopg.core import ( + build_postgres_extension_probe_names, + build_statement_config, + default_statement_config, + resolve_postgres_extension_state, +) from sqlspec.core import SQL, StatementConfig @@ -34,6 +39,24 @@ def serializer(_: object) -> str: assert parameter_config.json_serializer is serializer +def test_psycopg_build_postgres_extension_probe_names_filters_disabled_features() -> None: + """Only enabled extension probes should be returned.""" + assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] + + +def test_psycopg_resolve_postgres_extension_state_promotes_paradedb() -> None: + """Detected extensions should promote the runtime dialect.""" + statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( + StatementConfig(dialect="postgres"), + {"enable_pgvector": True, "enable_paradedb": True}, + {"vector", "pg_search"}, + ) + + assert statement_config.dialect == "paradedb" + assert pgvector_available is True + assert paradedb_available is True + + def test_psycopg_numeric_placeholders_convert_to_pyformat() -> None: """Numeric placeholders should be rewritten for psycopg execution.""" @@ -95,3 +118,25 @@ def test_psycopg_sync_session_context_preserves_explicit_statement_config() -> N with context as driver: assert driver.statement_config is explicit_config + + +def test_psycopg_sync_provide_session_tracks_promoted_statement_config() -> None: + """Sync runtime statement config should resolve the current config dialect lazily.""" + config = PsycopgSyncConfig() + config.statement_config = config.statement_config.replace(dialect="pgvector") + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + assert session_config().dialect == "pgvector" + + +def test_psycopg_async_provide_session_tracks_promoted_statement_config() -> None: + """Async runtime statement config should resolve the current config dialect lazily.""" + config = PsycopgAsyncConfig() + config.statement_config = config.statement_config.replace(dialect="pgvector") + + session_config = config.provide_session()._statement_config # pyright: ignore[reportPrivateUsage] + + assert callable(session_config) + assert session_config().dialect == "pgvector" From 701d702c62604dd6c1a836945ded87893394eca1 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 13 Mar 2026 17:08:30 +0000 Subject: [PATCH 02/39] refactor(config): move shared runtime helpers into core --- sqlspec/config.py | 110 ++++----- sqlspec/core/config_runtime.py | 105 +++++++++ .../unit/config/test_storage_capabilities.py | 209 +++++++++++++++++- 3 files changed, 357 insertions(+), 67 deletions(-) create mode 100644 sqlspec/core/config_runtime.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 5c23b9c7a..765d954e0 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -8,7 +8,14 @@ from typing_extensions import NotRequired, TypedDict -from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig +from sqlspec.core.config_runtime import ( + build_default_statement_config, + close_async_pool, + close_sync_pool, + create_async_pool, + create_sync_pool, + seed_runtime_driver_features, +) from sqlspec.exceptions import MissingDependencyError from sqlspec.extensions.events import EventRuntimeHints from sqlspec.loader import SQLFileLoader @@ -21,6 +28,7 @@ from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager, AbstractContextManager + from sqlspec.core import StatementConfig from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands from sqlspec.storage import StorageCapabilities @@ -1198,16 +1206,9 @@ def __init__( self._init_observability(observability_config) self._initialize_migration_components() - if statement_config is None: - default_parameter_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ) - self.statement_config = StatementConfig(dialect="sqlite", parameter_config=default_parameter_config) - else: - self.statement_config = statement_config - self.driver_features = driver_features or {} self._storage_capabilities = None - self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) + self.statement_config = statement_config or build_default_statement_config("sqlite") + self.driver_features = seed_runtime_driver_features(driver_features, self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() @@ -1369,14 +1370,9 @@ def __init__( self._init_observability(observability_config) self._initialize_migration_components() - if statement_config is None: - default_parameter_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ) - self.statement_config = StatementConfig(dialect="sqlite", parameter_config=default_parameter_config) - else: - self.statement_config = statement_config - self.driver_features = driver_features or {} + self.statement_config = statement_config or build_default_statement_config("sqlite") + self._storage_capabilities = None + self.driver_features = seed_runtime_driver_features(driver_features, self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() @@ -1539,16 +1535,9 @@ def __init__( self._init_observability(observability_config) self._initialize_migration_components() - if statement_config is None: - default_parameter_config = ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ) - self.statement_config = StatementConfig(dialect="postgres", parameter_config=default_parameter_config) - else: - self.statement_config = statement_config - self.driver_features = driver_features or {} + self.statement_config = statement_config or build_default_statement_config("postgres") self._storage_capabilities = None - self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) + self.driver_features = seed_runtime_driver_features(driver_features, self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() self._pool_lock = threading.Lock() @@ -1559,23 +1548,24 @@ def create_pool(self) -> PoolT: Returns: The created pool. """ - if self.connection_instance is not None: - return self.connection_instance - - with self._pool_lock: - if self.connection_instance is not None: - return self.connection_instance - - self.connection_instance = self._create_pool() - self.get_observability_runtime().emit_pool_create(self.connection_instance) - return self.connection_instance + existing_pool = self.connection_instance + if existing_pool is not None: + return existing_pool + + created_pool = create_sync_pool( + None, + self._pool_lock, + lambda: self.connection_instance, + self._create_pool, + self.get_observability_runtime().emit_pool_create, + ) + self.connection_instance = created_pool + return created_pool def close_pool(self) -> None: """Close the connection pool.""" pool = self.connection_instance - self._close_pool() - if pool is not None: - self.get_observability_runtime().emit_pool_destroy(pool) + close_sync_pool(pool, self._close_pool, self.get_observability_runtime().emit_pool_destroy) self.connection_instance = None def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: @@ -1742,18 +1732,9 @@ def __init__( self._init_observability(observability_config) self._initialize_migration_components() - if statement_config is None: - self.statement_config = StatementConfig( - parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} - ), - dialect="postgres", - ) - else: - self.statement_config = statement_config - self.driver_features = driver_features or {} + self.statement_config = statement_config or build_default_statement_config("postgres") self._storage_capabilities = None - self.driver_features.setdefault("storage_capabilities", self.storage_capabilities()) + self.driver_features = seed_runtime_driver_features(driver_features, self.storage_capabilities()) self._promote_driver_feature_hooks() self._configure_observability_extensions() self._pool_lock = asyncio.Lock() @@ -1764,23 +1745,24 @@ async def create_pool(self) -> PoolT: Returns: The created pool. """ - if self.connection_instance is not None: - return self.connection_instance - - async with self._pool_lock: - if self.connection_instance is not None: - return self.connection_instance - - self.connection_instance = await self._create_pool() - self.get_observability_runtime().emit_pool_create(self.connection_instance) - return self.connection_instance + existing_pool = self.connection_instance + if existing_pool is not None: + return existing_pool + + created_pool = await create_async_pool( + None, + self._pool_lock, + lambda: self.connection_instance, + self._create_pool, + self.get_observability_runtime().emit_pool_create, + ) + self.connection_instance = created_pool + return created_pool async def close_pool(self) -> None: """Close the connection pool.""" pool = self.connection_instance - await self._close_pool() - if pool is not None: - self.get_observability_runtime().emit_pool_destroy(pool) + await close_async_pool(pool, self._close_pool, self.get_observability_runtime().emit_pool_destroy) self.connection_instance = None async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT: diff --git a/sqlspec/core/config_runtime.py b/sqlspec/core/config_runtime.py new file mode 100644 index 000000000..923f35975 --- /dev/null +++ b/sqlspec/core/config_runtime.py @@ -0,0 +1,105 @@ +"""Compiled helpers for shared configuration runtime behavior.""" + +import asyncio +import threading +from typing import TYPE_CHECKING, Any, TypeVar + +from sqlspec.core import ParameterStyle, ParameterStyleConfig, StatementConfig + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + +__all__ = ( + "build_default_statement_config", + "close_async_pool", + "close_sync_pool", + "create_async_pool", + "create_sync_pool", + "seed_runtime_driver_features", +) + +PoolT = TypeVar("PoolT") + + +def build_default_statement_config(default_dialect: str) -> StatementConfig: + """Build the default statement config for a base config class.""" + return StatementConfig( + dialect=default_dialect, + parameter_config=ParameterStyleConfig( + default_parameter_style=ParameterStyle.QMARK, + supported_parameter_styles={ParameterStyle.QMARK}, + ), + ) + + +def seed_runtime_driver_features( + driver_features: "dict[str, Any] | None", storage_capabilities: "dict[str, Any] | None" +) -> "dict[str, Any]": + """Clone and seed driver feature state used on the runtime hot path.""" + seeded_features = dict(driver_features) if driver_features else {} + if storage_capabilities is not None: + seeded_features.setdefault("storage_capabilities", storage_capabilities) + return seeded_features + + +def create_sync_pool( + connection_instance: "PoolT | None", + lock: threading.Lock, + get_connection_instance: "Callable[[], PoolT | None]", + create_pool: "Callable[[], PoolT]", + emit_pool_create: "Callable[[PoolT], None]", +) -> PoolT: + """Create a sync pool once, using the existing lock/emit hooks.""" + if connection_instance is not None: + return connection_instance + + with lock: + existing_pool = get_connection_instance() + if existing_pool is not None: + return existing_pool + pool = create_pool() + emit_pool_create(pool) + return pool + + +def close_sync_pool( + connection_instance: "PoolT | None", + close_pool: "Callable[[], None]", + emit_pool_destroy: "Callable[[PoolT], None]", +) -> None: + """Close a sync pool and emit teardown hooks.""" + close_pool() + if connection_instance is not None: + emit_pool_destroy(connection_instance) + + +async def create_async_pool( + connection_instance: "PoolT | None", + lock: asyncio.Lock, + get_connection_instance: "Callable[[], PoolT | None]", + create_pool: "Callable[[], Awaitable[PoolT]]", + emit_pool_create: "Callable[[PoolT], None]", +) -> PoolT: + """Create an async pool once, using the existing lock/emit hooks.""" + if connection_instance is not None: + return connection_instance + + async with lock: + existing_pool = get_connection_instance() + if existing_pool is not None: + return existing_pool + pool = await create_pool() + emit_pool_create(pool) + return pool + + +async def close_async_pool( + connection_instance: "PoolT | None", + close_pool: "Callable[[], Awaitable[None]]", + emit_pool_destroy: "Callable[[PoolT], None]", +) -> None: + """Close an async pool and emit teardown hooks.""" + await close_pool() + if connection_instance is not None: + emit_pool_destroy(connection_instance) diff --git a/tests/unit/config/test_storage_capabilities.py b/tests/unit/config/test_storage_capabilities.py index 448e25bec..b4c90c120 100644 --- a/tests/unit/config/test_storage_capabilities.py +++ b/tests/unit/config/test_storage_capabilities.py @@ -1,8 +1,23 @@ -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager, asynccontextmanager, contextmanager from typing import TYPE_CHECKING, Any -from sqlspec.config import NoPoolSyncConfig -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase +import pytest + +from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig, NoPoolSyncConfig, SyncDatabaseConfig +from sqlspec.core.config_runtime import ( + build_default_statement_config, + close_async_pool, + close_sync_pool, + create_async_pool, + create_sync_pool, + seed_runtime_driver_features, +) +from sqlspec.driver import ( + AsyncDataDictionaryBase, + AsyncDriverAdapterBase, + SyncDataDictionaryBase, + SyncDriverAdapterBase, +) from tests.conftest import requires_interpreted pytestmark = requires_interpreted @@ -10,8 +25,14 @@ if TYPE_CHECKING: _NoPoolSyncConfigBase = NoPoolSyncConfig[Any, "_DummyDriver"] + _NoPoolAsyncConfigBase = NoPoolAsyncConfig[Any, "_AsyncDummyDriver"] + _SyncPoolConfigBase = SyncDatabaseConfig[Any, object, "_DummyDriver"] + _AsyncPoolConfigBase = AsyncDatabaseConfig[Any, object, "_AsyncDummyDriver"] else: _NoPoolSyncConfigBase = NoPoolSyncConfig + _NoPoolAsyncConfigBase = NoPoolAsyncConfig + _SyncPoolConfigBase = SyncDatabaseConfig + _AsyncPoolConfigBase = AsyncDatabaseConfig class _DummyDriver(SyncDriverAdapterBase): @@ -57,6 +78,43 @@ def dispatch_execute(self, cursor: Any, statement: Any): # type: ignore[overrid raise NotImplementedError +class _AsyncDummyDriver(AsyncDriverAdapterBase): + __slots__ = () + + @property + def data_dictionary(self) -> AsyncDataDictionaryBase: # type: ignore[override] + raise NotImplementedError + + @asynccontextmanager + async def with_cursor(self, connection: Any): # type: ignore[override] + yield object() + + @asynccontextmanager + async def handle_database_exceptions(self): # type: ignore[override] + yield None + + async def begin(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def rollback(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def commit(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def dispatch_special_handling(self, cursor: Any, statement: Any): # type: ignore[override] + return None + + async def dispatch_execute_script(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + async def dispatch_execute_many(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + async def dispatch_execute(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + class _CapabilityConfig(_NoPoolSyncConfigBase): driver_type = _DummyDriver connection_type = object @@ -81,6 +139,71 @@ def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] yield object() +class _AsyncCapabilityConfig(_NoPoolAsyncConfigBase): + driver_type = _AsyncDummyDriver + connection_type = object + supports_native_arrow_export = True + supports_native_arrow_import = True + requires_staging_for_load = True + staging_protocols = ("s3://",) + storage_partition_strategies = ("fixed", "rows_per_chunk") + + async def create_connection(self) -> object: + return object() + + @asynccontextmanager + async def provide_connection(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + @asynccontextmanager + async def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + +class _SyncPoolConfig(_SyncPoolConfigBase): + driver_type = _DummyDriver + connection_type = object + + def create_connection(self) -> object: + return object() + + @contextmanager + def provide_connection(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + @contextmanager + def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + def _create_pool(self) -> object: + return object() + + def _close_pool(self) -> None: + return None + + +class _AsyncPoolConfig(_AsyncPoolConfigBase): + driver_type = _AsyncDummyDriver + connection_type = object + + async def create_connection(self) -> object: + return object() + + @asynccontextmanager + async def provide_connection(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + @asynccontextmanager + async def provide_session(self, *args: Any, **kwargs: Any): # type: ignore[override] + yield object() + + async def _create_pool(self) -> object: + return object() + + async def _close_pool(self) -> None: + return None + + def test_storage_capabilities_snapshot(monkeypatch): monkeypatch.setattr(_CapabilityConfig, "_dependency_available", staticmethod(lambda checker: True)) config = _CapabilityConfig() @@ -107,3 +230,83 @@ def test_driver_features_seed_capabilities(monkeypatch): assert "storage_capabilities" in config.driver_features snapshot = config.driver_features["storage_capabilities"] assert isinstance(snapshot, dict) + + +def test_async_driver_features_seed_capabilities(monkeypatch): + monkeypatch.setattr(_AsyncCapabilityConfig, "_dependency_available", staticmethod(lambda checker: False)) + config = _AsyncCapabilityConfig() + assert "storage_capabilities" in config.driver_features + snapshot = config.driver_features["storage_capabilities"] + assert isinstance(snapshot, dict) + + +def test_build_default_statement_config_uses_requested_dialect() -> None: + statement_config = build_default_statement_config("postgres") + assert statement_config.dialect == "postgres" + + +def test_seed_runtime_driver_features_preserves_existing_values() -> None: + seeded = seed_runtime_driver_features({"custom": "value"}, {"arrow_export_enabled": True}) + assert seeded["custom"] == "value" + assert seeded["storage_capabilities"] == {"arrow_export_enabled": True} + + +def test_create_sync_pool_emits_observability_once() -> None: + emitted: list[object] = [] + created: list[object] = [] + config = _SyncPoolConfig() + lock = config._pool_lock # pyright: ignore[reportPrivateUsage] + + def _factory() -> object: + pool = object() + created.append(pool) + return pool + + pool = create_sync_pool(None, lock, lambda: None, _factory, emitted.append) + + assert pool is created[0] + assert emitted == [pool] + + +def test_close_sync_pool_emits_observability_once() -> None: + closed: list[str] = [] + emitted: list[object] = [] + pool = object() + + close_sync_pool(pool, lambda: closed.append("closed"), emitted.append) + + assert closed == ["closed"] + assert emitted == [pool] + + +@pytest.mark.anyio +async def test_create_async_pool_emits_observability_once() -> None: + emitted: list[object] = [] + created: list[object] = [] + config = _AsyncPoolConfig() + lock = config._pool_lock # pyright: ignore[reportPrivateUsage] + + async def _factory() -> object: + pool = object() + created.append(pool) + return pool + + pool = await create_async_pool(None, lock, lambda: None, _factory, emitted.append) + + assert pool is created[0] + assert emitted == [pool] + + +@pytest.mark.anyio +async def test_close_async_pool_emits_observability_once() -> None: + closed: list[str] = [] + emitted: list[object] = [] + pool = object() + + async def _closer() -> None: + closed.append("closed") + + await close_async_pool(pool, _closer, emitted.append) + + assert closed == ["closed"] + assert emitted == [pool] From c9791579bb12caf6ffe34ef34e33a2b5bd75be04 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 13 Mar 2026 17:36:24 +0000 Subject: [PATCH 03/39] refactor(adapters): centralize runtime contract helpers --- sqlspec/adapters/adbc/core.py | 53 ++----------------- sqlspec/adapters/asyncpg/core.py | 53 ++----------------- sqlspec/adapters/psqlpy/core.py | 53 ++----------------- sqlspec/adapters/psycopg/core.py | 53 ++----------------- sqlspec/core/config_runtime.py | 50 +++++++++++++++++ .../unit/config/test_storage_capabilities.py | 30 +++++++++++ 6 files changed, 100 insertions(+), 192 deletions(-) diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index 5af07e572..58c43ec03 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -13,6 +13,11 @@ build_null_pruning_transform, build_statement_config_from_profile, ) +from sqlspec.core.config_runtime import ( + build_postgres_extension_probe_names, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.exceptions import ( CheckViolationError, DatabaseConnectionError, @@ -216,54 +221,6 @@ def detect_postgres_extensions( except Exception: return False, False - -def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": - """Return enabled PostgreSQL extension names to probe on first connection.""" - if driver_features is None: - return [] - - extensions: list[str] = [] - if driver_features.get("enable_pgvector", False): - extensions.append("vector") - if driver_features.get("enable_paradedb", False): - extensions.append("pg_search") - return extensions - - -def resolve_postgres_extension_state( - statement_config: "StatementConfig", - driver_features: "Mapping[str, Any] | None", - detected_extensions: "set[str] | None" = None, -) -> "tuple[StatementConfig, bool, bool]": - """Resolve detected PostgreSQL extension flags and promoted dialect.""" - detected = detected_extensions or set() - pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) - paradedb_available = bool( - driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected - ) - - if statement_config.dialect == "postgres": - if paradedb_available: - statement_config = statement_config.replace(dialect="paradedb") - elif pgvector_available: - statement_config = statement_config.replace(dialect="pgvector") - - return statement_config, pgvector_available, paradedb_available - - -def resolve_runtime_statement_config( - statement_config: "StatementConfig | None", - configured_statement_config: "StatementConfig | None", - default_config: "StatementConfig", -) -> "StatementConfig": - """Resolve the effective runtime statement config for a session.""" - if statement_config is not None: - return statement_config - if configured_statement_config is not None: - return configured_statement_config - return default_config - - def normalize_driver_path(driver_name: str) -> str: """Normalize a driver name to an importable connect function path.""" stripped = driver_name.strip() diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index 2ebbc7e96..fcd39dce6 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -9,6 +9,11 @@ import asyncpg from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile +from sqlspec.core.config_runtime import ( + build_postgres_extension_probe_names, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.exceptions import ( CheckViolationError, ConnectionTimeoutError, @@ -258,54 +263,6 @@ def apply_driver_features( return statement_config, processed_features - -def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": - """Return enabled PostgreSQL extension names to probe on first connection.""" - if driver_features is None: - return [] - - extensions: list[str] = [] - if driver_features.get("enable_pgvector", False): - extensions.append("vector") - if driver_features.get("enable_paradedb", False): - extensions.append("pg_search") - return extensions - - -def resolve_postgres_extension_state( - statement_config: "StatementConfig", - driver_features: "Mapping[str, Any] | None", - detected_extensions: "set[str] | None" = None, -) -> "tuple[StatementConfig, bool, bool]": - """Resolve detected PostgreSQL extension flags and promoted dialect.""" - detected = detected_extensions or set() - pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) - paradedb_available = bool( - driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected - ) - - if statement_config.dialect == "postgres": - if paradedb_available: - statement_config = statement_config.replace(dialect="paradedb") - elif pgvector_available: - statement_config = statement_config.replace(dialect="pgvector") - - return statement_config, pgvector_available, paradedb_available - - -def resolve_runtime_statement_config( - statement_config: "StatementConfig | None", - configured_statement_config: "StatementConfig | None", - default_config: "StatementConfig", -) -> "StatementConfig": - """Resolve the effective runtime statement config for a session.""" - if statement_config is not None: - return statement_config - if configured_statement_config is not None: - return configured_statement_config - return default_config - - def parse_status(status: Any) -> int: """Parse AsyncPG status string to extract row count. diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index 81ba74a06..f2da1100c 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -8,6 +8,11 @@ from typing import TYPE_CHECKING, Any, Final, cast from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile +from sqlspec.core.config_runtime import ( + build_postgres_extension_probe_names, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.exceptions import ( CheckViolationError, ConnectionTimeoutError, @@ -259,54 +264,6 @@ def apply_driver_features( return statement_config, features - -def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": - """Return enabled PostgreSQL extension names to probe on first connection.""" - if driver_features is None: - return [] - - extensions: list[str] = [] - if driver_features.get("enable_pgvector", False): - extensions.append("vector") - if driver_features.get("enable_paradedb", False): - extensions.append("pg_search") - return extensions - - -def resolve_postgres_extension_state( - statement_config: "StatementConfig", - driver_features: "Mapping[str, Any] | None", - detected_extensions: "set[str] | None" = None, -) -> "tuple[StatementConfig, bool, bool]": - """Resolve detected PostgreSQL extension flags and promoted dialect.""" - detected = detected_extensions or set() - pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) - paradedb_available = bool( - driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected - ) - - if statement_config.dialect == "postgres": - if paradedb_available: - statement_config = statement_config.replace(dialect="paradedb") - elif pgvector_available: - statement_config = statement_config.replace(dialect="pgvector") - - return statement_config, pgvector_available, paradedb_available - - -def resolve_runtime_statement_config( - statement_config: "StatementConfig | None", - configured_statement_config: "StatementConfig | None", - default_config: "StatementConfig", -) -> "StatementConfig": - """Resolve the effective runtime statement config for a session.""" - if statement_config is not None: - return statement_config - if configured_statement_config is not None: - return configured_statement_config - return default_config - - def collect_rows(query_result: Any | None) -> "tuple[list[dict[str, Any]], list[str]]": """Collect psqlpy rows and column names. diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index 6c0bb4c29..9a52c9abd 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -13,6 +13,11 @@ StatementConfig, build_statement_config_from_profile, ) +from sqlspec.core.config_runtime import ( + build_postgres_extension_probe_names, + resolve_postgres_extension_state, + resolve_runtime_statement_config, +) from sqlspec.driver import ExecutionResult from sqlspec.exceptions import ( CheckViolationError, @@ -208,54 +213,6 @@ def apply_driver_features( return statement_config, features - -def build_postgres_extension_probe_names(driver_features: "Mapping[str, Any] | None") -> "list[str]": - """Return enabled PostgreSQL extension names to probe on first connection.""" - if driver_features is None: - return [] - - extensions: list[str] = [] - if driver_features.get("enable_pgvector", False): - extensions.append("vector") - if driver_features.get("enable_paradedb", False): - extensions.append("pg_search") - return extensions - - -def resolve_postgres_extension_state( - statement_config: "StatementConfig", - driver_features: "Mapping[str, Any] | None", - detected_extensions: "set[str] | None" = None, -) -> "tuple[StatementConfig, bool, bool]": - """Resolve detected PostgreSQL extension flags and promoted dialect.""" - detected = detected_extensions or set() - pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) - paradedb_available = bool( - driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected - ) - - if statement_config.dialect == "postgres": - if paradedb_available: - statement_config = statement_config.replace(dialect="paradedb") - elif pgvector_available: - statement_config = statement_config.replace(dialect="pgvector") - - return statement_config, pgvector_available, paradedb_available - - -def resolve_runtime_statement_config( - statement_config: "StatementConfig | None", - configured_statement_config: "StatementConfig | None", - default_config: "StatementConfig", -) -> "StatementConfig": - """Resolve the effective runtime statement config for a session.""" - if statement_config is not None: - return statement_config - if configured_statement_config is not None: - return configured_statement_config - return default_config - - def collect_rows(fetched_data: "list[Any] | None", description: "list[Any] | None") -> "tuple[list[Any], list[str]]": """Collect psycopg rows and column names. diff --git a/sqlspec/core/config_runtime.py b/sqlspec/core/config_runtime.py index 923f35975..2d7deeb64 100644 --- a/sqlspec/core/config_runtime.py +++ b/sqlspec/core/config_runtime.py @@ -12,10 +12,13 @@ __all__ = ( "build_default_statement_config", + "build_postgres_extension_probe_names", "close_async_pool", "close_sync_pool", "create_async_pool", "create_sync_pool", + "resolve_postgres_extension_state", + "resolve_runtime_statement_config", "seed_runtime_driver_features", ) @@ -43,6 +46,53 @@ def seed_runtime_driver_features( return seeded_features +def build_postgres_extension_probe_names(driver_features: "dict[str, Any] | None") -> "list[str]": + """Return enabled PostgreSQL extension names to probe on first connection.""" + if driver_features is None: + return [] + + extensions: list[str] = [] + if driver_features.get("enable_pgvector", False): + extensions.append("vector") + if driver_features.get("enable_paradedb", False): + extensions.append("pg_search") + return extensions + + +def resolve_postgres_extension_state( + statement_config: StatementConfig, + driver_features: "dict[str, Any] | None", + detected_extensions: "set[str] | None" = None, +) -> "tuple[StatementConfig, bool, bool]": + """Resolve detected PostgreSQL extension flags and promoted dialect.""" + detected = detected_extensions or set() + pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + paradedb_available = bool( + driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected + ) + + if statement_config.dialect == "postgres": + if paradedb_available: + statement_config = statement_config.replace(dialect="paradedb") + elif pgvector_available: + statement_config = statement_config.replace(dialect="pgvector") + + return statement_config, pgvector_available, paradedb_available + + +def resolve_runtime_statement_config( + statement_config: StatementConfig | None, + configured_statement_config: StatementConfig | None, + default_config: StatementConfig, +) -> StatementConfig: + """Resolve the effective runtime statement config for a session.""" + if statement_config is not None: + return statement_config + if configured_statement_config is not None: + return configured_statement_config + return default_config + + def create_sync_pool( connection_instance: "PoolT | None", lock: threading.Lock, diff --git a/tests/unit/config/test_storage_capabilities.py b/tests/unit/config/test_storage_capabilities.py index b4c90c120..b60a26432 100644 --- a/tests/unit/config/test_storage_capabilities.py +++ b/tests/unit/config/test_storage_capabilities.py @@ -4,12 +4,16 @@ import pytest from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig, NoPoolSyncConfig, SyncDatabaseConfig +from sqlspec.core import StatementConfig from sqlspec.core.config_runtime import ( build_default_statement_config, + build_postgres_extension_probe_names, close_async_pool, close_sync_pool, create_async_pool, create_sync_pool, + resolve_postgres_extension_state, + resolve_runtime_statement_config, seed_runtime_driver_features, ) from sqlspec.driver import ( @@ -245,6 +249,32 @@ def test_build_default_statement_config_uses_requested_dialect() -> None: assert statement_config.dialect == "postgres" +def test_build_postgres_extension_probe_names_filters_disabled_features() -> None: + assert build_postgres_extension_probe_names({"enable_pgvector": True, "enable_paradedb": False}) == ["vector"] + + +def test_resolve_postgres_extension_state_promotes_paradedb() -> None: + statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( + StatementConfig(dialect="postgres"), + {"enable_pgvector": True, "enable_paradedb": True}, + {"vector", "pg_search"}, + ) + + assert statement_config.dialect == "paradedb" + assert pgvector_available is True + assert paradedb_available is True + + +def test_resolve_runtime_statement_config_prefers_explicit_override() -> None: + explicit_config = StatementConfig(dialect="pgvector") + configured_config = StatementConfig(dialect="postgres") + + assert ( + resolve_runtime_statement_config(explicit_config, configured_config, build_default_statement_config("postgres")) + is explicit_config + ) + + def test_seed_runtime_driver_features_preserves_existing_values() -> None: seeded = seed_runtime_driver_features({"custom": "value"}, {"arrow_export_enabled": True}) assert seeded["custom"] == "value" From c7bcfe05b483bd1aec614d657b76e208d1ff2387 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 00:46:14 +0000 Subject: [PATCH 04/39] refactor(adbc): route runtime conversion through helper --- sqlspec/adapters/adbc/core.py | 4 +-- tests/unit/adapters/test_adbc/test_core.py | 28 +++++++++++++++++++ .../test_adbc/test_extension_detection.py | 22 +++++++++++++-- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index 58c43ec03..59b34aafd 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -5,7 +5,7 @@ from collections.abc import Sized from typing import TYPE_CHECKING, Any, cast -from sqlspec.adapters.adbc.type_converter import ADBCOutputConverter +from sqlspec.adapters.adbc.type_converter import get_adbc_type_converter from sqlspec.core import ( DriverParameterProfile, ParameterStyle, @@ -877,7 +877,7 @@ def prepare_parameters_with_casts( if isinstance(parameters, (list, tuple)): result: list[Any] = [] - converter = ADBCOutputConverter(dialect) + converter = get_adbc_type_converter(dialect) for idx, param in enumerate(parameters, start=1): cast_type = parameter_casts.get(idx, "").upper() if cast_type in {"JSON", "JSONB", "TYPE.JSON", "TYPE.JSONB"}: diff --git a/tests/unit/adapters/test_adbc/test_core.py b/tests/unit/adapters/test_adbc/test_core.py index 201506a85..c5650262c 100644 --- a/tests/unit/adapters/test_adbc/test_core.py +++ b/tests/unit/adapters/test_adbc/test_core.py @@ -3,9 +3,11 @@ from types import SimpleNamespace from typing import Any +from sqlspec.adapters.adbc import core as adbc_core from sqlspec.adapters.adbc.core import ( collect_rows, get_statement_config, + prepare_parameters_with_casts, prepare_postgres_parameters, resolve_column_names, resolve_many_rowcount, @@ -74,3 +76,29 @@ def test_collect_rows_uses_precomputed_column_names() -> None: assert data is rows assert column_names == ["id", "name"] + + +def test_prepare_parameters_with_casts_uses_type_converter_factory(monkeypatch: Any) -> None: + statement_config = get_statement_config("postgres") + factory_calls: list[tuple[str, int]] = [] + + class FakeConverter: + def convert_dict(self, value: dict[str, Any]) -> str: + return f"factory:{sorted(value.items())!r}" + + def fake_factory(dialect: str, cache_size: int = 5000) -> FakeConverter: + factory_calls.append((dialect, cache_size)) + return FakeConverter() + + monkeypatch.setattr(adbc_core, "get_adbc_type_converter", fake_factory) + + prepared = prepare_parameters_with_casts( + [{"id": 1}], + {}, + statement_config, + dialect="postgres", + json_serializer=lambda value: str(value), + ) + + assert prepared == ["factory:[('id', 1)]"] + assert factory_calls == [("postgres", 5000)] diff --git a/tests/unit/adapters/test_adbc/test_extension_detection.py b/tests/unit/adapters/test_adbc/test_extension_detection.py index 12bd32a50..331eb0d2d 100644 --- a/tests/unit/adapters/test_adbc/test_extension_detection.py +++ b/tests/unit/adapters/test_adbc/test_extension_detection.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock +from pytest import MonkeyPatch + from sqlspec.adapters.adbc.config import AdbcConfig from sqlspec.adapters.adbc.core import ( apply_driver_features, @@ -10,6 +12,7 @@ get_statement_config, resolve_postgres_extension_state, ) +from sqlspec.core import StatementConfig def test_apply_driver_features_sets_pgvector_default() -> None: @@ -148,8 +151,6 @@ def test_adbc_config_update_dialect_skips_non_postgres() -> None: def test_adbc_config_update_dialect_preserves_custom_dialect() -> None: """If user explicitly set a non-postgres dialect, don't override it.""" - from sqlspec.core import StatementConfig - config = AdbcConfig( connection_config={"uri": "postgresql://localhost/test"}, statement_config=StatementConfig(dialect="custom") ) @@ -157,3 +158,20 @@ def test_adbc_config_update_dialect_preserves_custom_dialect() -> None: config._paradedb_available = True # pyright: ignore[reportPrivateUsage] config._update_dialect_for_extensions() # pyright: ignore[reportPrivateUsage] assert config.statement_config.dialect == "custom" + + +def test_adbc_config_provide_session_skips_extension_probe_for_non_postgres(monkeypatch: MonkeyPatch) -> None: + """Non-postgres sessions should not create a connection for extension detection.""" + config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) + + def fail_create_connection() -> None: + raise AssertionError("non-postgres startup path should not probe extensions") + + monkeypatch.setattr(config, "create_connection", fail_create_connection) + + session = config.provide_session() + + assert session is not None + assert config._pgvector_available is False # pyright: ignore[reportPrivateUsage] + assert config._paradedb_available is False # pyright: ignore[reportPrivateUsage] + assert config.statement_config.dialect == "sqlite" From 89f3be1eb67277cc0fcdbe593cdfc943c2ce9b28 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 00:59:39 +0000 Subject: [PATCH 05/39] perf(coercion): keep json decode in compiled path --- sqlspec/core/type_converter.py | 7 ++++--- tests/unit/core/test_type_conversion.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sqlspec/core/type_converter.py b/sqlspec/core/type_converter.py index 0e199e22b..10250a6e6 100644 --- a/sqlspec/core/type_converter.py +++ b/sqlspec/core/type_converter.py @@ -1,5 +1,6 @@ """Base classes and detection for adapter type conversion.""" +import json import re from collections.abc import Callable from datetime import date, datetime, time, timezone @@ -10,8 +11,6 @@ from mypy_extensions import mypyc_attr -from sqlspec._serialization import decode_json - __all__ = ( "DEFAULT_CACHE_SIZE", "DEFAULT_SPECIAL_CHARS", @@ -110,7 +109,9 @@ def convert_json(value: str) -> "Any": Returns: Decoded Python object. """ - return decode_json(value) + # Keep the hot coercion path in this compiled module instead of bouncing + # through the interpreted serializer-selection shell. + return json.loads(value) def convert_decimal(value: str) -> "Decimal": diff --git a/tests/unit/core/test_type_conversion.py b/tests/unit/core/test_type_conversion.py index bbfe091e0..5b3f7d0c7 100644 --- a/tests/unit/core/test_type_conversion.py +++ b/tests/unit/core/test_type_conversion.py @@ -11,6 +11,7 @@ import pytest +import sqlspec._serialization from sqlspec.core import ( BaseTypeConverter, convert_decimal, @@ -233,6 +234,19 @@ def test_convert_json() -> None: assert result["key"] == "value" +def test_convert_json_avoids_serializer_dispatch(monkeypatch: pytest.MonkeyPatch) -> None: + """JSON coercion should not bounce through serializer selection.""" + + def fail_get_default_serializer() -> None: + raise AssertionError("convert_json should not call serializer selection") + + monkeypatch.setattr(sqlspec._serialization, "get_default_serializer", fail_get_default_serializer) + + result = convert_json('{"key": "value"}') + + assert result == {"key": "value"} + + def test_convert_decimal() -> None: """Test decimal conversion.""" decimal_str = "123.456" From 26c2144edf00918ad0853c5792b532ba83bb5b22 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 01:00:11 +0000 Subject: [PATCH 06/39] perf(parameters): reuse ordered placeholder metadata --- sqlspec/core/parameters/_converter.py | 42 ++++++++++++++++++++------- tests/unit/core/test_parameters.py | 26 +++++++++++++++++ 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/sqlspec/core/parameters/_converter.py b/sqlspec/core/parameters/_converter.py index 8fb619092..4e5887ff4 100644 --- a/sqlspec/core/parameters/_converter.py +++ b/sqlspec/core/parameters/_converter.py @@ -20,6 +20,8 @@ __all__ = ("ParameterConverter",) +_ORDERED_PARAM_INFO_MIN_SIZE = 2 + def _placeholder_qmark(_: Any) -> str: return "?" @@ -53,6 +55,29 @@ def _placeholder_positional_pyformat(_: Any) -> str: return "%s" +def _ordered_parameter_info(param_info: "list[ParameterInfo]") -> "list[ParameterInfo]": + if len(param_info) < _ORDERED_PARAM_INFO_MIN_SIZE: + return param_info + + previous_position = param_info[0].position + for param in param_info[1:]: + if param.position < previous_position: + return sorted(param_info, key=lambda item: item.position) + previous_position = param.position + return param_info + + +def _single_parameter_style(param_info: "list[ParameterInfo]") -> "ParameterStyle | None": + if not param_info: + return None + + style = param_info[0].style + for param in param_info[1:]: + if param.style != style: + return None + return style + + @mypyc_attr(allow_interpreted_subclasses=False) class ParameterConverter: """Parameter style conversion helper.""" @@ -87,8 +112,8 @@ def convert_placeholder_style( if target_style == ParameterStyle.STATIC: return self._embed_static_parameters(sql, parameters, param_info) - current_styles = {p.style for p in param_info} - if len(current_styles) == 1 and target_style in current_styles: + current_style = _single_parameter_style(param_info) + if current_style is not None and target_style == current_style: converted_parameters = self._convert_parameter_format( parameters, param_info, @@ -120,13 +145,12 @@ def _convert_placeholders_to_style( msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) - param_styles = {p.style for p in param_info} - use_sequential_for_qmark = ( - len(param_styles) == 1 and ParameterStyle.QMARK in param_styles and target_style == ParameterStyle.NUMERIC - ) + ordered_params = _ordered_parameter_info(param_info) + source_style = _single_parameter_style(ordered_params) + use_sequential_for_qmark = source_style == ParameterStyle.QMARK and target_style == ParameterStyle.NUMERIC unique_params: dict[str, int] = {} - for param in param_info: + for param in ordered_params: param_key = ( f"{param.placeholder_text}_{param.ordinal}" if use_sequential_for_qmark and param.style == ParameterStyle.QMARK @@ -135,8 +159,6 @@ def _convert_placeholders_to_style( if param_key not in unique_params: unique_params[param_key] = len(unique_params) - # Sort by position for forward iteration (O(n) string building) - sorted_params = sorted(param_info, key=lambda p: p.position) placeholder_text_len_cache: dict[str, int] = {} # Build SQL using forward iteration with list join (O(n) vs O(n^2) string slicing) segments: list[str] = [] @@ -149,7 +171,7 @@ def _convert_placeholders_to_style( ParameterStyle.POSITIONAL_COLON, } - for param in sorted_params: + for param in ordered_params: # Cache placeholder text length if param.placeholder_text not in placeholder_text_len_cache: placeholder_text_len_cache[param.placeholder_text] = len(param.placeholder_text) diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 42e65d140..98ca5ea48 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -1729,6 +1729,32 @@ def test_positional_parameter_output_type_narrowing(converter: ParameterConverte assert result_dict == (1, 2, 3) +def test_convert_placeholders_to_style_skips_sort_for_position_ordered_params(converter: ParameterConverter, monkeypatch: Any) -> None: + """Position-ordered parameter metadata should not pay an extra sorted() pass.""" + sql = "SELECT :a, :b, :c" + param_info = converter.validator.extract_parameters(sql) + + def fail_sorted(*_args: Any, **_kwargs: Any) -> object: + raise AssertionError("sorted() should not run for already ordered parameter metadata") + + monkeypatch.setattr("builtins.sorted", fail_sorted) + + converted_sql = converter._convert_placeholders_to_style(sql, param_info, ParameterStyle.NUMERIC) # pyright: ignore + + assert converted_sql == "SELECT $1, $2, $3" + + +def test_convert_placeholders_to_style_sorts_unsafely_ordered_params_as_fallback(converter: ParameterConverter) -> None: + """Manually unordered parameter metadata should still be normalized correctly.""" + sql = "SELECT :a, :b, :c" + param_info = converter.validator.extract_parameters(sql) + unordered = [param_info[2], param_info[0], param_info[1]] + + converted_sql = converter._convert_placeholders_to_style(sql, unordered, ParameterStyle.NUMERIC) # pyright: ignore + + assert converted_sql == "SELECT $1, $2, $3" + + def test_named_parameter_output_type_narrowing(converter: ParameterConverter) -> None: """Test _convert_sequence_to_dict returns NamedParameterOutput.""" sql = "SELECT * FROM table WHERE id = :id AND name = :name" From 942ba9b0a56828bdc546df399a5a50470024d1a8 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 01:01:12 +0000 Subject: [PATCH 07/39] docs(bench): record mypyc expansion matrix --- tests/unit/utils/test_bench_gate.py | 53 ++++++++++++++++++++++++ tools/scripts/bench_gate.py | 63 ++++++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 tests/unit/utils/test_bench_gate.py diff --git a/tests/unit/utils/test_bench_gate.py b/tests/unit/utils/test_bench_gate.py new file mode 100644 index 000000000..00b135ac5 --- /dev/null +++ b/tests/unit/utils/test_bench_gate.py @@ -0,0 +1,53 @@ +"""Tests for benchmark matrix metadata in bench_gate.py.""" + +import importlib.util +from pathlib import Path +from types import ModuleType + + +def _load_bench_gate_module() -> ModuleType: + module_path = Path(__file__).resolve().parents[3] / "tools" / "scripts" / "bench_gate.py" + spec = importlib.util.spec_from_file_location("bench_gate_for_tests", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_benchmark_scenario_matrix_covers_prd_hot_paths() -> None: + module = _load_bench_gate_module() + + assert set(module.BENCHMARK_SCENARIO_MATRIX) == { + "parameter_pipeline", + "coercion_engine", + "adapter_runtime_boundaries", + "storage_runtime_expansion", + "exclusion_revalidation", + } + + for entry in module.BENCHMARK_SCENARIO_MATRIX.values(): + assert entry["tracked_by"] + assert entry["goal"] + assert entry["scenarios"] + + +def test_module_admission_criteria_records_required_evidence() -> None: + module = _load_bench_gate_module() + + assert module.MODULE_ADMISSION_CRITERIA["benchmark_delta"] == "measurable or neutral" + assert module.MODULE_ADMISSION_CRITERIA["mypy_mypyc"] == "must compile cleanly" + assert module.MODULE_ADMISSION_CRITERIA["segfaults"] == "no new crashes or segfaults" + assert module.MODULE_ADMISSION_CRITERIA["any_boundaries"] == "explicitly justified" + assert module.MODULE_ADMISSION_CRITERIA["unsafe_surfaces"] == "keep Arrow/metaclass-heavy paths interpreted" + + +def test_chapter_rollout_order_starts_with_guardrails_and_parameter_work() -> None: + module = _load_bench_gate_module() + + assert module.CHAPTER_ROLLOUT_ORDER[0] == "compile-boundary-guardrails" + assert module.CHAPTER_ROLLOUT_ORDER[1:4] == ( + "compiled-parameter-pipeline", + "compiled-coercion-engine", + "adapter-runtime-boundaries", + ) diff --git a/tools/scripts/bench_gate.py b/tools/scripts/bench_gate.py index c007b5bc6..b8f449d0f 100644 --- a/tools/scripts/bench_gate.py +++ b/tools/scripts/bench_gate.py @@ -26,7 +26,16 @@ from rich.console import Console from rich.table import Table -__all__ = ("main", "print_gate_table", "run_gate") +__all__ = ( + "BENCHMARK_SCENARIO_MATRIX", + "CHAPTER_ROLLOUT_ORDER", + "DEFAULT_THRESHOLDS", + "GATE_SCENARIOS", + "MODULE_ADMISSION_CRITERIA", + "main", + "print_gate_table", + "run_gate", +) # Import bench.py from the same directory, regardless of working directory @@ -51,6 +60,58 @@ # Core scenarios to gate on GATE_SCENARIOS = ["iterative_inserts", "repeated_queries", "write_heavy", "read_heavy"] +# PRD benchmark matrix for mypyc expansion work. This keeps the benchmark +# expectations next to the scripts that actually exercise them. +BENCHMARK_SCENARIO_MATRIX: dict[str, dict[str, str | tuple[str, ...]]] = { + "parameter_pipeline": { + "tracked_by": "tools/scripts/bench_subsystems.py + tools/scripts/bench.py", + "goal": "Placeholder conversion, parameter preparation, and execute-many shaping", + "scenarios": ( + "prepare_driver_parameters (tuple)", + "prepare_driver_parameters (dict)", + "_format_parameter_set (3 params)", + "complex_parameters", + ), + }, + "coercion_engine": { + "tracked_by": "tools/scripts/bench.py", + "goal": "Schema mapping, key transformation, and JSON-heavy coercion paths", + "scenarios": ("schema_mapping", "dict_key_transform", "complex_parameters"), + }, + "adapter_runtime_boundaries": { + "tracked_by": "tools/scripts/bench.py + tools/scripts/bench_subsystems.py", + "goal": "Startup/runtime setup and end-to-end thin execution path overhead", + "scenarios": ("initialization", "session.execute() - full path"), + }, + "storage_runtime_expansion": { + "tracked_by": "tools/scripts/bench.py", + "goal": "Storage-adjacent write/read throughput until dedicated storage micro-benchmarks land", + "scenarios": ("write_heavy", "read_heavy"), + }, + "exclusion_revalidation": { + "tracked_by": "tools/scripts/bench.py", + "goal": "Regression proof that narrowing exclusions does not harm hot-path throughput", + "scenarios": ("thin_path_stress", "repeated_queries"), + }, +} + +MODULE_ADMISSION_CRITERIA: dict[str, str] = { + "benchmark_delta": "measurable or neutral", + "mypy_mypyc": "must compile cleanly", + "segfaults": "no new crashes or segfaults", + "any_boundaries": "explicitly justified", + "unsafe_surfaces": "keep Arrow/metaclass-heavy paths interpreted", +} + +CHAPTER_ROLLOUT_ORDER: tuple[str, ...] = ( + "compile-boundary-guardrails", + "compiled-parameter-pipeline", + "compiled-coercion-engine", + "adapter-runtime-boundaries", + "storage-runtime-expansion", + "exclusion-revalidation", +) + def run_gate( *, rows: int, iterations: int, warmup: int, thresholds: dict[str, float] From 7cbb8c53cde39a2c05ef18cc3dbeafa4f8f1751c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 16:25:46 +0000 Subject: [PATCH 08/39] perf(parameters): reuse placeholder metadata --- sqlspec/core/parameters/_converter.py | 116 ++++++++++++++++++-------- sqlspec/core/parameters/_processor.py | 55 ++++++++---- tests/unit/core/test_parameters.py | 40 +++++++++ 3 files changed, 160 insertions(+), 51 deletions(-) diff --git a/sqlspec/core/parameters/_converter.py b/sqlspec/core/parameters/_converter.py index 4e5887ff4..94bbbd0dd 100644 --- a/sqlspec/core/parameters/_converter.py +++ b/sqlspec/core/parameters/_converter.py @@ -78,6 +78,28 @@ def _single_parameter_style(param_info: "list[ParameterInfo]") -> "ParameterStyl return style +def _is_positional_style(style: "ParameterStyle") -> bool: + return style in { + ParameterStyle.QMARK, + ParameterStyle.NUMERIC, + ParameterStyle.POSITIONAL_PYFORMAT, + ParameterStyle.POSITIONAL_COLON, + } + + +def _parameter_lookup_key(param: "ParameterInfo", use_sequential_for_qmark: bool) -> str: + if use_sequential_for_qmark and param.style == ParameterStyle.QMARK: + return f"{param.placeholder_text}_{param.ordinal}" + return param.placeholder_text + + +def _normalized_named_parameter_name(param: "ParameterInfo") -> str: + param_name = param.name or f"param_{param.ordinal}" + if param_name.isdigit(): + return f"param_{param.ordinal}" + return param_name + + @mypyc_attr(allow_interpreted_subclasses=False) class ParameterConverter: """Parameter style conversion helper.""" @@ -106,17 +128,18 @@ def convert_placeholder_style( is_many: bool = False, *, strict_named_parameters: bool = True, + param_info: "list[ParameterInfo] | None" = None, ) -> "tuple[str, ConvertedParameters]": - param_info = self.validator.extract_parameters(sql) + extracted_param_info = param_info if param_info is not None else self.validator.extract_parameters(sql) if target_style == ParameterStyle.STATIC: - return self._embed_static_parameters(sql, parameters, param_info) + return self._embed_static_parameters(sql, parameters, extracted_param_info) - current_style = _single_parameter_style(param_info) + current_style = _single_parameter_style(extracted_param_info) if current_style is not None and target_style == current_style: converted_parameters = self._convert_parameter_format( parameters, - param_info, + extracted_param_info, target_style, parameters, preserve_parameter_format=True, @@ -125,10 +148,10 @@ def convert_placeholder_style( ) return sql, converted_parameters - converted_sql = self._convert_placeholders_to_style(sql, param_info, target_style) + converted_sql = self._convert_placeholders_to_style(sql, extracted_param_info, target_style) converted_parameters = self._convert_parameter_format( parameters, - param_info, + extracted_param_info, target_style, parameters, preserve_parameter_format=True, @@ -137,39 +160,37 @@ def convert_placeholder_style( ) return converted_sql, converted_parameters - def _convert_placeholders_to_style( - self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" - ) -> str: - generator = self._placeholder_generators.get(target_style) - if generator is None: - msg = f"Unsupported target parameter style: {target_style}" - raise ValueError(msg) - + def _build_conversion_plan( + self, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + ) -> "tuple[list[ParameterInfo], dict[str, int], bool]": ordered_params = _ordered_parameter_info(param_info) source_style = _single_parameter_style(ordered_params) use_sequential_for_qmark = source_style == ParameterStyle.QMARK and target_style == ParameterStyle.NUMERIC unique_params: dict[str, int] = {} for param in ordered_params: - param_key = ( - f"{param.placeholder_text}_{param.ordinal}" - if use_sequential_for_qmark and param.style == ParameterStyle.QMARK - else param.placeholder_text - ) + param_key = _parameter_lookup_key(param, use_sequential_for_qmark) if param_key not in unique_params: unique_params[param_key] = len(unique_params) + return ordered_params, unique_params, use_sequential_for_qmark + + def _convert_placeholders_to_style( + self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + ) -> str: + generator = self._placeholder_generators.get(target_style) + if generator is None: + msg = f"Unsupported target parameter style: {target_style}" + raise ValueError(msg) + + ordered_params, unique_params, use_sequential_for_qmark = self._build_conversion_plan(param_info, target_style) + placeholder_text_len_cache: dict[str, int] = {} # Build SQL using forward iteration with list join (O(n) vs O(n^2) string slicing) segments: list[str] = [] last_end = 0 - is_positional_style = target_style in { - ParameterStyle.QMARK, - ParameterStyle.NUMERIC, - ParameterStyle.POSITIONAL_PYFORMAT, - ParameterStyle.POSITIONAL_COLON, - } + is_positional_style = _is_positional_style(target_style) for param in ordered_params: # Cache placeholder text length @@ -179,16 +200,10 @@ def _convert_placeholders_to_style( # Generate new placeholder based on target style if is_positional_style: - param_key = ( - f"{param.placeholder_text}_{param.ordinal}" - if use_sequential_for_qmark and param.style == ParameterStyle.QMARK - else param.placeholder_text - ) + param_key = _parameter_lookup_key(param, use_sequential_for_qmark) new_placeholder = generator(unique_params[param_key]) else: - param_name = param.name or f"param_{param.ordinal}" - if isinstance(param_name, str) and param_name.isdigit(): - param_name = f"param_{param.ordinal}" + param_name = _normalized_named_parameter_name(param) new_placeholder = generator(param_name) # Append segment before this placeholder and the new placeholder @@ -200,6 +215,41 @@ def _convert_placeholders_to_style( return "".join(segments) + def convert_parameter_info_style( + self, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" + ) -> "list[ParameterInfo]": + generator = self._placeholder_generators.get(target_style) + if generator is None: + msg = f"Unsupported target parameter style: {target_style}" + raise ValueError(msg) + + ordered_params, unique_params, use_sequential_for_qmark = self._build_conversion_plan(param_info, target_style) + is_positional_style = _is_positional_style(target_style) + converted_param_info: list[ParameterInfo] = [] + + for param in ordered_params: + if is_positional_style: + converted_index = unique_params[_parameter_lookup_key(param, use_sequential_for_qmark)] + placeholder_text = generator(converted_index) + name = None + if target_style in {ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_COLON}: + name = str(converted_index + 1) + else: + name = _normalized_named_parameter_name(param) + placeholder_text = generator(name) + + converted_param_info.append( + ParameterInfo( + name=name, + style=target_style, + position=param.position, + ordinal=param.ordinal, + placeholder_text=placeholder_text, + ) + ) + + return converted_param_info + def _convert_sequence_to_dict( self, parameters: "ParameterSequence", param_info: "list[ParameterInfo]" ) -> "NamedParameterOutput": diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index 3963851b1..dd2412846 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -737,7 +737,9 @@ def _normalize_sql_for_parsing( return sql # Convert to the default style that sqlglot can parse for this dialect target_style = config.default_parameter_style - normalized_sql, _ = self._converter.convert_placeholder_style(sql, None, target_style, is_many=False) + normalized_sql, _ = self._converter.convert_placeholder_style( + sql, None, target_style, is_many=False, param_info=param_info + ) return normalized_sql def _make_processor_cache_key( @@ -948,7 +950,11 @@ def _process_internal( target_style, is_many, strict_named_parameters=config.strict_named_parameters, + param_info=param_info, ) + param_info = self._converter.convert_parameter_info_style(param_info, target_style) + original_styles = {target_style} + needs_execution_conversion = False applied_wrap_types = False if processed_parameters and wrap_types: @@ -960,14 +966,14 @@ def _process_internal( if config.type_coercion_map and processed_parameters: processed_parameters = self._coerce_parameter_types(processed_parameters, config.type_coercion_map, is_many) - processed_sql, processed_parameters = self._convert_placeholders_for_execution( - processed_sql, processed_parameters, config, original_styles, needs_execution_conversion, is_many + processed_sql, processed_parameters, converted_param_info = self._convert_placeholders_for_execution( + processed_sql, processed_parameters, config, param_info, original_styles, needs_execution_conversion, is_many ) if config.output_transformer: processed_sql, processed_parameters = config.output_transformer(processed_sql, processed_parameters) - final_param_info = self._validator.extract_parameters(processed_sql) + final_param_info = converted_param_info if converted_param_info is not None else param_info final_profile = ParameterProfile(final_param_info) sqlglot_sql = ( self._normalize_sql_for_parsing(processed_sql, final_param_info, config) @@ -1040,34 +1046,47 @@ def _convert_placeholders_for_execution( sql: str, parameters: "ParameterPayload", config: "ParameterStyleConfig", + param_info: "list[ParameterInfo]", original_styles: "set[ParameterStyle]", needs_execution_conversion: bool, is_many: bool, - ) -> "tuple[str, ConvertedParameters]": + ) -> "tuple[str, ConvertedParameters, list[ParameterInfo] | None]": if not needs_execution_conversion: # Convert parameters to concrete type for return if parameters is None: - return sql, None + return sql, None, None if isinstance(parameters, dict): - return sql, parameters + return sql, parameters, None if isinstance(parameters, list): - return sql, parameters + return sql, parameters, None if isinstance(parameters, tuple): - return sql, parameters + return sql, parameters, None if isinstance(parameters, Mapping): - return sql, dict(parameters) + return sql, dict(parameters), None if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return sql, list(parameters) - return sql, None + return sql, list(parameters), None + return sql, None, None + + target_style = self._select_execution_style(original_styles, config) + converted_param_info = self._converter.convert_parameter_info_style(param_info, target_style) if is_many and config.preserve_original_params_for_many and isinstance(parameters, (list, tuple)): - target_style = self._select_execution_style(original_styles, config) processed_sql, _ = self._converter.convert_placeholder_style( - sql, parameters, target_style, is_many, strict_named_parameters=config.strict_named_parameters + sql, + parameters, + target_style, + is_many, + strict_named_parameters=config.strict_named_parameters, + param_info=param_info, ) - return processed_sql, parameters + return processed_sql, parameters, converted_param_info - target_style = self._select_execution_style(original_styles, config) - return self._converter.convert_placeholder_style( - sql, parameters, target_style, is_many, strict_named_parameters=config.strict_named_parameters + processed_sql, processed_parameters = self._converter.convert_placeholder_style( + sql, + parameters, + target_style, + is_many, + strict_named_parameters=config.strict_named_parameters, + param_info=param_info, ) + return processed_sql, processed_parameters, converted_param_info diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 98ca5ea48..a3b5cc457 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -14,6 +14,7 @@ from decimal import Decimal from importlib import import_module from typing import Any +from unittest.mock import patch import pytest import sqlglot @@ -2064,6 +2065,45 @@ def test_multiple_unsupported_parameters_all_normalized( # Should have NUMERIC placeholders assert "$" in normalized_sql + def test_process_reuses_extracted_metadata_for_parse_normalization( + self, processor: ParameterProcessor, validator: ParameterValidator + ) -> None: + """process() should normalize SQL without re-extracting placeholders.""" + sql = "SELECT * FROM t WHERE id = %(name)s" + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NUMERIC, + default_execution_parameter_style=ParameterStyle.NAMED_PYFORMAT, + supported_parameter_styles={ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NAMED_PYFORMAT}, + ) + + with patch.object(ParameterValidator, "extract_parameters", wraps=validator.extract_parameters) as mock_extract: + result = processor.process(sql, {"name": 1}, config) + + assert mock_extract.call_count == 1 + assert result.sql == sql + assert result.sqlglot_sql == "SELECT * FROM t WHERE id = $1" + + def test_process_reuses_extracted_metadata_for_execution_conversion( + self, processor: ParameterProcessor, validator: ParameterValidator + ) -> None: + """process() should derive the final profile without re-parsing converted SQL.""" + sql = "SELECT * FROM t WHERE id = :id AND name = :name" + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.NAMED_COLON, + default_execution_parameter_style=ParameterStyle.NUMERIC, + supported_parameter_styles={ParameterStyle.NAMED_COLON, ParameterStyle.NUMERIC}, + supported_execution_parameter_styles={ParameterStyle.NUMERIC}, + ) + + with patch.object(ParameterValidator, "extract_parameters", wraps=validator.extract_parameters) as mock_extract: + result = processor.process(sql, {"id": 1, "name": "a"}, config) + + assert mock_extract.call_count == 1 + assert result.sql == "SELECT * FROM t WHERE id = $1 AND name = $2" + assert result.parameter_profile.styles == (ParameterStyle.NUMERIC.value,) + assert result.parameter_profile.named_parameters == ("1", "2") + class TestDriverProfileValidation: """Validate all driver profiles have correct supported_styles.""" From fba712fa2c41b49d1dd1be350d3d03f0a3f0a3e5 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 16:30:20 +0000 Subject: [PATCH 09/39] perf(parameters): add exact-type coercion fallback --- sqlspec/core/parameters/_processor.py | 115 ++++++++++++++++++-------- sqlspec/driver/_common.py | 74 ++++++++++++----- tests/unit/core/test_parameters.py | 20 +++++ tests/unit/driver/test_query_cache.py | 22 +++++ 4 files changed, 174 insertions(+), 57 deletions(-) diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index dd2412846..4b5b9145d 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -34,6 +34,8 @@ # Number of records to sample for type signatures _EXECUTE_MANY_SAMPLE_SIZE = 3 +TypeCoercionFallback = tuple[type, Callable[[Any], Any]] + def _structural_fingerprint(parameters: "ParameterPayload", is_many: bool = False) -> Any: """Return a structural fingerprint for caching parameter payloads. @@ -197,19 +199,50 @@ def _value_fingerprint(parameters: "ParameterPayload") -> Any: return ("values", repr(parameters)) -def _coerce_nested_value(value: object, type_coercion_map: "dict[type, Callable[[Any], Any]]") -> object: +def _type_coercion_fallbacks( + type_coercion_map: "dict[type, Callable[[Any], Any]]", +) -> "tuple[TypeCoercionFallback, ...]": + return tuple(type_coercion_map.items()) + + +def _resolve_type_coercion( + value: object, + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", +) -> object: + value_type = type(value) + exact_converter = type_coercion_map.get(value_type) + if exact_converter is not None: + return exact_converter(value) + for type_check, converter in fallback_items: + if type_check is value_type: + continue + if isinstance(value, type_check): + return converter(value) + return value + + +def _coerce_nested_value( + value: object, + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", +) -> object: # Fast type dispatch for common types value_type = type(value) if value_type is list or value_type is tuple: seq_value = cast("Sequence[Any]", value) - return [_coerce_parameter_value(item, type_coercion_map) for item in seq_value] + return [_coerce_parameter_value(item, type_coercion_map, fallback_items) for item in seq_value] if value_type is dict: dict_value = cast("dict[Any, Any]", value) - return {key: _coerce_parameter_value(val, type_coercion_map) for key, val in dict_value.items()} + return {key: _coerce_parameter_value(val, type_coercion_map, fallback_items) for key, val in dict_value.items()} return value -def _coerce_parameter_value(value: object, type_coercion_map: "dict[type, Callable[[Any], Any]]") -> object: +def _coerce_parameter_value( + value: object, + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", +) -> object: if value is None: return value @@ -221,23 +254,25 @@ def _coerce_parameter_value(value: object, type_coercion_map: "dict[type, Callab if wrapped_value is None: return wrapped_value original_type = typed_param.original_type - if original_type in type_coercion_map: - coerced = type_coercion_map[original_type](wrapped_value) - return _coerce_nested_value(coerced, type_coercion_map) - return wrapped_value - - if value_type in type_coercion_map: - coerced = type_coercion_map[value_type](value) - return _coerce_nested_value(coerced, type_coercion_map) - return value + coerced = _resolve_type_coercion(wrapped_value, type_coercion_map, fallback_items) + if coerced is wrapped_value: + return wrapped_value + return _coerce_nested_value(coerced, type_coercion_map, fallback_items) + + coerced = _resolve_type_coercion(value, type_coercion_map, fallback_items) + if coerced is value: + return value + return _coerce_nested_value(coerced, type_coercion_map, fallback_items) def _coerce_sequence_preserving_identity( - seq_value: "Sequence[Any]", type_coercion_map: "dict[type, Callable[[Any], Any]]" + seq_value: "Sequence[Any]", + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", ) -> "Sequence[Any] | list[Any]": updated_seq: list[Any] | None = None for idx, item in enumerate(seq_value): - coerced_value = _coerce_parameter_value(item, type_coercion_map) + coerced_value = _coerce_parameter_value(item, type_coercion_map, fallback_items) if updated_seq is None: if coerced_value is item: continue @@ -249,11 +284,13 @@ def _coerce_sequence_preserving_identity( def _coerce_mapping_preserving_identity( - mapping: "Mapping[Any, Any]", type_coercion_map: "dict[type, Callable[[Any], Any]]" + mapping: "Mapping[Any, Any]", + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", ) -> "Mapping[Any, Any] | dict[Any, Any]": updated_mapping: dict[Any, Any] | None = None for key, val in mapping.items(): - coerced_value = _coerce_parameter_value(val, type_coercion_map) + coerced_value = _coerce_parameter_value(val, type_coercion_map, fallback_items) if updated_mapping is None: if coerced_value is val: continue @@ -264,36 +301,43 @@ def _coerce_mapping_preserving_identity( return updated_mapping -def _coerce_parameter_set(param_set: object, type_coercion_map: "dict[type, Callable[[Any], Any]]") -> object: +def _coerce_parameter_set( + param_set: object, + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", +) -> object: # Fast type dispatch for common types param_type = type(param_set) if param_type is list: - return _coerce_sequence_preserving_identity(cast("list[Any]", param_set), type_coercion_map) + return _coerce_sequence_preserving_identity(cast("list[Any]", param_set), type_coercion_map, fallback_items) if param_type is tuple: seq_value = cast("tuple[Any, ...]", param_set) - coerced_seq = _coerce_sequence_preserving_identity(seq_value, type_coercion_map) + coerced_seq = _coerce_sequence_preserving_identity(seq_value, type_coercion_map, fallback_items) if coerced_seq is seq_value: return seq_value return tuple(cast("list[Any]", coerced_seq)) if param_type is dict: - return _coerce_mapping_preserving_identity(cast("dict[Any, Any]", param_set), type_coercion_map) + return _coerce_mapping_preserving_identity(cast("dict[Any, Any]", param_set), type_coercion_map, fallback_items) # Fallback to ABC checks for custom types if isinstance(param_set, Sequence) and not isinstance(param_set, (str, bytes)): seq_fallback = param_set - coerced_seq = _coerce_sequence_preserving_identity(seq_fallback, type_coercion_map) + coerced_seq = _coerce_sequence_preserving_identity(seq_fallback, type_coercion_map, fallback_items) if coerced_seq is seq_fallback: return param_set return coerced_seq if isinstance(param_set, Mapping): - coerced_mapping = _coerce_mapping_preserving_identity(param_set, type_coercion_map) + coerced_mapping = _coerce_mapping_preserving_identity(param_set, type_coercion_map, fallback_items) if coerced_mapping is param_set: return param_set return coerced_mapping - return _coerce_parameter_value(param_set, type_coercion_map) + return _coerce_parameter_value(param_set, type_coercion_map, fallback_items) def _coerce_parameters_payload( - parameters: "ParameterPayload", type_coercion_map: "dict[type, Callable[[Any], Any]]", is_many: bool + parameters: "ParameterPayload", + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[TypeCoercionFallback, ...]", + is_many: bool, ) -> object: # Fast type dispatch for common types param_type = type(parameters) @@ -302,7 +346,7 @@ def _coerce_parameters_payload( if is_many: updated_many: list[Any] | None = None for idx, param_set in enumerate(seq_params): - coerced_set = _coerce_parameter_set(param_set, type_coercion_map) + coerced_set = _coerce_parameter_set(param_set, type_coercion_map, fallback_items) if updated_many is None: if coerced_set is param_set: continue @@ -314,7 +358,7 @@ def _coerce_parameters_payload( updated_seq: list[Any] | None = None for idx, item in enumerate(seq_params): - coerced_item = _coerce_parameter_value(item, type_coercion_map) + coerced_item = _coerce_parameter_value(item, type_coercion_map, fallback_items) if updated_seq is None: if coerced_item is item: continue @@ -326,13 +370,13 @@ def _coerce_parameters_payload( if param_type is tuple: tuple_params = cast("tuple[Any, ...]", parameters) if is_many: - return [_coerce_parameter_set(param_set, type_coercion_map) for param_set in tuple_params] - return [_coerce_parameter_value(item, type_coercion_map) for item in tuple_params] + return [_coerce_parameter_set(param_set, type_coercion_map, fallback_items) for param_set in tuple_params] + return [_coerce_parameter_value(item, type_coercion_map, fallback_items) for item in tuple_params] if param_type is dict: dict_params = cast("dict[Any, Any]", parameters) updated_mapping: dict[Any, Any] | None = None for key, val in dict_params.items(): - coerced_value = _coerce_parameter_value(val, type_coercion_map) + coerced_value = _coerce_parameter_value(val, type_coercion_map, fallback_items) if updated_mapping is None: if coerced_value is val: continue @@ -343,12 +387,12 @@ def _coerce_parameters_payload( return updated_mapping # Fallback to ABC checks for custom types if is_many and isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return [_coerce_parameter_set(param_set, type_coercion_map) for param_set in parameters] + return [_coerce_parameter_set(param_set, type_coercion_map, fallback_items) for param_set in parameters] if isinstance(parameters, Sequence) and not isinstance(parameters, (str, bytes)): - return [_coerce_parameter_value(item, type_coercion_map) for item in parameters] + return [_coerce_parameter_value(item, type_coercion_map, fallback_items) for item in parameters] if isinstance(parameters, Mapping): - return {key: _coerce_parameter_value(val, type_coercion_map) for key, val in parameters.items()} - return _coerce_parameter_value(parameters, type_coercion_map) + return {key: _coerce_parameter_value(val, type_coercion_map, fallback_items) for key, val in parameters.items()} + return _coerce_parameter_value(parameters, type_coercion_map, fallback_items) @mypyc_attr(allow_interpreted_subclasses=False) @@ -516,7 +560,8 @@ def _coerce_parameter_types( type_coercion_map: "dict[type, Callable[[Any], Any]]", is_many: bool = False, ) -> "ConvertedParameters": - result = _coerce_parameters_payload(parameters, type_coercion_map, is_many) + fallback_items = _type_coercion_fallbacks(type_coercion_map) + result = _coerce_parameters_payload(parameters, type_coercion_map, fallback_items, is_many) # Fast type narrowing - _coerce_parameters_payload returns object but produces concrete types if result is None: return None diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index fdf5c35fb..1a9049296 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -225,6 +225,12 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: ... _CONVERT_TO_FROZENSET = object() +def _type_coercion_fallbacks(type_coercion_map: "dict[type, Any] | None") -> "tuple[tuple[type, Any], ...]": + if not type_coercion_map: + return () + return tuple(type_coercion_map.items()) + + def make_cache_key_hashable(obj: Any) -> Any: """Recursively convert unhashable types to hashable ones for cache keys. @@ -1545,29 +1551,21 @@ def prepare_driver_parameters( if is_many: if isinstance(parameters, list): type_coercion_map = statement_config.parameter_config.type_coercion_map + fallback_items = _type_coercion_fallbacks(type_coercion_map) needs_transform = False for param_set in parameters: if isinstance(param_set, dict): for value in param_set.values(): - value_type = type(value) - if value_type is TypedParameter: - needs_transform = True - break - if type_coercion_map and value_type in type_coercion_map: + if self._needs_coercion_candidate(value, type_coercion_map, fallback_items): needs_transform = True break elif isinstance(param_set, (list, tuple)): for value in param_set: - value_type = type(value) - if value_type is TypedParameter: - needs_transform = True - break - if type_coercion_map and value_type in type_coercion_map: + if self._needs_coercion_candidate(value, type_coercion_map, fallback_items): needs_transform = True break else: - value_type = type(param_set) - if value_type is TypedParameter or (type_coercion_map and value_type in type_coercion_map): + if self._needs_coercion_candidate(param_set, type_coercion_map, fallback_items): needs_transform = True if needs_transform: break @@ -1604,17 +1602,49 @@ def _apply_coercion(self, value: object, type_coercion_map: "dict[type, Callable if not type_coercion_map: return unwrapped_value - value_type = type(unwrapped_value) + return self._apply_coercion_with_fallback( + unwrapped_value, type_coercion_map, _type_coercion_fallbacks(type_coercion_map) + ) + + def _apply_coercion_with_fallback( + self, + value: object, + type_coercion_map: "dict[type, Callable[[Any], Any]]", + fallback_items: "tuple[tuple[type, Any], ...]", + ) -> object: + value_type = type(value) exact_converter = type_coercion_map.get(value_type) if exact_converter is not None: - return exact_converter(unwrapped_value) + return exact_converter(value) + + for type_check, converter in fallback_items: + if type_check is value_type: + continue + if isinstance(value, type_check): + return converter(value) + return value - for type_check, converter in type_coercion_map.items(): + def _needs_coercion_candidate( + self, + value: object, + type_coercion_map: "dict[type, Callable[[Any], Any]] | None", + fallback_items: "tuple[tuple[type, Any], ...]", + ) -> bool: + if type(value) is TypedParameter: + return True + if not type_coercion_map: + return False + + value_type = type(value) + if value_type in type_coercion_map: + return True + + for type_check, _converter in fallback_items: if type_check is value_type: continue - if isinstance(unwrapped_value, type_check): - return converter(unwrapped_value) - return unwrapped_value + if isinstance(value, type_check): + return True + return False def _format_parameter_set_for_many( self, parameters: "StatementParameters", statement_config: "StatementConfig" @@ -1636,15 +1666,15 @@ def _format_parameter_set_for_many( return [] type_coercion_map = statement_config.parameter_config.type_coercion_map - coerce_value = self._apply_coercion + fallback_items = _type_coercion_fallbacks(type_coercion_map) if not isinstance(parameters, (dict, list, tuple)): - return [coerce_value(parameters, type_coercion_map)] + return [self._apply_coercion_with_fallback(parameters, type_coercion_map, fallback_items)] if isinstance(parameters, dict): coerced_mapping: dict[str, Any] | None = None for key, value in parameters.items(): - coerced_value = coerce_value(value, type_coercion_map) + coerced_value = self._apply_coercion_with_fallback(value, type_coercion_map, fallback_items) if coerced_mapping is None: if coerced_value is value: continue @@ -1656,7 +1686,7 @@ def _format_parameter_set_for_many( updated_params: list[Any] | None = None for idx, value in enumerate(parameters): - coerced_value = coerce_value(value, type_coercion_map) + coerced_value = self._apply_coercion_with_fallback(value, type_coercion_map, fallback_items) if updated_params is None: if coerced_value is value: continue diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index a3b5cc457..2ea6e13eb 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -1164,6 +1164,26 @@ def test_process_execute_many_coerces_only_rows_that_require_conversion(processo assert tuple(result.parameters[1]) == ("v2",) +def test_process_type_coercion_supports_subclass_fallback(processor: "ParameterProcessor") -> None: + """Subclass values should still hit coercion fallback when no exact entry exists.""" + + class MyInt(int): + pass + + config = ParameterStyleConfig( + default_parameter_style=ParameterStyle.QMARK, + supported_execution_parameter_styles={ParameterStyle.QMARK}, + default_execution_parameter_style=ParameterStyle.QMARK, + type_coercion_map={int: lambda value: value + 1}, + ) + sql = "SELECT * FROM metrics WHERE value = ?" + + result = processor.process(sql, [MyInt(4)], config, wrap_types=False) + + assert isinstance(result.parameters, list) + assert result.parameters == [5] + + def test_list_parameter_preservation(converter: ParameterConverter) -> None: """Test that list parameters are properly handled.""" sql = "INSERT INTO users (id, name, active) VALUES (?, ?, ?)" diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index 181ed52c7..8cfed2a8b 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -167,6 +167,28 @@ def test_prepare_driver_parameters_many_coerces_rows_when_needed() -> None: assert tuple(prepared[1]) == ("b",) +def test_prepare_driver_parameters_many_coerces_subclass_rows_when_needed() -> None: + class MyInt(int): + pass + + config = StatementConfig( + parameter_config=ParameterStyleConfig( + default_parameter_style=ParameterStyle.QMARK, + supported_parameter_styles={ParameterStyle.QMARK}, + type_coercion_map={int: lambda value: value + 1}, + ) + ) + driver = _FakeDriver(object(), config) + parameters = [(MyInt(2),), ("b",)] + + prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) + + assert isinstance(prepared, list) + assert prepared is not parameters + assert tuple(prepared[0]) == (3,) + assert tuple(prepared[1]) == ("b",) + + def test_sync_stmt_cache_execute_direct_uses_dispatch_path(mock_sync_driver, monkeypatch) -> None: class _CursorManager: def __enter__(self) -> object: From c1fb9eff028db8a4f4705db3b888f52df0644f67 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 16:32:00 +0000 Subject: [PATCH 10/39] perf(parameters): preserve execute-many identity --- sqlspec/core/parameters/_processor.py | 24 +++++++++++++++++------- tests/unit/core/test_parameters.py | 25 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index 4b5b9145d..d245844d6 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -663,9 +663,8 @@ def _map_named_to_positional( param_type = type(parameters) if is_many and (param_type is list or param_type is tuple): - # Process each row in execute_many - result: list[Any] = [] - for row in parameters: # type: ignore[union-attr] + updated_rows: list[Any] | None = None + for idx, row in enumerate(parameters): # type: ignore[union-attr] row_type = type(row) if row_type is dict: row_dict: dict[str, Any] = row # type: ignore[assignment] @@ -676,7 +675,7 @@ def _map_named_to_positional( msg = f"Missing required parameters: {missing}" raise SQLSpecError(msg) - result.append(tuple(row_dict.get(name) for name in named_order)) + mapped_row: Any = tuple(row_dict.get(name) for name in named_order) elif isinstance(row, Mapping): # Fallback for custom Mapping types if strict: @@ -686,10 +685,21 @@ def _map_named_to_positional( msg = f"Missing required parameters: {missing}" raise SQLSpecError(msg) - result.append(tuple(row.get(name) for name in named_order)) + mapped_row = tuple(row.get(name) for name in named_order) else: - result.append(row) - return result + mapped_row = row + + if updated_rows is None: + if mapped_row is row: + continue + updated_rows = list(parameters[:idx]) # type: ignore[index] + updated_rows.append(mapped_row) + + if updated_rows is None: + return parameters + if param_type is tuple: + return tuple(updated_rows) + return updated_rows if param_type is dict: if strict: diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 2ea6e13eb..990654ba2 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -1184,6 +1184,31 @@ class MyInt(int): assert result.parameters == [5] +def test_map_named_to_positional_preserves_execute_many_identity_when_rows_are_already_positional( + processor: "ParameterProcessor", +) -> None: + """execute_many rebinding should avoid allocating when every row is already positional.""" + parameters = [(1, 2), (3, 4)] + + remapped = processor._map_named_to_positional(parameters, ("a", "b"), is_many=True) # pyright: ignore[reportPrivateUsage] + + assert remapped is parameters + + +def test_map_named_to_positional_only_allocates_when_execute_many_row_needs_mapping( + processor: "ParameterProcessor", +) -> None: + """execute_many rebinding should allocate only after the first mapping row.""" + parameters: list[object] = [(1, 2), {"a": 3, "b": 4}] + + remapped = processor._map_named_to_positional(parameters, ("a", "b"), is_many=True) # pyright: ignore[reportPrivateUsage] + + assert isinstance(remapped, list) + assert remapped is not parameters + assert remapped[0] is parameters[0] + assert remapped[1] == (3, 4) + + def test_list_parameter_preservation(converter: ParameterConverter) -> None: """Test that list parameters are properly handled.""" sql = "INSERT INTO users (id, name, active) VALUES (?, ?, ?)" From c07c6f2d4d04d017430737779f06a23d492b77fc Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 16:59:13 +0000 Subject: [PATCH 11/39] perf(schema): hoist msgspec helpers --- sqlspec/utils/schema.py | 22 +++++++++++++++------- tests/unit/driver/test_result_tools.py | 12 ++++++++++++ tests/unit/utils/test_to_value_type.py | 12 ++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/sqlspec/utils/schema.py b/sqlspec/utils/schema.py index 9989b3e4f..0161266fa 100644 --- a/sqlspec/utils/schema.py +++ b/sqlspec/utils/schema.py @@ -56,6 +56,11 @@ _DATETIME_TYPES: Final[set[type]] = {datetime.datetime, datetime.date, datetime.time} _DATETIME_TYPE_TUPLE: Final[tuple[type, ...]] = (datetime.datetime, datetime.date, datetime.time) +_MSGSPEC_RENAME_CONVERTERS: Final[dict[str, Callable[[str], str]]] = { + "camel": camelize, + "kebab": kebabize, + "pascal": pascalize, +} # ============================================================================= @@ -370,6 +375,11 @@ def _default_msgspec_deserializer( return value +_DEFAULT_MSGSPEC_DESERIALIZER: Final[Callable[[Any, Any], Any]] = partial( + _default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS +) + + def _convert_numpy_recursive(obj: Any) -> Any: """Recursively convert numpy arrays to lists. @@ -400,13 +410,11 @@ def _convert_numpy_recursive(obj: Any) -> Any: def _convert_msgspec(data: Any, schema_type: Any) -> Any: """Convert data to msgspec Struct.""" rename_config = get_msgspec_rename_config(schema_type) - deserializer = partial(_default_msgspec_deserializer, type_decoders=_DEFAULT_TYPE_DECODERS) transformed_data = data if (rename_config and is_dict(data)) or (isinstance(data, Sequence) and data and is_dict(data[0])): try: - converter_map: dict[str, Callable[[str], str]] = {"camel": camelize, "kebab": kebabize, "pascal": pascalize} - converter = converter_map.get(rename_config) if rename_config else None + converter = _MSGSPEC_RENAME_CONVERTERS.get(rename_config) if rename_config else None if converter: transformed_data = ( [transform_dict_keys(item, converter) if is_dict(item) else item for item in data] @@ -423,7 +431,7 @@ def _convert_msgspec(data: Any, schema_type: Any) -> Any: obj=transformed_data, type=(list[schema_type] if isinstance(transformed_data, Sequence) else schema_type), from_attributes=True, - dec_hook=deserializer, + dec_hook=_DEFAULT_MSGSPEC_DESERIALIZER, ) @@ -973,10 +981,10 @@ def to_value_type(value: Any, value_type: "type[ValueT]") -> "ValueT": # Schema types (Pydantic, dataclass, msgspec, attrs, TypedDict) # Deferred after scalar checks to avoid overhead for common scalar queries - schema_type_key = _detect_schema_type(value_type) # type: ignore[arg-type] - if schema_type_key is not None: + schema_converter = _get_schema_converter(value_type) # type: ignore[arg-type] + if schema_converter is not None: parsed = _ensure_json_parsed(value) - return cast("ValueT", to_schema(parsed, schema_type=value_type)) + return cast("ValueT", schema_converter(parsed, value_type)) # Fallback: try direct construction for custom types try: diff --git a/tests/unit/driver/test_result_tools.py b/tests/unit/driver/test_result_tools.py index 51ef5e8b0..4a64bf4d1 100644 --- a/tests/unit/driver/test_result_tools.py +++ b/tests/unit/driver/test_result_tools.py @@ -20,6 +20,7 @@ _default_msgspec_deserializer, _is_list_type_target, ) +import sqlspec.utils.schema as schema_utils pytestmark = pytest.mark.xdist_group("driver") @@ -291,6 +292,17 @@ def test_to_schema_mixin_with_regular_lists() -> None: assert result.metadata == {"type": "manual"} +def test_to_schema_msgspec_reuses_module_scope_helpers() -> None: + """Msgspec schema conversion should not rebuild helper partials per call.""" + test_data = {"name": "test", "embedding": [1.0, 2.0], "metadata": None} + + with patch.object(schema_utils, "partial", wraps=schema_utils.partial) as mock_partial: + result = CommonDriverAttributesMixin.to_schema(test_data, schema_type=SampleMsgspecStruct) + + assert isinstance(result, SampleMsgspecStruct) + assert mock_partial.call_count == 0 + + def test_to_schema_mixin_without_schema_type() -> None: """Test that data is returned unchanged when no schema_type is provided.""" test_data = {"name": "test", "values": [1, 2, 3]} diff --git a/tests/unit/utils/test_to_value_type.py b/tests/unit/utils/test_to_value_type.py index 62257adf6..93cce8f28 100644 --- a/tests/unit/utils/test_to_value_type.py +++ b/tests/unit/utils/test_to_value_type.py @@ -5,6 +5,7 @@ from decimal import Decimal from pathlib import Path, PurePosixPath from typing import TypedDict +from unittest.mock import patch from uuid import UUID import attrs @@ -12,6 +13,7 @@ import pytest from pydantic import BaseModel +import sqlspec.utils.schema as schema_utils from sqlspec.utils.schema import to_value_type # ============================================================================= @@ -647,6 +649,16 @@ def test_pydantic_identity(self) -> None: assert isinstance(result, UserPydantic) assert result.name == "Charlie" + def test_schema_conversion_uses_cached_converter_path(self) -> None: + """Schema conversion should not re-enter schema-type detection before dispatch.""" + data = {"name": "Alice", "email": "alice@example.com"} + + with patch.object(schema_utils, "_detect_schema_type", side_effect=AssertionError("unexpected schema detection")): + result = to_value_type(data, UserPydantic) + + assert isinstance(result, UserPydantic) + assert result.name == "Alice" + class TestDataclassConversion: """Tests for converting values to dataclasses.""" From 30e8919c7c78817894cd6db505ee52fb1aaf9941 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 17:27:59 +0000 Subject: [PATCH 12/39] perf(coercion): preserve nested identity --- sqlspec/adapters/psqlpy/core.py | 31 +++++++++++++-- sqlspec/utils/type_converters.py | 30 ++++++++++++-- tests/unit/adapters/test_psqlpy/test_core.py | 38 +++++++++++++++++- tests/unit/utils/test_type_converters.py | 42 ++++++++++++++++++++ 4 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 tests/unit/utils/test_type_converters.py diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index f2da1100c..230d8a4d2 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -292,12 +292,35 @@ def coerce_numeric_for_write(value: Any) -> Any: if isinstance(value, decimal.Decimal): return value if isinstance(value, list): - return [coerce_numeric_for_write(item) for item in value] + coerced_list: list[Any] | None = None + for index, item in enumerate(value): + coerced_item = coerce_numeric_for_write(item) + if coerced_list is None: + if coerced_item is item: + continue + coerced_list = list(value[:index]) + coerced_list.append(coerced_item) + return value if coerced_list is None else coerced_list if isinstance(value, tuple): - coerced = [coerce_numeric_for_write(item) for item in value] - return tuple(coerced) + coerced_tuple: list[Any] | None = None + for index, item in enumerate(value): + coerced_item = coerce_numeric_for_write(item) + if coerced_tuple is None: + if coerced_item is item: + continue + coerced_tuple = list(value[:index]) + coerced_tuple.append(coerced_item) + return value if coerced_tuple is None else tuple(coerced_tuple) if isinstance(value, dict): - return {key: coerce_numeric_for_write(item) for key, item in value.items()} + coerced_dict: dict[Any, Any] | None = None + for key, item in value.items(): + coerced_item = coerce_numeric_for_write(item) + if coerced_dict is None: + if coerced_item is item: + continue + coerced_dict = dict(value) + coerced_dict[key] = coerced_item + return value if coerced_dict is None else coerced_dict return value diff --git a/sqlspec/utils/type_converters.py b/sqlspec/utils/type_converters.py index 6c938c0f3..a855d9035 100644 --- a/sqlspec/utils/type_converters.py +++ b/sqlspec/utils/type_converters.py @@ -70,11 +70,35 @@ def __call__(self, value: Any) -> Any: if isinstance(value, decimal.Decimal): return self._decimal_converter(value) if isinstance(value, list): - return [self(item) for item in value] + normalized_list: list[Any] | None = None + for index, item in enumerate(value): + normalized_item = self(item) + if normalized_list is None: + if normalized_item is item: + continue + normalized_list = list(value[:index]) + normalized_list.append(normalized_item) + return value if normalized_list is None else normalized_list if isinstance(value, tuple): - return tuple(self(item) for item in value) + normalized_tuple: list[Any] | None = None + for index, item in enumerate(value): + normalized_item = self(item) + if normalized_tuple is None: + if normalized_item is item: + continue + normalized_tuple = list(value[:index]) + normalized_tuple.append(normalized_item) + return value if normalized_tuple is None else tuple(normalized_tuple) if isinstance(value, dict): - return {key: self(item) for key, item in value.items()} + normalized_dict: dict[Any, Any] | None = None + for key, item in value.items(): + normalized_item = self(item) + if normalized_dict is None: + if normalized_item is item: + continue + normalized_dict = dict(value) + normalized_dict[key] = normalized_item + return value if normalized_dict is None else normalized_dict return value diff --git a/tests/unit/adapters/test_psqlpy/test_core.py b/tests/unit/adapters/test_psqlpy/test_core.py index 7897fff8a..33b1a2abc 100644 --- a/tests/unit/adapters/test_psqlpy/test_core.py +++ b/tests/unit/adapters/test_psqlpy/test_core.py @@ -5,7 +5,12 @@ import pytest -from sqlspec.adapters.psqlpy.core import coerce_records_for_execute_many, collect_rows, format_execute_many_parameters +from sqlspec.adapters.psqlpy.core import ( + coerce_numeric_for_write, + coerce_records_for_execute_many, + collect_rows, + format_execute_many_parameters, +) pytestmark = pytest.mark.xdist_group("adapter_unit") @@ -40,6 +45,37 @@ def test_format_execute_many_parameters_with_coercion_converts_float_to_decimal( assert formatted[1][0] == 2 +def test_coerce_numeric_for_write_preserves_identity_when_unchanged() -> None: + """Nested payloads without float values should keep their existing container identities.""" + payload = {"items": [1, {"value": Decimal("1.5")}], "meta": ("a", None)} + + coerced = coerce_numeric_for_write(payload) + + assert coerced is payload + assert coerced["items"] is payload["items"] + assert coerced["items"][1] is payload["items"][1] + assert coerced["meta"] is payload["meta"] + + +def test_coerce_numeric_for_write_copies_only_changed_branch() -> None: + """Numeric write coercion should allocate only along branches containing float values.""" + payload = { + "changed": [1.5, {"value": 2.5}], + "unchanged": ("a", {"value": Decimal("3.5")}), + } + + coerced = coerce_numeric_for_write(payload) + + assert coerced == { + "changed": [Decimal("1.5"), {"value": Decimal("2.5")}], + "unchanged": ("a", {"value": Decimal("3.5")}), + } + assert coerced is not payload + assert coerced["changed"] is not payload["changed"] + assert coerced["changed"][1] is not payload["changed"][1] + assert coerced["unchanged"] is payload["unchanged"] + + def test_format_execute_many_parameters_handles_scalar_input() -> None: """Scalar execute_many payloads should be normalized to a list containing one row.""" formatted = format_execute_many_parameters(5, coerce_numeric=False) diff --git a/tests/unit/utils/test_type_converters.py b/tests/unit/utils/test_type_converters.py new file mode 100644 index 000000000..10c02dd0a --- /dev/null +++ b/tests/unit/utils/test_type_converters.py @@ -0,0 +1,42 @@ +"""Tests for nested converter helpers.""" + +from decimal import Decimal + +import pytest + +from sqlspec.utils.type_converters import build_nested_decimal_normalizer + +pytestmark = pytest.mark.xdist_group("utils") + + +def test_nested_decimal_normalizer_preserves_identity_when_unchanged() -> None: + """Unchanged nested payloads should keep their existing container identities.""" + normalizer = build_nested_decimal_normalizer(mode="float") + payload = {"items": [1, {"value": "x"}], "meta": ("a", None)} + + normalized = normalizer(payload) + + assert normalized is payload + assert normalized["items"] is payload["items"] + assert normalized["items"][1] is payload["items"][1] + assert normalized["meta"] is payload["meta"] + + +def test_nested_decimal_normalizer_copies_only_changed_branch() -> None: + """Nested normalization should allocate only along branches containing Decimal values.""" + normalizer = build_nested_decimal_normalizer(mode="float") + payload = { + "changed": [1, {"value": Decimal("1.5")}], + "unchanged": ("a", {"flag": True}), + } + + normalized = normalizer(payload) + + assert normalized == { + "changed": [1, {"value": 1.5}], + "unchanged": ("a", {"flag": True}), + } + assert normalized is not payload + assert normalized["changed"] is not payload["changed"] + assert normalized["changed"][1] is not payload["changed"][1] + assert normalized["unchanged"] is payload["unchanged"] From bd0800fc01265a9c965339e75f5d0a72bb706193 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 18:18:03 +0000 Subject: [PATCH 13/39] perf(result): cache schema conversions --- sqlspec/core/result/_base.py | 45 ++++++++++++++++++++++++--- tests/unit/core/test_result.py | 57 +++++++++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 10 deletions(-) diff --git a/sqlspec/core/result/_base.py b/sqlspec/core/result/_base.py index 1c0c6dcd0..67175ee0a 100644 --- a/sqlspec/core/result/_base.py +++ b/sqlspec/core/result/_base.py @@ -164,6 +164,8 @@ class SQLResult(StatementResult): "_materialized_dicts", "_operation_type", "_row_format", + "_schema_row_cache", + "_schema_rows_cache", "column_names", "error", "errors", @@ -238,6 +240,8 @@ def __init__( self.parameters = parameters self._row_format = row_format self._materialized_dicts: list[dict[str, Any]] | None = None + self._schema_rows_cache: dict[type[Any], list[Any]] | None = None + self._schema_row_cache: dict[type[Any], Any] | None = None self.column_names = column_names or [] self.total_count = total_count @@ -384,9 +388,40 @@ def get_data(self, *, schema_type: "type[SchemaT] | None" = None) -> "list[Schem ] data = self._get_rows() if schema_type: - return cast("list[SchemaT]", to_schema(data, schema_type=schema_type)) + return cast("list[SchemaT]", self._get_schema_rows(schema_type, data)) return data + def _get_schema_rows(self, schema_type: "type[SchemaT]", rows: "list[dict[str, Any]]") -> "list[SchemaT]": + cache = self._schema_rows_cache + if cache is not None: + cached_rows = cache.get(schema_type) + if cached_rows is not None: + return cast("list[SchemaT]", cached_rows) + converted_rows = cast("list[SchemaT]", to_schema(rows, schema_type=schema_type)) + if cache is None: + self._schema_rows_cache = {schema_type: converted_rows} + else: + cache[schema_type] = converted_rows + return converted_rows + + def _get_schema_row(self, schema_type: "type[SchemaT]", row: "dict[str, Any]") -> "SchemaT": + rows_cache = self._schema_rows_cache + if rows_cache is not None: + cached_rows = rows_cache.get(schema_type) + if cached_rows: + return cast("SchemaT", cached_rows[0]) + row_cache = self._schema_row_cache + if row_cache is not None: + cached_row = row_cache.get(schema_type) + if cached_row is not None: + return cast("SchemaT", cached_row) + converted_row = cast("SchemaT", to_schema(row, schema_type=schema_type)) + if row_cache is None: + self._schema_row_cache = {schema_type: converted_row} + else: + row_cache[schema_type] = converted_row + return converted_row + def add_statement_result(self, result: "SQLResult") -> None: """Add a statement result to the script execution results. @@ -451,7 +486,7 @@ def get_first(self, *, schema_type: "type[SchemaT] | None" = None) -> "SchemaT | return None row = rows[0] if schema_type: - return to_schema(row, schema_type=schema_type) + return self._get_schema_row(schema_type, row) return row def get_count(self) -> int: @@ -551,7 +586,7 @@ def all(self, *, schema_type: "type[SchemaT] | None" = None) -> "list[SchemaT] | """ data = self._get_rows() if schema_type: - return cast("list[SchemaT]", to_schema(data, schema_type=schema_type)) + return cast("list[SchemaT]", self._get_schema_rows(schema_type, data)) return data @overload @@ -588,7 +623,7 @@ def one(self, *, schema_type: "type[SchemaT] | None" = None) -> "SchemaT | dict[ row = rows[0] if schema_type: - return to_schema(row, schema_type=schema_type) + return self._get_schema_row(schema_type, row) return row @overload @@ -623,7 +658,7 @@ def one_or_none(self, *, schema_type: "type[SchemaT] | None" = None) -> "SchemaT row = rows[0] if schema_type: - return to_schema(row, schema_type=schema_type) + return self._get_schema_row(schema_type, row) return row def scalar(self) -> Any: diff --git a/tests/unit/core/test_result.py b/tests/unit/core/test_result.py index 6f49f93c7..c0e55f46f 100644 --- a/tests/unit/core/test_result.py +++ b/tests/unit/core/test_result.py @@ -1,9 +1,12 @@ """Tests for the SQLResult iteration functionality.""" +from dataclasses import dataclass from typing import Any, cast +from unittest.mock import patch import pytest +import sqlspec.core.result._base as result_base from sqlspec.core import SQL, ArrowResult, SQLResult, StackResult, create_sql_result from sqlspec.typing import PYARROW_INSTALLED @@ -175,7 +178,6 @@ def test_create_sql_result_iteration() -> None: def test_sql_result_get_data_with_schema_type() -> None: """Test SQLResult.get_data() with schema_type parameter.""" - from dataclasses import dataclass @dataclass class User: @@ -265,7 +267,6 @@ def test_stack_result_with_error_and_factory() -> None: def test_sql_result_all_with_schema_type() -> None: """Test SQLResult.all() with schema_type parameter.""" - from dataclasses import dataclass @dataclass class User: @@ -290,7 +291,6 @@ class User: def test_sql_result_one_with_schema_type() -> None: """Test SQLResult.one() with schema_type parameter.""" - from dataclasses import dataclass @dataclass class User: @@ -311,7 +311,6 @@ class User: def test_sql_result_one_or_none_with_schema_type() -> None: """Test SQLResult.one_or_none() with schema_type parameter.""" - from dataclasses import dataclass @dataclass class User: @@ -335,7 +334,6 @@ class User: def test_sql_result_get_first_with_schema_type() -> None: """Test SQLResult.get_first() with schema_type parameter.""" - from dataclasses import dataclass @dataclass class User: @@ -360,6 +358,55 @@ class User: assert none_user is None +def test_sql_result_reuses_cached_schema_list_conversion() -> None: + """Repeated list-shaped schema access should not re-run to_schema().""" + + @dataclass + class User: + id: int + name: str + + result = SQLResult( + statement=SQL("SELECT id, name FROM users"), + data=[{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + rows_affected=2, + ) + + original_to_schema = result_base.to_schema + with patch.object(result_base, "to_schema", wraps=original_to_schema) as mocked_to_schema: + first = result.get_data(schema_type=User) + second = result.all(schema_type=User) + + assert first is second + assert mocked_to_schema.call_count == 1 + + +def test_sql_result_reuses_cached_single_row_schema_conversion() -> None: + """Repeated single-row schema access should not re-run to_schema().""" + + @dataclass + class User: + id: int + name: str + + result = SQLResult( + statement=SQL("SELECT id, name FROM users WHERE id = 1"), + data=[{"id": 1, "name": "Alice"}], + rows_affected=1, + ) + + original_to_schema = result_base.to_schema + with patch.object(result_base, "to_schema", wraps=original_to_schema) as mocked_to_schema: + first = result.get_first(schema_type=User) + second = result.one(schema_type=User) + third = result.one_or_none(schema_type=User) + + assert isinstance(first, User) + assert first is second + assert second is third + assert mocked_to_schema.call_count == 1 + + class TestTupleFormatSchemaType: """Tests for schema_type with raw tuple data (the real driver path).""" From e40ce1e4d6227a7aaf72386e8993013e827e79f9 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 18:24:15 +0000 Subject: [PATCH 14/39] build(mypyc): include safe storage runtime modules --- pyproject.toml | 3 +++ sqlspec/core/parameters/_processor.py | 5 +++-- tests/unit/test_mypyc_config.py | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_mypyc_config.py diff --git a/pyproject.toml b/pyproject.toml index bd0cd72aa..1abd34ea4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,9 @@ include = [ "sqlspec/loader.py", # Loader module "sqlspec/observability/**/*.py", # Observability utilities "sqlspec/driver/**/*.py", # Driver module + "sqlspec/storage/registry.py", # Safe storage registry/runtime routing + "sqlspec/storage/errors.py", # Safe storage error normalization + "sqlspec/storage/backends/base.py", # Storage backend runtime base classes and iterators "sqlspec/data_dictionary/**/*.py", # Data dictionary mixin (required for adapter inheritance) "sqlspec/adapters/**/core.py", # Adapter compiled helpers "sqlspec/adapters/**/type_converter.py", # All adapters type converters diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index d245844d6..f4a428529 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -663,8 +663,9 @@ def _map_named_to_positional( param_type = type(parameters) if is_many and (param_type is list or param_type is tuple): + parameter_rows = cast("Sequence[Any]", parameters) updated_rows: list[Any] | None = None - for idx, row in enumerate(parameters): # type: ignore[union-attr] + for idx, row in enumerate(parameter_rows): row_type = type(row) if row_type is dict: row_dict: dict[str, Any] = row # type: ignore[assignment] @@ -692,7 +693,7 @@ def _map_named_to_positional( if updated_rows is None: if mapped_row is row: continue - updated_rows = list(parameters[:idx]) # type: ignore[index] + updated_rows = list(parameter_rows[:idx]) updated_rows.append(mapped_row) if updated_rows is None: diff --git a/tests/unit/test_mypyc_config.py b/tests/unit/test_mypyc_config.py new file mode 100644 index 000000000..fe7daeb02 --- /dev/null +++ b/tests/unit/test_mypyc_config.py @@ -0,0 +1,25 @@ +"""Tests for mypyc build configuration.""" + +from pathlib import Path + +import pytest + +try: + import tomllib +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + + +def test_mypyc_include_set_covers_safe_storage_runtime_modules() -> None: + """Safe storage runtime modules should be in the mypyc include set.""" + pyproject = Path(__file__).resolve().parents[2] / "pyproject.toml" + config = tomllib.loads(pyproject.read_text()) + mypyc_config = config["tool"]["hatch"]["build"]["targets"]["wheel"]["hooks"]["mypyc"] + + include = set(mypyc_config["include"]) + exclude = set(mypyc_config["exclude"]) + + assert "sqlspec/storage/registry.py" in include + assert "sqlspec/storage/errors.py" in include + assert "sqlspec/storage/backends/base.py" in include + assert "sqlspec/utils/arrow_helpers.py" in exclude From f94ed78fae22f904c8c4cb2512082183c7ccd69b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 18:46:35 +0000 Subject: [PATCH 15/39] perf(storage): tighten registry cache handling --- sqlspec/storage/pipeline.py | 19 +++++++----- sqlspec/storage/registry.py | 25 ++++++++-------- .../storage/test_storage_registry_source.py | 30 +++++++++++++++++++ 3 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 tests/unit/storage/test_storage_registry_source.py diff --git a/sqlspec/storage/pipeline.py b/sqlspec/storage/pipeline.py index f801ee0cc..fb4ae506a 100644 --- a/sqlspec/storage/pipeline.py +++ b/sqlspec/storage/pipeline.py @@ -144,6 +144,7 @@ def reset(self) -> None: _METRICS = _StorageBridgeMetrics() _RECENT_STORAGE_EVENTS: "deque[StorageTelemetry]" = deque(maxlen=25) +_EMPTY_STORAGE_OPTIONS: dict[str, Any] = {} def get_storage_bridge_metrics() -> "dict[str, int]": @@ -307,7 +308,7 @@ def _resolve_storage_backend( registry: StorageRegistry, destination: StorageDestination, backend_options: "dict[str, Any] | None" ) -> "tuple[ObjectStoreProtocol, str]": destination_str = destination.as_posix() if isinstance(destination, Path) else str(destination) - options = backend_options or {} + options = _EMPTY_STORAGE_OPTIONS if backend_options is None else backend_options alias_resolution = _resolve_alias_destination(registry, destination_str, options) if alias_resolution is not None: return alias_resolution @@ -344,12 +345,13 @@ def write_rows( serialized = serialize_collection(rows) format_choice = format_hint or "jsonl" payload = _encode_row_payload(serialized, format_choice) + resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options return self._write_bytes( payload, destination, rows=len(serialized), format_label=format_choice, - storage_options=storage_options or {}, + storage_options=resolved_options, ) def write_arrow( @@ -364,7 +366,8 @@ def write_arrow( """Write an Arrow table to storage using zero-copy buffers.""" format_choice = format_hint or "parquet" - format_write_options = (storage_options or {}).get("write_options") if format_choice == "csv" else None + resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options + format_write_options = resolved_options.get("write_options") if format_choice == "csv" else None payload = _encode_arrow_payload( table, format_choice, compression=compression, write_options=format_write_options ) @@ -373,7 +376,7 @@ def write_arrow( destination, rows=int(table.num_rows), format_label=format_choice, - storage_options=storage_options or {}, + storage_options=resolved_options, ) def read_arrow( @@ -485,12 +488,13 @@ async def write_rows( serialized = serialize_collection(rows) format_choice = format_hint or "jsonl" payload = await async_(_encode_row_payload)(serialized, format_choice) + resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options return await self._write_bytes_async( payload, destination, rows=len(serialized), format_label=format_choice, - storage_options=storage_options or {}, + storage_options=resolved_options, ) async def write_arrow( @@ -503,7 +507,8 @@ async def write_arrow( compression: str | None = None, ) -> StorageTelemetry: format_choice = format_hint or "parquet" - format_write_options = (storage_options or {}).get("write_options") if format_choice == "csv" else None + resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options + format_write_options = resolved_options.get("write_options") if format_choice == "csv" else None payload = await async_(_encode_arrow_payload)( table, format_choice, compression=compression, write_options=format_write_options ) @@ -512,7 +517,7 @@ async def write_arrow( destination, rows=int(table.num_rows), format_label=format_choice, - storage_options=storage_options or {}, + storage_options=resolved_options, ) async def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: diff --git a/sqlspec/storage/registry.py b/sqlspec/storage/registry.py index e6e6001c7..b21d8e75d 100644 --- a/sqlspec/storage/registry.py +++ b/sqlspec/storage/registry.py @@ -50,13 +50,11 @@ class StorageRegistry: backend = registry.get("s3://bucket", backend="fsspec") """ - __slots__ = ("_alias_configs", "_aliases", "_cache", "_instances") + __slots__ = ("_alias_configs", "_instances") def __init__(self) -> None: self._alias_configs: dict[str, tuple[type[ObjectStoreProtocol], str, dict[str, Any]]] = {} - self._aliases: dict[str, dict[str, Any]] = {} self._instances: dict[str | tuple[str, tuple[tuple[str, Any], ...]], ObjectStoreProtocol] = {} - self._cache: dict[str, tuple[str, type[ObjectStoreProtocol]]] = {} def _make_hashable(self, obj: Any) -> Any: """Convert nested dict/list structures to hashable tuples.""" @@ -86,10 +84,6 @@ def register_alias( if base_path: backend_config["base_path"] = base_path self._alias_configs[alias] = (backend_cls, uri, backend_config) - - test_config = dict(backend_config) - test_config["uri"] = uri - self._aliases[alias] = test_config log_with_context( logger, logging.DEBUG, @@ -278,14 +272,22 @@ def list_aliases(self) -> "list[str]": def clear_cache(self, uri_or_alias: str | None = None) -> None: """Clear resolved backend cache.""" if uri_or_alias: - self._instances.pop(uri_or_alias, None) - else: - self._instances.clear() + keys_to_remove: list[str | tuple[str, tuple[tuple[str, Any], ...]]] = [] + for key in list(self._instances): + if isinstance(key, str): + if key == uri_or_alias: + keys_to_remove.append(key) + continue + if key and key[0] == uri_or_alias: + keys_to_remove.append(key) + for key in keys_to_remove: + self._instances.pop(key, None) + return + self._instances.clear() def clear(self) -> None: """Clear all aliases and instances.""" self._alias_configs.clear() - self._aliases.clear() self._instances.clear() def clear_instances(self) -> None: @@ -295,7 +297,6 @@ def clear_instances(self) -> None: def clear_aliases(self) -> None: """Clear only aliases, keeping cached instances.""" self._alias_configs.clear() - self._aliases.clear() storage_registry = StorageRegistry() diff --git a/tests/unit/storage/test_storage_registry_source.py b/tests/unit/storage/test_storage_registry_source.py new file mode 100644 index 000000000..45498f42f --- /dev/null +++ b/tests/unit/storage/test_storage_registry_source.py @@ -0,0 +1,30 @@ +"""Source-level regressions for storage registry hot paths. + +These tests load the Python source module directly so they remain stable even +when a stale compiled extension exists in the workspace. +""" + +import importlib.util +from pathlib import Path + + +def _load_registry_source_module(): + module_path = Path(__file__).resolve().parents[3] / "sqlspec" / "storage" / "registry.py" + spec = importlib.util.spec_from_file_location("storage_registry_source_tests", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_source_registry_clear_cache_removes_parameterized_uri_entries(tmp_path: Path) -> None: + """Source implementation should evict backend-override cache entries by base URI.""" + module = _load_registry_source_module() + registry = module.StorageRegistry() + + backend1 = registry.get(f"file://{tmp_path}", backend="local") + registry.clear_cache(f"file://{tmp_path}") + backend2 = registry.get(f"file://{tmp_path}", backend="local") + + assert backend1 is not backend2 From 0045002783cf1c4714cfc7afcca2529e0b5eda3f Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 18:49:00 +0000 Subject: [PATCH 16/39] perf(bench): add storage runtime benchmarks --- tests/unit/utils/test_bench_subsystems.py | 34 ++++++++++++ tools/scripts/bench_gate.py | 12 +++- tools/scripts/bench_subsystems.py | 67 ++++++++++++++++++++++- 3 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 tests/unit/utils/test_bench_subsystems.py diff --git a/tests/unit/utils/test_bench_subsystems.py b/tests/unit/utils/test_bench_subsystems.py new file mode 100644 index 000000000..83bed0d52 --- /dev/null +++ b/tests/unit/utils/test_bench_subsystems.py @@ -0,0 +1,34 @@ +"""Tests for subsystem benchmark registration.""" + +import importlib.util +from pathlib import Path +from types import ModuleType + + +def _load_bench_subsystems_module() -> ModuleType: + module_path = Path(__file__).resolve().parents[3] / "tools" / "scripts" / "bench_subsystems.py" + spec = importlib.util.spec_from_file_location("bench_subsystems_for_tests", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_storage_subsystem_benchmarks_are_registered() -> None: + module = _load_bench_subsystems_module() + db_path = module._make_temp_db() + module._setup_test_table(db_path) + benchmarks = [] + + try: + benchmarks = module._build_benchmarks(db_path, iterations=1) + names = {benchmark.name for benchmark in benchmarks} + assert "StorageRegistry.get() - cached alias" in names + assert "SyncStoragePipeline.write_rows() - local jsonl" in names + assert "_decode_arrow_payload() - jsonl" in names + finally: + for benchmark in benchmarks: + if benchmark.name == "_cleanup_" and benchmark.setup_fn is not None: + benchmark.setup_fn() + db_path.unlink(missing_ok=True) diff --git a/tools/scripts/bench_gate.py b/tools/scripts/bench_gate.py index b8f449d0f..74c93d7ec 100644 --- a/tools/scripts/bench_gate.py +++ b/tools/scripts/bench_gate.py @@ -84,9 +84,15 @@ "scenarios": ("initialization", "session.execute() - full path"), }, "storage_runtime_expansion": { - "tracked_by": "tools/scripts/bench.py", - "goal": "Storage-adjacent write/read throughput until dedicated storage micro-benchmarks land", - "scenarios": ("write_heavy", "read_heavy"), + "tracked_by": "tools/scripts/bench.py + tools/scripts/bench_subsystems.py", + "goal": "Storage registry/runtime write overhead and JSONL-to-Arrow boundary crossings", + "scenarios": ( + "write_heavy", + "read_heavy", + "StorageRegistry.get() - cached alias", + "SyncStoragePipeline.write_rows() - local jsonl", + "_decode_arrow_payload() - jsonl", + ), }, "exclusion_revalidation": { "tracked_by": "tools/scripts/bench.py", diff --git a/tools/scripts/bench_subsystems.py b/tools/scripts/bench_subsystems.py index 64284b438..d7e6a4b9a 100644 --- a/tools/scripts/bench_subsystems.py +++ b/tools/scripts/bench_subsystems.py @@ -322,7 +322,61 @@ def bench_result_init_select() -> None: ) ) - # --- 6. Cursor context manager overhead --- + # --- 6. Storage runtime + Arrow boundary helpers --- + + from sqlspec.storage.pipeline import SyncStoragePipeline, _decode_arrow_payload, _encode_row_payload + from sqlspec.storage.registry import StorageRegistry + + storage_root = db_path.parent / "bench-storage-runtime" + storage_root.mkdir(exist_ok=True) + storage_path = storage_root / "bench-storage.jsonl" + storage_uri = f"file://{storage_root}" + storage_destination = "alias://bench_store/bench-storage.jsonl" + storage_rows = [{"id": idx, "label": f"value_{idx}"} for idx in range(16)] + storage_registry = StorageRegistry() + storage_registry.register_alias("bench_store", storage_uri, backend="local") + storage_registry.get("bench_store") + storage_pipeline = SyncStoragePipeline(registry=storage_registry) + jsonl_payload = _encode_row_payload(storage_rows, "jsonl") + storage_pipeline.write_rows(storage_rows, storage_destination, format_hint="jsonl") + + def bench_storage_registry_cached_alias() -> None: + storage_registry.get("bench_store") + + benchmarks.append( + SubsystemBenchmark( + name="StorageRegistry.get() - cached alias", + bench_fn=bench_storage_registry_cached_alias, + iterations=iterations, + description="Resolve a cached storage alias through the registry hot path", + ) + ) + + def bench_storage_write_rows_local() -> None: + storage_pipeline.write_rows(storage_rows, storage_destination, format_hint="jsonl") + + benchmarks.append( + SubsystemBenchmark( + name="SyncStoragePipeline.write_rows() - local jsonl", + bench_fn=bench_storage_write_rows_local, + iterations=iterations, + description="Encode rows and route them through the sync local storage pipeline", + ) + ) + + def bench_storage_decode_jsonl_arrow() -> None: + _decode_arrow_payload(jsonl_payload, "jsonl") + + benchmarks.append( + SubsystemBenchmark( + name="_decode_arrow_payload() - jsonl", + bench_fn=bench_storage_decode_jsonl_arrow, + iterations=iterations, + description="Decode JSONL bytes through the Arrow boundary helper", + ) + ) + + # --- 7. Cursor context manager overhead --- raw_conn = sqlite3.connect(str(db_path)) raw_conn.execute("PRAGMA journal_mode = WAL") @@ -356,7 +410,7 @@ def bench_raw_cursor() -> None: ) ) - # --- 7. Full execute() overhead (single statement, end-to-end) --- + # --- 8. Full execute() overhead (single statement, end-to-end) --- def bench_full_execute() -> None: session.execute("INSERT INTO test (value) VALUES (?)", ("bench_e2e",)) @@ -383,13 +437,20 @@ def bench_raw_execute() -> None: ) ) + def cleanup_benchmarks() -> None: + _session_ctx.__exit__(None, None, None) + raw_conn.close() + storage_path.unlink(missing_ok=True) + with suppress(OSError): + storage_root.rmdir() + # Store session context for cleanup benchmarks.append( SubsystemBenchmark( name="_cleanup_", bench_fn=lambda: None, iterations=0, - setup_fn=lambda: (_session_ctx.__exit__(None, None, None), raw_conn.close()), + setup_fn=cleanup_benchmarks, ) ) From d60af1d0c7709177c0bb826ef5a9551fcd7a0641 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 19:04:34 +0000 Subject: [PATCH 17/39] build(mypyc): inventory compiled surface --- tests/unit/utils/test_mypyc_inventory.py | 72 ++++++++++ tools/scripts/mypyc_inventory.py | 175 +++++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 tests/unit/utils/test_mypyc_inventory.py create mode 100644 tools/scripts/mypyc_inventory.py diff --git a/tests/unit/utils/test_mypyc_inventory.py b/tests/unit/utils/test_mypyc_inventory.py new file mode 100644 index 000000000..9cc397aa9 --- /dev/null +++ b/tests/unit/utils/test_mypyc_inventory.py @@ -0,0 +1,72 @@ +"""Tests for mypyc inventory reporting.""" + +import importlib.util +from pathlib import Path +from types import ModuleType + + +def _load_mypyc_inventory_module() -> ModuleType: + module_path = Path(__file__).resolve().parents[3] / "tools" / "scripts" / "mypyc_inventory.py" + spec = importlib.util.spec_from_file_location("mypyc_inventory_for_tests", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_build_inventory_reports_current_compiled_surface() -> None: + module = _load_mypyc_inventory_module() + + inventory = module.build_inventory() + + assert inventory["summary"] == { + "compiled_count": 60, + "interpreted_count": 335, + "total_modules": 395, + } + + hot_surfaces = inventory["hot_surfaces"] + assert hot_surfaces["sqlspec/config.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/config.py"]["classification"] == "helper_split_first" + assert hot_surfaces["sqlspec/base.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/base.py"]["classification"] == "helper_split_first" + assert hot_surfaces["sqlspec/_serialization.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/_serialization.py"]["classification"] == "prove_separately" + assert hot_surfaces["sqlspec/storage/pipeline.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/storage/pipeline.py"]["classification"] == "helper_split_first" + assert hot_surfaces["sqlspec/storage/registry.py"]["status"] == "compiled" + assert hot_surfaces["sqlspec/storage/errors.py"]["status"] == "compiled" + assert hot_surfaces["sqlspec/storage/_utils.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/utils/module_loader.py"]["status"] == "interpreted" + assert hot_surfaces["sqlspec/utils/arrow_helpers.py"]["status"] == "interpreted" + + +def test_build_inventory_preserves_known_exclusions() -> None: + module = _load_mypyc_inventory_module() + + inventory = module.build_inventory() + + assert inventory["preserved_exclusions"] == [ + "sqlspec/adapters/**/data_dictionary.py", + "sqlspec/builder/_vector_expressions.py", + "sqlspec/config.py", + "sqlspec/data_dictionary/_loader.py", + "sqlspec/dialects/**", + "sqlspec/migrations/commands.py", + "sqlspec/observability/_formatting.py", + "sqlspec/utils/arrow_helpers.py", + ] + + +def test_build_inventory_summarizes_adapter_shells() -> None: + module = _load_mypyc_inventory_module() + + inventory = module.build_inventory() + + assert inventory["adapter_config_shells"]["count"] > 0 + assert inventory["adapter_config_shells"]["status"] == "interpreted" + assert inventory["adapter_config_shells"]["classification"] == "helper_split_first" + assert inventory["adapter_core_helpers"]["count"] > 0 + assert inventory["adapter_core_helpers"]["status"] == "compiled" + assert inventory["adapter_core_helpers"]["classification"] == "compile_now" diff --git a/tools/scripts/mypyc_inventory.py b/tools/scripts/mypyc_inventory.py new file mode 100644 index 000000000..ecb205f05 --- /dev/null +++ b/tools/scripts/mypyc_inventory.py @@ -0,0 +1,175 @@ +"""Inventory the current mypyc compiled vs interpreted module surface.""" + +from fnmatch import fnmatch +from pathlib import Path +from typing import Any + +try: + import tomllib +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + +__all__ = ( + "HOT_SURFACE_CLASSIFICATIONS", + "build_inventory", + "classify_module", + "list_sqlspec_modules", + "load_mypyc_patterns", +) + + +HOT_SURFACE_CLASSIFICATIONS: dict[str, dict[str, str]] = { + "sqlspec/config.py": { + "classification": "helper_split_first", + "reason": "Owns runtime hooks, migration setup, and observability/bootstrap orchestration.", + }, + "sqlspec/base.py": { + "classification": "helper_split_first", + "reason": "Registry/session wrappers still manage runtime pool and telemetry orchestration.", + }, + "sqlspec/_serialization.py": { + "classification": "prove_separately", + "reason": "Serializer selection remains dynamic and historically broke same-unit coercion paths.", + }, + "sqlspec/storage/pipeline.py": { + "classification": "helper_split_first", + "reason": "Most orchestration is safe, but Arrow encode/decode helpers still cross interpreted boundaries.", + }, + "sqlspec/storage/registry.py": { + "classification": "compile_now", + "reason": "Pure routing/cache logic with backend selection only.", + }, + "sqlspec/storage/errors.py": { + "classification": "compile_now", + "reason": "Storage error normalization is typed runtime logic with no Arrow dependence.", + }, + "sqlspec/storage/_utils.py": { + "classification": "helper_split_first", + "reason": "Path resolution is safe, but the same module owns dynamic PyArrow import shims.", + }, + "sqlspec/utils/module_loader.py": { + "classification": "keep_interpreted", + "reason": "Heavy dynamic import and optional dependency probing surface.", + }, + "sqlspec/utils/serializers.py": { + "classification": "compile_now", + "reason": "Already part of the compiled utility path and performance sensitive.", + }, + "sqlspec/utils/sync_tools.py": { + "classification": "compile_now", + "reason": "Hot async bridge helpers are already in the include set.", + }, + "sqlspec/utils/schema.py": { + "classification": "compile_now", + "reason": "Core schema conversion path is already compiled and actively optimized.", + }, + "sqlspec/utils/type_converters.py": { + "classification": "compile_now", + "reason": "Compiled adapter coercion helpers are on the hot path.", + }, + "sqlspec/storage/backends/base.py": { + "classification": "compile_now", + "reason": "Mypyc-safe runtime base classes and iterator wrappers.", + }, + "sqlspec/utils/arrow_helpers.py": { + "classification": "keep_interpreted", + "reason": "Direct PyArrow boundary with historical mypyc segfault risk.", + }, +} + + +def load_mypyc_patterns(root: Path) -> tuple[list[str], list[str]]: + """Load mypyc include/exclude glob patterns from pyproject.toml.""" + + config = tomllib.loads((root / "pyproject.toml").read_text()) + mypyc_config = config["tool"]["hatch"]["build"]["targets"]["wheel"]["hooks"]["mypyc"] + return list(mypyc_config["include"]), list(mypyc_config["exclude"]) + + +def list_sqlspec_modules(root: Path) -> list[str]: + """Return all Python module paths under sqlspec/.""" + + return sorted(str(path.relative_to(root)).replace("\\", "/") for path in (root / "sqlspec").rglob("*.py")) + + +def classify_module(module_path: str, include_patterns: list[str], exclude_patterns: list[str]) -> str: + """Return whether a module is currently compiled or interpreted.""" + + included = any(fnmatch(module_path, pattern) for pattern in include_patterns) + excluded = any(fnmatch(module_path, pattern) for pattern in exclude_patterns) + return "compiled" if included and not excluded else "interpreted" + + +def build_inventory(root: Path | None = None) -> dict[str, Any]: + """Build the current module inventory and hot-surface classification.""" + + project_root = root or Path(__file__).resolve().parents[2] + include_patterns, exclude_patterns = load_mypyc_patterns(project_root) + modules = list_sqlspec_modules(project_root) + + compiled: list[str] = [] + interpreted: list[str] = [] + for module in modules: + if classify_module(module, include_patterns, exclude_patterns) == "compiled": + compiled.append(module) + else: + interpreted.append(module) + + hot_surfaces: dict[str, dict[str, str]] = {} + for module_path, details in HOT_SURFACE_CLASSIFICATIONS.items(): + hot_surfaces[module_path] = { + "status": classify_module(module_path, include_patterns, exclude_patterns), + "classification": details["classification"], + "reason": details["reason"], + } + + adapter_configs = sorted( + module + for module in modules + if module.startswith("sqlspec/adapters/") and module.endswith("/config.py") + ) + adapter_cores = sorted(module for module in modules if module.startswith("sqlspec/adapters/") and module.endswith("/core.py")) + + return { + "summary": { + "compiled_count": len(compiled), + "interpreted_count": len(interpreted), + "total_modules": len(modules), + }, + "compiled_modules": compiled, + "interpreted_modules": interpreted, + "adapter_config_shells": { + "count": len(adapter_configs), + "modules": adapter_configs, + "status": "interpreted", + "classification": "helper_split_first", + }, + "adapter_core_helpers": { + "count": len(adapter_cores), + "modules": adapter_cores, + "status": "compiled", + "classification": "compile_now", + }, + "preserved_exclusions": sorted( + pattern + for pattern in exclude_patterns + if pattern + in { + "sqlspec/dialects/**", + "sqlspec/utils/arrow_helpers.py", + "sqlspec/builder/_vector_expressions.py", + "sqlspec/data_dictionary/_loader.py", + "sqlspec/adapters/**/data_dictionary.py", + "sqlspec/observability/_formatting.py", + "sqlspec/migrations/commands.py", + "sqlspec/config.py", + } + ), + "hot_surfaces": hot_surfaces, + } + + +if __name__ == "__main__": # pragma: no cover + import json + + print(json.dumps(build_inventory(), indent=2)) From 48a1990c598231c33c0017eecf0887ae1d36e718 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 19:10:29 +0000 Subject: [PATCH 18/39] build(mypyc): map hot boundary crossings --- tests/unit/utils/test_mypyc_boundary_map.py | 121 +++++++ tools/scripts/mypyc_boundary_map.py | 331 ++++++++++++++++++++ 2 files changed, 452 insertions(+) create mode 100644 tests/unit/utils/test_mypyc_boundary_map.py create mode 100644 tools/scripts/mypyc_boundary_map.py diff --git a/tests/unit/utils/test_mypyc_boundary_map.py b/tests/unit/utils/test_mypyc_boundary_map.py new file mode 100644 index 000000000..cbab93b0d --- /dev/null +++ b/tests/unit/utils/test_mypyc_boundary_map.py @@ -0,0 +1,121 @@ +"""Tests for mypyc boundary inventory reporting.""" + +import importlib.util +from pathlib import Path +from types import ModuleType + + +def _load_mypyc_boundary_map_module() -> ModuleType: + module_path = Path(__file__).resolve().parents[3] / "tools" / "scripts" / "mypyc_boundary_map.py" + spec = importlib.util.spec_from_file_location("mypyc_boundary_map_for_tests", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_build_boundary_map_reports_expected_hot_groups() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + assert boundary_map["summary"] == { + "config_runtime_edges": 2, + "adapter_config_core_edges": 17, + "interpreted_to_compiled_adapter_edges": 16, + "serializer_bridges": 18, + "storage_arrow_edges": 3, + "any_audit_seams": 8, + "exclusion_revalidation_buckets": 7, + } + + +def test_build_boundary_map_captures_config_and_storage_runtime_edges() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + config_boundaries = { + (entry["from_module"], entry["to_module"]): entry for entry in boundary_map["config_runtime_boundaries"] + } + config_runtime_edge = config_boundaries[("sqlspec/config.py", "sqlspec/core/config_runtime.py")] + assert config_runtime_edge["classification"] == "interpreted_runtime_helper_boundary" + assert config_runtime_edge["from_status"] == "interpreted" + assert config_runtime_edge["to_status"] == "interpreted" + + module_loader_edge = config_boundaries[("sqlspec/config.py", "sqlspec/utils/module_loader.py")] + assert module_loader_edge["classification"] == "interpreted_optional_dependency_boundary" + assert module_loader_edge["sites"][1]["symbol"] == "_build_storage_capabilities" + + storage_boundaries = { + (entry["from_module"], entry["to_module"]): entry for entry in boundary_map["storage_arrow_boundaries"] + } + pipeline_edge = storage_boundaries[("sqlspec/storage/pipeline.py", "sqlspec/storage/_utils.py")] + assert pipeline_edge["classification"] == "interpreted_to_interpreted_arrow_boundary" + assert pipeline_edge["sites"][3]["symbol"] == "SyncStoragePipeline.write_arrow" + + +def test_build_boundary_map_lists_adapter_config_to_compiled_core_crossings() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + adapter_edges = { + (entry["from_module"], entry["to_module"]): entry for entry in boundary_map["adapter_config_core_boundaries"] + } + + psqlpy_edge = adapter_edges[("sqlspec/adapters/psqlpy/config.py", "sqlspec/adapters/psqlpy/core.py")] + assert psqlpy_edge["classification"] == "interpreted_to_compiled" + assert psqlpy_edge["helpers"] == [ + "apply_driver_features", + "build_connection_config", + "build_postgres_extension_probe_names", + "default_statement_config", + "resolve_postgres_extension_state", + "resolve_runtime_statement_config", + ] + + asyncpg_edge = adapter_edges[("sqlspec/adapters/asyncpg/config.py", "sqlspec/adapters/asyncpg/core.py")] + assert asyncpg_edge["classification"] == "interpreted_to_compiled" + assert "register_json_codecs" in asyncpg_edge["helpers"] + assert "register_pgvector_support" in asyncpg_edge["helpers"] + + mock_edge = adapter_edges[("sqlspec/adapters/mock/config.py", "sqlspec/adapters/mock/core.py")] + assert mock_edge["classification"] == "same_mode_import" + assert mock_edge["to_status"] == "interpreted" + + +def test_build_boundary_map_tracks_serializer_and_any_seams() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + serializer_edges = {entry["from_module"]: entry for entry in boundary_map["serializer_bridges"]} + assert serializer_edges["sqlspec/adapters/psqlpy/core.py"]["terminal_module"] == "sqlspec/_serialization.py" + assert serializer_edges["sqlspec/adapters/psqlpy/core.py"]["helpers"] == ["to_json"] + assert serializer_edges["sqlspec/adapters/asyncpg/core.py"]["helpers"] == ["from_json", "to_json"] + assert serializer_edges["sqlspec/core/parameters/_registry.py"]["classification"] == ( + "compiled_to_interpreted_json_boundary" + ) + + any_seams = { + (entry["module"], entry["symbol"]): entry for entry in boundary_map["any_audit_matrix"] + } + assert any_seams[("sqlspec/config.py", "_DriverFeatureHookWrapper.__init__")]["annotation"] == "Callable[..., Any]" + assert any_seams[("sqlspec/storage/pipeline.py", "_encode_arrow_payload")]["annotation"] == ( + "write_options: dict[str, Any] | None" + ) + + +def test_build_boundary_map_seeds_exclusion_revalidation_buckets() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + exclusion_seed = boundary_map["exclusion_revalidation_seed"] + assert exclusion_seed["sqlspec/utils/arrow_helpers.py"]["bucket"] == "hard_block" + assert exclusion_seed["sqlspec/adapters/**/data_dictionary.py"]["bucket"] == "hard_block" + assert exclusion_seed["sqlspec/builder/_vector_expressions.py"]["bucket"] == "helper_split" + assert exclusion_seed["sqlspec/data_dictionary/_loader.py"]["bucket"] == "helper_split" + assert exclusion_seed["sqlspec/dialects/**"]["bucket"] == "low_roi" diff --git a/tools/scripts/mypyc_boundary_map.py b/tools/scripts/mypyc_boundary_map.py new file mode 100644 index 000000000..cd2e61336 --- /dev/null +++ b/tools/scripts/mypyc_boundary_map.py @@ -0,0 +1,331 @@ +"""Map current interpreted/compiled hot boundaries for mypyc rollout work.""" + +import ast +from fnmatch import fnmatch +from pathlib import Path +from typing import Any + +try: + import tomllib +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + +__all__ = ( + "ANY_AUDIT_SEAMS", + "CONFIG_RUNTIME_BOUNDARIES", + "EXCLUSION_REVALIDATION_SEED", + "STORAGE_ARROW_BOUNDARIES", + "build_boundary_map", + "classify_module", + "collect_adapter_core_boundaries", + "collect_serializer_bridges", + "load_mypyc_patterns", +) + + +CONFIG_RUNTIME_BOUNDARIES: tuple[dict[str, Any], ...] = ( + { + "from_module": "sqlspec/config.py", + "to_module": "sqlspec/core/config_runtime.py", + "sites": [ + {"line": 11, "symbol": "config_runtime import"}, + {"line": 1210, "symbol": "build_default_statement_config"}, + {"line": 1211, "symbol": "seed_runtime_driver_features"}, + {"line": 1555, "symbol": "create_sync_pool"}, + {"line": 1568, "symbol": "close_sync_pool"}, + {"line": 1752, "symbol": "create_async_pool"}, + {"line": 1765, "symbol": "close_async_pool"}, + ], + "classification": "interpreted_runtime_helper_boundary", + "reason": "Base config shells stay interpreted and currently delegate statement defaults, driver feature seeding, and pool helpers to another interpreted runtime helper layer.", + }, + { + "from_module": "sqlspec/config.py", + "to_module": "sqlspec/utils/module_loader.py", + "sites": [ + {"line": 22, "symbol": "ensure_pyarrow import"}, + {"line": 824, "symbol": "_build_storage_capabilities"}, + {"line": 828, "symbol": "_dependency_available(ensure_pyarrow)"}, + ], + "classification": "interpreted_optional_dependency_boundary", + "reason": "Storage capability detection remains interpreted because it probes optional PyArrow availability at runtime.", + }, +) + + +STORAGE_ARROW_BOUNDARIES: tuple[dict[str, Any], ...] = ( + { + "from_module": "sqlspec/storage/pipeline.py", + "to_module": "sqlspec/storage/_utils.py", + "sites": [ + {"line": 13, "symbol": "import_pyarrow/import_pyarrow_csv/import_pyarrow_parquet"}, + {"line": 211, "symbol": "_encode_arrow_payload"}, + {"line": 256, "symbol": "_decode_arrow_payload"}, + {"line": 357, "symbol": "SyncStoragePipeline.write_arrow"}, + {"line": 382, "symbol": "SyncStoragePipeline.read_arrow"}, + {"line": 500, "symbol": "AsyncStoragePipeline.write_arrow"}, + {"line": 578, "symbol": "AsyncStoragePipeline.read_arrow_async"}, + ], + "classification": "interpreted_to_interpreted_arrow_boundary", + "reason": "Pipeline orchestration is still interpreted and delegates Arrow imports/codecs to `_utils.py`.", + }, + { + "from_module": "sqlspec/storage/_utils.py", + "to_module": "sqlspec/utils/module_loader.py", + "sites": [ + {"line": 5, "symbol": "ensure_pyarrow import"}, + {"line": 18, "symbol": "import_pyarrow"}, + {"line": 31, "symbol": "import_pyarrow_parquet"}, + {"line": 44, "symbol": "import_pyarrow_csv"}, + ], + "classification": "interpreted_optional_dependency_boundary", + "reason": "Arrow helpers remain isolated behind optional-dependency probes in `module_loader.py`.", + }, + { + "from_module": "sqlspec/utils/serializers.py", + "to_module": "sqlspec/_serialization.py", + "sites": [ + {"line": 11, "symbol": "decode_json/encode_json import"}, + {"line": 103, "symbol": "to_json"}, + {"line": 126, "symbol": "from_json"}, + ], + "classification": "compiled_to_interpreted_json_boundary", + "reason": "Compiled serializer helpers still terminate in the interpreted fallback serializers defined in `_serialization.py`.", + }, +) + + +ANY_AUDIT_SEAMS: tuple[dict[str, Any], ...] = ( + { + "module": "sqlspec/config.py", + "line": 86, + "symbol": "_DriverFeatureHookWrapper.__init__", + "annotation": "Callable[..., Any]", + "reason": "Lifecycle hook callbacks accept heterogeneous driver/pool/session payloads.", + }, + { + "module": "sqlspec/config.py", + "line": 107, + "symbol": "LifecycleConfig", + "annotation": "Callable[[Any], None] and query hooks with dict[str, Any]", + "reason": "Observability lifecycle hooks bridge raw driver objects and event payload maps.", + }, + { + "module": "sqlspec/_serialization.py", + "line": 23, + "symbol": "_type_to_string", + "annotation": "Any -> Any", + "reason": "Serializer fallback path handles arbitrary runtime values and optional third-party model types.", + }, + { + "module": "sqlspec/_serialization.py", + "line": 274, + "symbol": "encode_json", + "annotation": "data: Any", + "reason": "Top-level JSON encoding surface is intentionally untyped because it serves every adapter/runtime layer.", + }, + { + "module": "sqlspec/storage/pipeline.py", + "line": 198, + "symbol": "_encode_row_payload", + "annotation": "list[Any]", + "reason": "Storage bridge accepts pre-serialized row payloads without schema specialization.", + }, + { + "module": "sqlspec/storage/pipeline.py", + "line": 216, + "symbol": "_encode_arrow_payload", + "annotation": "write_options: dict[str, Any] | None", + "reason": "CSV/Parquet writer options pass backend-specific dictionaries through unchanged.", + }, + { + "module": "sqlspec/adapters/psqlpy/config.py", + "line": 79, + "symbol": "PsqlpyPoolParams.configure", + "annotation": "Callable[..., Any]", + "reason": "psqlpy exposes raw driver configure callbacks that are opaque to the shared config shell.", + }, + { + "module": "sqlspec/adapters/psqlpy/config.py", + "line": 126, + "symbol": "_PsqlpySessionFactory._ctx", + "annotation": "Any | None", + "reason": "Pool acquire context objects are driver-owned and not weakref/Protocol-friendly.", + }, +) + + +EXCLUSION_REVALIDATION_SEED: dict[str, dict[str, str]] = { + "sqlspec/utils/arrow_helpers.py": { + "bucket": "hard_block", + "reason": "Direct PyArrow table/batch boundary with prior segfault history.", + }, + "sqlspec/adapters/**/data_dictionary.py": { + "bucket": "hard_block", + "reason": "Still carries native_class=False and inline cache patterns to avoid mypyc crashes.", + }, + "sqlspec/builder/_vector_expressions.py": { + "bucket": "helper_split", + "reason": "Current sqlglot Expression no longer has the old metaclass concern, but registration side effects remain.", + }, + "sqlspec/data_dictionary/_loader.py": { + "bucket": "helper_split", + "reason": "Path discovery is the risky piece; cache/query wrapper logic is otherwise straightforward.", + }, + "sqlspec/dialects/**": { + "bucket": "low_roi", + "reason": "Dialect metaclass/plugin surfaces remain mostly registration code, not hot loops.", + }, + "sqlspec/observability/_formatting.py": { + "bucket": "low_roi", + "reason": "Small logging formatter module with negligible performance upside.", + }, + "sqlspec/migrations/commands.py": { + "bucket": "low_roi", + "reason": "Large CLI/orchestration shell with dynamic inspection and Rich output, not a hot path.", + }, +} + + +def load_mypyc_patterns(root: Path) -> tuple[list[str], list[str]]: + """Load mypyc include/exclude globs from pyproject.toml.""" + + config = tomllib.loads((root / "pyproject.toml").read_text()) + mypyc_config = config["tool"]["hatch"]["build"]["targets"]["wheel"]["hooks"]["mypyc"] + return list(mypyc_config["include"]), list(mypyc_config["exclude"]) + + +def classify_module(module_path: str, include_patterns: list[str], exclude_patterns: list[str]) -> str: + """Classify a module as currently compiled or interpreted.""" + + included = any(fnmatch(module_path, pattern) for pattern in include_patterns) + excluded = any(fnmatch(module_path, pattern) for pattern in exclude_patterns) + return "compiled" if included and not excluded else "interpreted" + + +def _module_path_from_file(root: Path, file_path: Path) -> str: + return str(file_path.relative_to(root)).replace("\\", "/") + + +def _read_ast(file_path: Path) -> ast.Module: + return ast.parse(file_path.read_text(), filename=str(file_path)) + + +def collect_adapter_core_boundaries(root: Path) -> list[dict[str, Any]]: + """Collect adapter config.py imports that cross into core.py helpers.""" + + include_patterns, exclude_patterns = load_mypyc_patterns(root) + boundaries: list[dict[str, Any]] = [] + + for config_path in sorted((root / "sqlspec" / "adapters").glob("*/config.py")): + module_path = _module_path_from_file(root, config_path) + tree = _read_ast(config_path) + + for node in tree.body: + if not isinstance(node, ast.ImportFrom) or node.module is None: + continue + if not node.module.startswith("sqlspec.adapters.") or not node.module.endswith(".core"): + continue + + target_module = f"{node.module.replace('.', '/')}.py" + imported_symbols = sorted(alias.name for alias in node.names if alias.name != "*") + boundaries.append({ + "from_module": module_path, + "from_status": classify_module(module_path, include_patterns, exclude_patterns), + "to_module": target_module, + "to_status": classify_module(target_module, include_patterns, exclude_patterns), + "import_line": node.lineno, + "helpers": imported_symbols, + "classification": "interpreted_to_compiled" + if classify_module(module_path, include_patterns, exclude_patterns) == "interpreted" + and classify_module(target_module, include_patterns, exclude_patterns) == "compiled" + else "same_mode_import", + }) + + return boundaries + + +def collect_serializer_bridges(root: Path) -> list[dict[str, Any]]: + """Collect compiled helper modules that import JSON helpers from utils.serializers.""" + + include_patterns, exclude_patterns = load_mypyc_patterns(root) + bridges: list[dict[str, Any]] = [] + + for module_path in sorted(str(path.relative_to(root)).replace("\\", "/") for path in (root / "sqlspec").rglob("*.py")): + if classify_module(module_path, include_patterns, exclude_patterns) != "compiled": + continue + + file_path = root / module_path + tree = _read_ast(file_path) + for node in tree.body: + if not isinstance(node, ast.ImportFrom) or node.module != "sqlspec.utils.serializers": + continue + imported_symbols = sorted(alias.name for alias in node.names if alias.name != "*") + bridges.append({ + "from_module": module_path, + "from_status": "compiled", + "via_module": "sqlspec/utils/serializers.py", + "via_status": classify_module("sqlspec/utils/serializers.py", include_patterns, exclude_patterns), + "terminal_module": "sqlspec/_serialization.py", + "terminal_status": classify_module("sqlspec/_serialization.py", include_patterns, exclude_patterns), + "import_line": node.lineno, + "helpers": imported_symbols, + "classification": "compiled_to_interpreted_json_boundary", + }) + break + + return bridges + + +def build_boundary_map(root: Path | None = None) -> dict[str, Any]: + """Build the current hot boundary map for mypyc rollout planning.""" + + project_root = root or Path(__file__).resolve().parents[2] + include_patterns, exclude_patterns = load_mypyc_patterns(project_root) + + config_boundaries = [ + { + **entry, + "from_status": classify_module(entry["from_module"], include_patterns, exclude_patterns), + "to_status": classify_module(entry["to_module"], include_patterns, exclude_patterns), + } + for entry in CONFIG_RUNTIME_BOUNDARIES + ] + storage_boundaries = [ + { + **entry, + "from_status": classify_module(entry["from_module"], include_patterns, exclude_patterns), + "to_status": classify_module(entry["to_module"], include_patterns, exclude_patterns), + } + for entry in STORAGE_ARROW_BOUNDARIES + ] + adapter_boundaries = collect_adapter_core_boundaries(project_root) + serializer_bridges = collect_serializer_bridges(project_root) + + interpreted_to_compiled_adapter_edges = [ + entry for entry in adapter_boundaries if entry["classification"] == "interpreted_to_compiled" + ] + + return { + "summary": { + "config_runtime_edges": len(config_boundaries), + "adapter_config_core_edges": len(adapter_boundaries), + "interpreted_to_compiled_adapter_edges": len(interpreted_to_compiled_adapter_edges), + "serializer_bridges": len(serializer_bridges), + "storage_arrow_edges": len(storage_boundaries), + "any_audit_seams": len(ANY_AUDIT_SEAMS), + "exclusion_revalidation_buckets": len(EXCLUSION_REVALIDATION_SEED), + }, + "config_runtime_boundaries": config_boundaries, + "adapter_config_core_boundaries": adapter_boundaries, + "serializer_bridges": serializer_bridges, + "storage_arrow_boundaries": storage_boundaries, + "any_audit_matrix": list(ANY_AUDIT_SEAMS), + "exclusion_revalidation_seed": EXCLUSION_REVALIDATION_SEED, + } + + +if __name__ == "__main__": # pragma: no cover + import json + + print(json.dumps(build_boundary_map(), indent=2)) From 78177f5f52d3c3c4dcee692a2d61dd31f02924fc Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 19:17:17 +0000 Subject: [PATCH 19/39] build(mypyc): design exclusion helper splits --- tests/unit/utils/test_mypyc_boundary_map.py | 30 ++++++ tools/scripts/mypyc_boundary_map.py | 105 ++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/tests/unit/utils/test_mypyc_boundary_map.py b/tests/unit/utils/test_mypyc_boundary_map.py index cbab93b0d..b098e456e 100644 --- a/tests/unit/utils/test_mypyc_boundary_map.py +++ b/tests/unit/utils/test_mypyc_boundary_map.py @@ -28,6 +28,8 @@ def test_build_boundary_map_reports_expected_hot_groups() -> None: "storage_arrow_edges": 3, "any_audit_seams": 8, "exclusion_revalidation_buckets": 7, + "helper_split_designs": 3, + "rollout_feedback_entries": 3, } @@ -119,3 +121,31 @@ def test_build_boundary_map_seeds_exclusion_revalidation_buckets() -> None: assert exclusion_seed["sqlspec/builder/_vector_expressions.py"]["bucket"] == "helper_split" assert exclusion_seed["sqlspec/data_dictionary/_loader.py"]["bucket"] == "helper_split" assert exclusion_seed["sqlspec/dialects/**"]["bucket"] == "low_roi" + + +def test_build_boundary_map_records_helper_split_designs_and_rollout_feedback() -> None: + module = _load_mypyc_boundary_map_module() + + boundary_map = module.build_boundary_map() + + helper_splits = {entry["surface"]: entry for entry in boundary_map["helper_split_designs"]} + assert helper_splits["sqlspec/builder/_vector_expressions.py"]["extract_module"] == ( + "sqlspec/builder/_vector_renderers.py" + ) + assert "render_postgres_vector_distance" in helper_splits["sqlspec/builder/_vector_expressions.py"]["safe_symbols"] + assert "_register_with_sqlglot" in helper_splits["sqlspec/builder/_vector_expressions.py"][ + "keep_interpreted_symbols" + ] + assert helper_splits["sqlspec/data_dictionary/_loader.py"]["compile_target"] == ( + "sqlspec/data_dictionary/_loader_core.py" + ) + assert "ensure_dialect_path" in helper_splits["sqlspec/data_dictionary/_loader.py"]["safe_symbols"] + assert helper_splits["sqlspec/adapters/**/data_dictionary.py"]["extract_module"] == ( + "sqlspec/data_dictionary/_plans.py" + ) + assert "build_query_plan" in helper_splits["sqlspec/adapters/**/data_dictionary.py"]["safe_symbols"] + + rollout_feedback = {entry["task_id"]: entry["recommendation"] for entry in boundary_map["rollout_feedback"]} + assert "renderer helper module" in rollout_feedback["sqlspec-k1a.4"] + assert "Arrow boundaries interpreted" in rollout_feedback["sqlspec-k1a.5"] + assert "_loader_core.py" in rollout_feedback["sqlspec-k1a.6.3"] diff --git a/tools/scripts/mypyc_boundary_map.py b/tools/scripts/mypyc_boundary_map.py index cd2e61336..c117adbb2 100644 --- a/tools/scripts/mypyc_boundary_map.py +++ b/tools/scripts/mypyc_boundary_map.py @@ -14,6 +14,8 @@ "ANY_AUDIT_SEAMS", "CONFIG_RUNTIME_BOUNDARIES", "EXCLUSION_REVALIDATION_SEED", + "HELPER_SPLIT_DESIGNS", + "ROLLOUT_FEEDBACK", "STORAGE_ARROW_BOUNDARIES", "build_boundary_map", "classify_module", @@ -187,6 +189,105 @@ } +HELPER_SPLIT_DESIGNS: tuple[dict[str, Any], ...] = ( + { + "surface": "sqlspec/builder/_vector_expressions.py", + "split_kind": "extract_pure_renderers", + "extract_module": "sqlspec/builder/_vector_renderers.py", + "compile_target": "sqlspec/builder/_vector_renderers.py", + "safe_symbols": ( + "_normalize_metric_name", + "_coerce_oracle_vector_literal", + "_maybe_wrap_mysql_vector_literal", + "_duckdb_target_type", + "render_postgres_vector_distance", + "render_mysql_vector_distance", + "render_oracle_vector_distance", + "render_bigquery_vector_distance", + "render_duckdb_vector_distance", + "render_generic_vector_distance", + ), + "keep_interpreted_symbols": ( + "VectorDistance", + "_register_with_sqlglot", + "_vector_distance_sql_base", + "_vector_distance_sql_postgres", + "_vector_distance_sql_mysql", + "_vector_distance_sql_oracle", + "_vector_distance_sql_bigquery", + "_vector_distance_sql_spanner", + "_vector_distance_sql_duckdb", + ), + "reason": "The expression subclass and sqlglot registration side effects remain unsafe, but dialect-specific string rendering and metric normalization are pure helpers.", + "feeds_chapter": "adapter-runtime-boundaries", + }, + { + "surface": "sqlspec/data_dictionary/_loader.py", + "split_kind": "extract_loader_state_and_path_resolution", + "extract_module": "sqlspec/data_dictionary/_loader_core.py", + "compile_target": "sqlspec/data_dictionary/_loader_core.py", + "safe_symbols": ( + "build_sql_dir_path", + "ensure_dialect_path", + "list_sql_dialects", + "get_or_create_loader", + "mark_dialect_loaded", + "is_dialect_loaded", + ), + "keep_interpreted_symbols": ( + "SQL_DIR", + "DataDictionaryLoader._ensure_dialect_loaded", + "DataDictionaryLoader.get_query", + "DataDictionaryLoader.get_query_text", + "get_data_dictionary_loader", + ), + "reason": "Path discovery and loader-cache mutation are straightforward helpers; keep singleton lifecycle and SQLFileLoader orchestration interpreted.", + "feeds_chapter": "exclusion-revalidation", + }, + { + "surface": "sqlspec/adapters/**/data_dictionary.py", + "split_kind": "extract_query_plans_and_version_resolution", + "extract_module": "sqlspec/data_dictionary/_plans.py", + "compile_target": "sqlspec/data_dictionary/_plans.py", + "safe_symbols": ( + "resolve_schema_name", + "resolve_feature_flag_from_version", + "resolve_optimal_type_from_version", + "build_query_plan", + "build_sqlite_query_text_plan", + "collect_index_columns_metadata", + ), + "keep_interpreted_symbols": ( + "SyncDataDictionaryBase subclasses", + "AsyncDataDictionaryBase subclasses", + "get_version", + "get_tables", + "get_columns", + "get_indexes", + "get_foreign_keys", + ), + "reason": "Cross-module inheritance and driver I/O stay unsafe, but repeated schema resolution, feature gating, and query-plan assembly can be centralized into compiled helpers.", + "feeds_chapter": "storage-runtime-expansion", + }, +) + + +ROLLOUT_FEEDBACK: tuple[dict[str, str], ...] = ( + { + "task_id": "sqlspec-k1a.4", + "recommendation": "Do not reopen adapter runtime compilation for dialect/vector registration; only revisit if a pure renderer helper module is extracted first.", + }, + { + "task_id": "sqlspec-k1a.5", + "recommendation": "Keep Arrow boundaries interpreted and only route data-dictionary query-plan helpers toward future storage/runtime widening.", + }, + { + "task_id": "sqlspec-k1a.6.3", + "recommendation": "Prioritize `_loader_core.py` and shared data-dictionary plan helpers before any file-level exclusion removal.", + }, +) + + def load_mypyc_patterns(root: Path) -> tuple[list[str], list[str]]: """Load mypyc include/exclude globs from pyproject.toml.""" @@ -315,6 +416,8 @@ def build_boundary_map(root: Path | None = None) -> dict[str, Any]: "storage_arrow_edges": len(storage_boundaries), "any_audit_seams": len(ANY_AUDIT_SEAMS), "exclusion_revalidation_buckets": len(EXCLUSION_REVALIDATION_SEED), + "helper_split_designs": len(HELPER_SPLIT_DESIGNS), + "rollout_feedback_entries": len(ROLLOUT_FEEDBACK), }, "config_runtime_boundaries": config_boundaries, "adapter_config_core_boundaries": adapter_boundaries, @@ -322,6 +425,8 @@ def build_boundary_map(root: Path | None = None) -> dict[str, Any]: "storage_arrow_boundaries": storage_boundaries, "any_audit_matrix": list(ANY_AUDIT_SEAMS), "exclusion_revalidation_seed": EXCLUSION_REVALIDATION_SEED, + "helper_split_designs": list(HELPER_SPLIT_DESIGNS), + "rollout_feedback": list(ROLLOUT_FEEDBACK), } From c81e16db253955865e892de54a1deec9c369e144 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 23:06:17 +0000 Subject: [PATCH 20/39] refactor: relocate PGO training script to `tools/scripts/pgo_training.py` and update its invocation in tests and Makefile. --- Makefile | 2 +- pyproject.toml | 1 - tests/unit/test_pgo_training.py | 6 ++-- .../scripts/pgo_training.py | 34 ++++++------------- 4 files changed, 15 insertions(+), 28 deletions(-) rename sqlspec/_pgo_training.py => tools/scripts/pgo_training.py (78%) diff --git a/Makefile b/Makefile index 72052d292..20bea0830 100644 --- a/Makefile +++ b/Makefile @@ -365,7 +365,7 @@ pgo-local: ## Run full three-stage PGO @echo "${OK} Instrumented wheel built" @echo "${INFO} Stage 2: Running training workload..." @uv pip install dist/*.whl --force-reinstall --no-deps >/dev/null 2>&1 - @.venv/bin/python -m sqlspec._pgo_training + @.venv/bin/python tools/scripts/pgo_training.py @echo "${OK} Training complete" @rm -rf dist/ $(PGO_BUILD_DIR)/build $(PGO_BUILD_DIR)/tmp @echo "${INFO} Stage 3: Building PGO-optimized wheel..." diff --git a/pyproject.toml b/pyproject.toml index 1abd34ea4..a6eb731bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,7 +200,6 @@ exclude = [ "sqlspec/observability/_formatting.py", # Inherits from non-compiled logging.Formatter "sqlspec/utils/arrow_helpers.py", # Arrow operations cause segfaults when compiled "sqlspec/utils/profiling.py", # Uses sys.setprofile (dynamic, not mypyc compatible) - "sqlspec/_pgo_training.py", # PGO training script — runs against compiled modules ] include = [ "sqlspec/core/**/*.py", # Core module diff --git a/tests/unit/test_pgo_training.py b/tests/unit/test_pgo_training.py index 544dc9f6e..f4755c4f9 100644 --- a/tests/unit/test_pgo_training.py +++ b/tests/unit/test_pgo_training.py @@ -1,13 +1,13 @@ """Tests for the PGO training script.""" -from __future__ import annotations - import subprocess import sys +from pathlib import Path def test_pgo_training_runs_without_error() -> None: """Training script must exit cleanly and complete in reasonable time.""" - result = subprocess.run([sys.executable, "-m", "sqlspec._pgo_training"], capture_output=True, timeout=90) + script = Path(__file__).resolve().parents[2] / "tools" / "scripts" / "pgo_training.py" + result = subprocess.run([sys.executable, str(script)], capture_output=True, timeout=90) assert result.returncode == 0, f"Training failed: {result.stderr.decode()}" assert b"PGO training complete" in result.stdout diff --git a/sqlspec/_pgo_training.py b/tools/scripts/pgo_training.py similarity index 78% rename from sqlspec/_pgo_training.py rename to tools/scripts/pgo_training.py index ebd9bbe0e..c39307bcf 100644 --- a/sqlspec/_pgo_training.py +++ b/tools/scripts/pgo_training.py @@ -1,13 +1,11 @@ """PGO training workload for sqlspec. Exercises hot paths to generate compiler profile data for Profile-Guided Optimization. -This module is excluded from mypyc compilation — it runs against compiled modules. +This script runs against compiled modules and should not be packaged with the library. -Run as: python -m sqlspec._pgo_training +Run as: python tools/scripts/pgo_training.py """ -from __future__ import annotations - import sys import tempfile import time @@ -32,12 +30,12 @@ def _train_text_transforms() -> None: "multi_word_column_name_here", ] for _ in range(25000): - for s in inputs: - camelize(s) - snake_case(s) - pascalize(s) - kebabize(s) - slugify(s) + for value in inputs: + camelize(value) + snake_case(value) + pascalize(value) + kebabize(value) + slugify(value) def _train_schema_transforms() -> None: @@ -55,11 +53,7 @@ def _train_schema_transforms() -> None: def _train_sqlite_sync() -> None: - """Exercise sync driver via in-memory SQLite. - - Covers: driver initialization, execute, execute_many, fetch, fetch_one_or_none, - query cache warm/miss/evict paths, parameter processing, and type conversion. - """ + """Exercise sync driver via in-memory SQLite.""" from sqlspec import SQLSpec from sqlspec.adapters.sqlite.config import SqliteConfig @@ -77,14 +71,10 @@ def _train_sqlite_sync() -> None: with spec.provide_session(config) as session: session.execute(create_sql) - # Write heavy: bulk insert via execute_many data = [(f"value_{i}",) for i in range(2000)] session.execute_many(insert_sql, data) - - # Read heavy: fetch all rows session.fetch(select_all) - # Repeated queries: hammer the query cache for i in range(10000): session.fetch_one_or_none(select_by, (f"value_{i % 100}",)) @@ -92,7 +82,6 @@ def _train_sqlite_sync() -> None: finally: Path(tmp_name).unlink() - # Second pass: focus on query cache warm path with more iterations with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: tmp_name = tmp.name @@ -102,7 +91,6 @@ def _train_sqlite_sync() -> None: session.execute(create_sql) for i in range(500): session.execute(insert_sql, (f"val_{i}",)) - # 50k queries to thoroughly exercise the query cache fast path for i in range(50000): session.fetch_one_or_none(select_by, (f"val_{i % 500}",)) @@ -115,7 +103,7 @@ def main() -> None: """Run all PGO training workloads.""" start = time.perf_counter() - workloads: list[tuple[str, object]] = [ + workloads = [ ("text_transforms", _train_text_transforms), ("schema_transforms", _train_schema_transforms), ("sqlite_sync", _train_sqlite_sync), @@ -123,7 +111,7 @@ def main() -> None: for name, fn in workloads: t0 = time.perf_counter() - fn() # type: ignore[operator] + fn() elapsed = time.perf_counter() - t0 print(f" {name}: {elapsed:.2f}s") # noqa: T201 From 40cef511ed5a1386e8a676ae46660be839067671 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 14 Mar 2026 23:26:41 +0000 Subject: [PATCH 21/39] chore(perf): capture current surface inventory --- performance_surface_inventory.json | 285 ++++++++++++++++++++++ tests/unit/test_perf_surface_inventory.py | 126 ++++++++++ 2 files changed, 411 insertions(+) create mode 100644 performance_surface_inventory.json create mode 100644 tests/unit/test_perf_surface_inventory.py diff --git a/performance_surface_inventory.json b/performance_surface_inventory.json new file mode 100644 index 000000000..c19cb9a44 --- /dev/null +++ b/performance_surface_inventory.json @@ -0,0 +1,285 @@ +{ + "task_id": "sqlspec-iab.6.1", + "captured_on": "2026-03-14", + "summary": { + "adapter_count": 16, + "bench_harness_driver_count": 4, + "regression_gate_driver_count": 1 + }, + "benchmark_scripts": { + "bench.py": { + "drivers": [ + "aiosqlite", + "asyncpg", + "duckdb", + "sqlite" + ], + "core_scenarios": [ + "initialization", + "write_heavy", + "read_heavy", + "iterative_inserts", + "repeated_queries" + ], + "extended_scenarios": [ + "dict_key_transform", + "schema_mapping", + "complex_parameters", + "thin_path_stress" + ] + }, + "bench_gate.py": { + "drivers": [ + "sqlite" + ], + "gate_scenarios": [ + "iterative_inserts", + "repeated_queries", + "write_heavy", + "read_heavy" + ] + }, + "bench_subsystems.py": { + "drivers": [ + "sqlite" + ], + "focus_areas": [ + "SQL object construction", + "SQL.compile()", + "query-cache prepare/lookup", + "parameter processing", + "cursor lifecycle", + "storage/runtime helpers" + ] + } + }, + "build_workflows": { + "publish.yml": { + "tier": "release", + "builds_mypyc_wheels": true, + "exercises_three_stage_pgo": true, + "uses_packaged_training_module": true, + "notes": "Release wheels still run `python -m sqlspec._pgo_training` during the three-stage PGO flow." + }, + "test-build.yml": { + "tier": "pull_request/workflow_dispatch", + "builds_mypyc_wheels": true, + "exercises_three_stage_pgo": false, + "uses_packaged_training_module": false, + "notes": "PR build validation compiles mypyc wheels but does not run reduced three-stage PGO parity." + } + }, + "external_infra": { + "dma_accelerator": { + "reusable_lifecycle_modules": [ + "postgres", + "mysql", + "oracle", + "sqlserver" + ], + "root_cli_registered_families": [ + "postgres" + ], + "notes": "Reusable lifecycle code exists for PostgreSQL, MySQL, Oracle, and SQL Server, but the root CLI currently mounts only PostgreSQL. MySQL and Oracle still need SQLSpec-facing exposure work before they can be treated as directly reusable perf infrastructure." + } + }, + "adapters": [ + { + "name": "adbc", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/adbc/config.py", + "integration_root": "tests/integration/adapters/adbc", + "current_infra_family": "bridge/underlying-engine", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage spans PostgreSQL, SQLite, and DuckDB backends, but no dedicated perf-harness workload exists yet." + }, + { + "name": "aiosqlite", + "execution_surfaces": [ + "async" + ], + "config_surface": "sqlspec/adapters/aiosqlite/config.py", + "integration_root": "tests/integration/adapters/aiosqlite", + "current_infra_family": "file-local", + "bench_harness_driver": true, + "regression_gate_driver": false, + "notes": "Covered in bench.py, but not in bench_gate.py or bench_subsystems.py." + }, + { + "name": "asyncmy", + "execution_surfaces": [ + "async" + ], + "config_surface": "sqlspec/adapters/asyncmy/config.py", + "integration_root": "tests/integration/adapters/asyncmy", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "MySQL-family integration coverage exists, but there is no current perf-harness workload." + }, + { + "name": "asyncpg", + "execution_surfaces": [ + "async" + ], + "config_surface": "sqlspec/adapters/asyncpg/config.py", + "integration_root": "tests/integration/adapters/asyncpg", + "current_infra_family": "server-backed", + "bench_harness_driver": true, + "regression_gate_driver": false, + "notes": "Bench coverage exists in bench.py, but the regression gate remains SQLite-only." + }, + { + "name": "bigquery", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/bigquery/config.py", + "integration_root": "tests/integration/adapters/bigquery", + "current_infra_family": "cloud-managed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists, but there is no current perf or build-path matrix entry for the adapter." + }, + { + "name": "cockroach_asyncpg", + "execution_surfaces": [ + "async" + ], + "config_surface": "sqlspec/adapters/cockroach_asyncpg/config.py", + "integration_root": "tests/integration/adapters/cockroach_asyncpg", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Cockroach async coverage exists in integration tests only." + }, + { + "name": "cockroach_psycopg", + "execution_surfaces": [ + "sync", + "async" + ], + "config_surface": "sqlspec/adapters/cockroach_psycopg/config.py", + "integration_root": "tests/integration/adapters/cockroach_psycopg", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Cockroach psycopg integration coverage exists, but no perf harness tracks either sync or async surfaces." + }, + { + "name": "duckdb", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/duckdb/config.py", + "integration_root": "tests/integration/adapters/duckdb", + "current_infra_family": "file-local", + "bench_harness_driver": true, + "regression_gate_driver": false, + "notes": "Bench coverage exists in bench.py, but no driver-specific gate or subsystem micro-benchmark exists." + }, + { + "name": "mock", + "execution_surfaces": [ + "sync", + "async" + ], + "config_surface": "sqlspec/adapters/mock/config.py", + "integration_root": null, + "current_infra_family": "mock-only", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Useful for isolated behavior tests, but not a real performance validation surface." + }, + { + "name": "mysqlconnector", + "execution_surfaces": [ + "sync", + "async" + ], + "config_surface": "sqlspec/adapters/mysqlconnector/config.py", + "integration_root": "tests/integration/adapters/mysqlconnector", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists for sync and async paths, but the perf harness does not exercise them yet." + }, + { + "name": "oracledb", + "execution_surfaces": [ + "sync", + "async" + ], + "config_surface": "sqlspec/adapters/oracledb/config.py", + "integration_root": "tests/integration/adapters/oracledb", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists for sync and async Oracle paths, but there is no Oracle perf workload in the harness." + }, + { + "name": "psqlpy", + "execution_surfaces": [ + "async" + ], + "config_surface": "sqlspec/adapters/psqlpy/config.py", + "integration_root": "tests/integration/adapters/psqlpy", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists, but the current perf harness does not represent the adapter." + }, + { + "name": "psycopg", + "execution_surfaces": [ + "sync", + "async" + ], + "config_surface": "sqlspec/adapters/psycopg/config.py", + "integration_root": "tests/integration/adapters/psycopg", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists for sync and async psycopg paths, but the perf harness still skips them." + }, + { + "name": "pymysql", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/pymysql/config.py", + "integration_root": "tests/integration/adapters/pymysql", + "current_infra_family": "server-backed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists, but no current perf workload or gate covers the adapter." + }, + { + "name": "spanner", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/spanner/config.py", + "integration_root": "tests/integration/adapters/spanner", + "current_infra_family": "cloud-managed", + "bench_harness_driver": false, + "regression_gate_driver": false, + "notes": "Integration coverage exists, but cloud-managed validation is not represented in the current perf harness." + }, + { + "name": "sqlite", + "execution_surfaces": [ + "sync" + ], + "config_surface": "sqlspec/adapters/sqlite/config.py", + "integration_root": "tests/integration/adapters/sqlite", + "current_infra_family": "file-local", + "bench_harness_driver": true, + "regression_gate_driver": true, + "notes": "SQLite is the only adapter covered by bench.py, bench_gate.py, and bench_subsystems.py." + } + ] +} diff --git a/tests/unit/test_perf_surface_inventory.py b/tests/unit/test_perf_surface_inventory.py new file mode 100644 index 000000000..0fa1bb1a4 --- /dev/null +++ b/tests/unit/test_perf_surface_inventory.py @@ -0,0 +1,126 @@ +"""Tests for the workload-matrix foundation surface inventory.""" + +import json +import re +from pathlib import Path +from typing import TypedDict, cast + + +class BenchmarkScriptInventory(TypedDict): + drivers: list[str] + + +class BuildWorkflowInventory(TypedDict): + exercises_three_stage_pgo: bool + uses_packaged_training_module: bool + + +class AdapterInventory(TypedDict): + name: str + execution_surfaces: list[str] + config_surface: str + integration_root: str | None + current_infra_family: str + bench_harness_driver: bool + regression_gate_driver: bool + notes: str + + +class PerfSurfaceInventory(TypedDict): + benchmark_scripts: dict[str, BenchmarkScriptInventory] + build_workflows: dict[str, BuildWorkflowInventory] + adapters: list[AdapterInventory] + + +REPO_ROOT = Path(__file__).resolve().parents[2] +INVENTORY_PATH = REPO_ROOT / "performance_surface_inventory.json" +BENCH_DRIVER_PATTERN = re.compile(r'\("(?:raw|sqlspec|sqlalchemy)",\s*"(?P[^"]+)",\s*"[^"]+"\)') +GATE_DRIVER_PATTERN = re.compile(r'run_benchmark\("(?P[^"]+)"') + + +def _load_inventory() -> PerfSurfaceInventory: + return cast(PerfSurfaceInventory, json.loads(INVENTORY_PATH.read_text())) + + +def _bench_drivers() -> list[str]: + bench_text = (REPO_ROOT / "tools" / "scripts" / "bench.py").read_text() + return sorted({match.group("driver") for match in BENCH_DRIVER_PATTERN.finditer(bench_text)}) + + +def test_perf_surface_inventory_covers_all_adapter_configs_and_integration_roots() -> None: + inventory = _load_inventory() + actual_adapters = sorted(path.parent.name for path in (REPO_ROOT / "sqlspec" / "adapters").glob("*/config.py")) + inventory_adapters = sorted(entry["name"] for entry in inventory["adapters"]) + + assert INVENTORY_PATH.is_file() + assert inventory_adapters == actual_adapters + + allowed_families = { + "bridge/underlying-engine", + "cloud-managed", + "file-local", + "mock-only", + "server-backed", + } + + for entry in inventory["adapters"]: + assert entry["execution_surfaces"] + assert set(entry["execution_surfaces"]) <= {"async", "sync"} + assert entry["current_infra_family"] in allowed_families + assert entry["notes"] + assert (REPO_ROOT / entry["config_surface"]).is_file() + + if entry["integration_root"] is None: + assert entry["name"] == "mock" + else: + assert (REPO_ROOT / entry["integration_root"]).is_dir() + + +def test_perf_surface_inventory_records_expected_execution_surfaces_and_perf_state() -> None: + inventory = _load_inventory() + lookup = {entry["name"]: entry for entry in inventory["adapters"]} + + assert lookup["sqlite"]["execution_surfaces"] == ["sync"] + assert lookup["aiosqlite"]["execution_surfaces"] == ["async"] + assert lookup["psycopg"]["execution_surfaces"] == ["sync", "async"] + assert lookup["mysqlconnector"]["execution_surfaces"] == ["sync", "async"] + assert lookup["oracledb"]["execution_surfaces"] == ["sync", "async"] + assert lookup["adbc"]["current_infra_family"] == "bridge/underlying-engine" + assert lookup["bigquery"]["current_infra_family"] == "cloud-managed" + assert lookup["spanner"]["current_infra_family"] == "cloud-managed" + assert lookup["mock"]["current_infra_family"] == "mock-only" + + bench_drivers = {entry["name"] for entry in inventory["adapters"] if entry["bench_harness_driver"]} + gate_drivers = {entry["name"] for entry in inventory["adapters"] if entry["regression_gate_driver"]} + + assert bench_drivers == {"aiosqlite", "asyncpg", "duckdb", "sqlite"} + assert gate_drivers == {"sqlite"} + + +def test_perf_surface_inventory_matches_current_perf_scripts_and_build_workflows() -> None: + inventory = _load_inventory() + bench_gate_text = (REPO_ROOT / "tools" / "scripts" / "bench_gate.py").read_text() + bench_subsystems_text = (REPO_ROOT / "tools" / "scripts" / "bench_subsystems.py").read_text() + publish_text = (REPO_ROOT / ".github" / "workflows" / "publish.yml").read_text() + test_build_text = (REPO_ROOT / ".github" / "workflows" / "test-build.yml").read_text() + + assert inventory["benchmark_scripts"]["bench.py"]["drivers"] == _bench_drivers() + + gate_match = GATE_DRIVER_PATTERN.search(bench_gate_text) + assert gate_match is not None + assert inventory["benchmark_scripts"]["bench_gate.py"]["drivers"] == [gate_match.group("driver")] + assert inventory["benchmark_scripts"]["bench_subsystems.py"]["drivers"] == ["sqlite"] + assert "Subsystem Micro-Benchmarks (sqlite)" in bench_subsystems_text + + assert inventory["build_workflows"]["publish.yml"]["uses_packaged_training_module"] == ( + "python -m sqlspec._pgo_training" in publish_text + ) + assert inventory["build_workflows"]["publish.yml"]["exercises_three_stage_pgo"] == ( + "fprofile-generate" in publish_text or "fprofile-instr-generate" in publish_text + ) + assert inventory["build_workflows"]["test-build.yml"]["uses_packaged_training_module"] == ( + "python -m sqlspec._pgo_training" in test_build_text + ) + assert inventory["build_workflows"]["test-build.yml"]["exercises_three_stage_pgo"] == ( + "fprofile-generate" in test_build_text or "fprofile-instr-generate" in test_build_text + ) From cdb29e9b25eeda696a4ea1ffa60d8eddcfd892bf Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 15:23:23 +0000 Subject: [PATCH 22/39] refactor: consolidate data dictionary mixins and inherit them in sync/async base classes --- .pre-commit-config.yaml | 2 +- sqlspec/driver/_async.py | 285 +--------------------------- sqlspec/driver/_common.py | 41 +++-- sqlspec/driver/_sync.py | 287 +---------------------------- tests/unit/core/test_parameters.py | 9 + uv.lock | 126 +++++++------ 6 files changed, 118 insertions(+), 632 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 50d4ac62f..2a7994f9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.15.5" + rev: "v0.15.6" hooks: - id: ruff args: ["--fix"] diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index e02850145..441981ed6 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -1,10 +1,6 @@ """Asynchronous driver protocol implementation.""" -import graphlib -import logging -import re from abc import abstractmethod -from contextlib import suppress from time import perf_counter from typing import TYPE_CHECKING, Any, ClassVar, Final, cast, final, overload @@ -13,18 +9,15 @@ from sqlspec.core import SQL, ProcessedState, StackResult, Statement, create_arrow_result from sqlspec.core.result import DMLResult from sqlspec.core.stack import StackOperation, StatementStack -from sqlspec.data_dictionary._loader import get_data_dictionary_loader -from sqlspec.data_dictionary._registry import get_dialect_config from sqlspec.driver._common import ( - VERSION_GROUPS_MIN_FOR_MINOR, - VERSION_GROUPS_MIN_FOR_PATCH, AsyncExceptionHandler, CommonDriverAttributesMixin, + DataDictionaryDialectMixin, + DataDictionaryMixin, ExecutionResult, StackExecutionObserver, describe_stack_statement, handle_single_row_error, - resolve_db_system, ) from sqlspec.driver._query_cache import CachedQuery from sqlspec.driver._sql_helpers import DEFAULT_PRETTY @@ -37,11 +30,11 @@ create_storage_job, stringify_storage_target, ) -from sqlspec.exceptions import ImproperConfigurationError, SQLFileNotFoundError, StackExecutionError +from sqlspec.exceptions import ImproperConfigurationError, StackExecutionError from sqlspec.storage import AsyncStoragePipeline, StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry from sqlspec.typing import VersionInfo from sqlspec.utils.arrow_helpers import convert_dict_to_arrow_with_schema -from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.logging import get_logger from sqlspec.utils.schema import ValueT, to_value_type from sqlspec.utils.type_guards import has_asdict_method, is_dict_row, is_mapping_like @@ -1805,270 +1798,20 @@ def _create_storage_job( return create_storage_job(produced, provided, status=status) -@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) -class AsyncDataDictionaryBase: +@mypyc_attr(allow_interpreted_subclasses=True) +class AsyncDataDictionaryBase(DataDictionaryDialectMixin, DataDictionaryMixin): """Base class for asynchronous data dictionary implementations. Uses Python-compatible class layouts for cross-module inheritance. Child classes define dialect as a class attribute. """ - _version_cache: "dict[int, VersionInfo | None]" - _version_fetch_attempted: "set[int]" - dialect: "ClassVar[str]" """Dialect identifier. Must be defined by subclasses as a class attribute.""" def __init__(self) -> None: - self._version_cache = {} - self._version_fetch_attempted = set() - - # ───────────────────────────────────────────────────────────────────────────── - # DIALECT SQL METHODS (merged from DialectSQLMixin) - # ───────────────────────────────────────────────────────────────────────────── - - def get_dialect_config(self) -> "DialectConfig": - """Return the dialect configuration for this data dictionary.""" - return get_dialect_config(type(self).dialect) - - def get_query(self, name: str) -> "SQL": - """Return a named SQL query for this dialect.""" - loader = get_data_dictionary_loader() - return loader.get_query(type(self).dialect, name) - - def get_query_text(self, name: str) -> str: - """Return raw SQL text for a named query for this dialect.""" - loader = get_data_dictionary_loader() - return loader.get_query_text(type(self).dialect, name) - - def get_query_text_or_none(self, name: str) -> "str | None": - """Return raw SQL text for a named query or None if missing.""" - try: - return self.get_query_text(name) - except SQLFileNotFoundError: - return None - - def resolve_schema(self, schema: "str | None") -> "str | None": - """Return a schema name using dialect defaults when missing.""" - if schema is not None: - return schema - config = self.get_dialect_config() - return config.default_schema - - def resolve_feature_flag(self, feature: str, version: "VersionInfo | None") -> bool: - """Resolve a feature flag using dialect config and version info.""" - config = self.get_dialect_config() - flag = config.get_feature_flag(feature) - if flag is not None: - return flag - required_version = config.get_feature_version(feature) - if required_version is None or version is None: - return False - return bool(version >= required_version) - - # ───────────────────────────────────────────────────────────────────────────── - # VERSION CACHING METHODS (inlined from DataDictionaryMixin) - # ───────────────────────────────────────────────────────────────────────────── - - def get_cached_version(self, driver_id: int) -> object: - """Get cached version info for a driver. - - Args: - driver_id: The id() of the driver instance. - - Returns: - Tuple of (was_cached, version_info). If was_cached is False, - the caller should fetch the version and call cache_version(). - """ - if driver_id in self._version_fetch_attempted: - return True, self._version_cache.get(driver_id) - return False, None - - def cache_version(self, driver_id: int, version: "VersionInfo | None") -> None: - """Cache version info for a driver. - - Args: - driver_id: The id() of the driver instance. - version: The version info to cache (can be None if detection failed). - """ - self._version_fetch_attempted.add(driver_id) - if version is not None: - self._version_cache[driver_id] = version - - def parse_version_string(self, version_str: str) -> "VersionInfo | None": - """Parse version string into VersionInfo. - - Args: - version_str: Raw version string from database - - Returns: - VersionInfo instance or None if parsing fails - """ - patterns = [r"(\d+)\.(\d+)\.(\d+)", r"(\d+)\.(\d+)", r"(\d+)"] - for pattern in patterns: - match = re.search(pattern, version_str) - if match: - groups = match.groups() - major = int(groups[0]) - minor = int(groups[1]) if len(groups) > VERSION_GROUPS_MIN_FOR_MINOR else 0 - patch = int(groups[2]) if len(groups) > VERSION_GROUPS_MIN_FOR_PATCH else 0 - return VersionInfo(major, minor, patch) - return None - - def parse_version_with_pattern(self, pattern: "re.Pattern[str]", version_str: str) -> "VersionInfo | None": - """Parse version string using a specific regex pattern. - - Args: - pattern: Compiled regex pattern for the version format - version_str: Raw version string from database - - Returns: - VersionInfo instance or None if parsing fails - """ - match = pattern.search(version_str) - if not match: - return None - groups = match.groups() - if not groups: - return None - major = int(groups[0]) - minor = int(groups[1]) if len(groups) > VERSION_GROUPS_MIN_FOR_MINOR and groups[1] else 0 - patch = int(groups[2]) if len(groups) > VERSION_GROUPS_MIN_FOR_PATCH and groups[2] else 0 - return VersionInfo(major, minor, patch) - - def _resolve_log_adapter(self) -> str: - """Resolve adapter identifier for logging.""" - return str(type(self).dialect) - - def _log_version_detected(self, adapter: str, version: "VersionInfo") -> None: - """Log detected database version with db.system context.""" - logger.debug( - "Detected database version", extra={"db.system": resolve_db_system(adapter), "db.version": str(version)} - ) - - def _log_version_unavailable(self, adapter: str, reason: str) -> None: - """Log that database version could not be determined.""" - logger.debug("Database version unavailable", extra={"db.system": resolve_db_system(adapter), "reason": reason}) - - def _log_schema_introspect( - self, driver: Any, *, schema_name: "str | None", table_name: "str | None", operation: str - ) -> None: - """Log schema-level introspection activity.""" - log_with_context( - logger, - logging.DEBUG, - "schema.introspect", - db_system=resolve_db_system(type(driver).__name__), - schema_name=schema_name, - table_name=table_name, - operation=operation, - ) - - def _log_table_describe(self, driver: Any, *, schema_name: "str | None", table_name: str, operation: str) -> None: - """Log table-level introspection activity.""" - log_with_context( - logger, - logging.DEBUG, - "table.describe", - db_system=resolve_db_system(type(driver).__name__), - schema_name=schema_name, - table_name=table_name, - operation=operation, - ) - - def detect_version_with_queries(self, driver: "HasExecuteProtocol", queries: "list[str]") -> "VersionInfo | None": - """Try multiple version queries to detect database version. - - Args: - driver: Database driver with execute support - queries: List of SQL queries to try - - Returns: - Version information or None if detection fails - """ - for query in queries: - with suppress(Exception): - result: HasDataProtocol = driver.execute(query) - result_data = result.data - if result_data: - first_row = result_data[0] - version_str = str(first_row) - if isinstance(first_row, dict): - version_str = str(next(iter(first_row.values()))) - elif isinstance(first_row, (list, tuple)): - version_str = str(first_row[0]) - - parsed_version = self.parse_version_string(version_str) - if parsed_version: - self._log_version_detected(self._resolve_log_adapter(), parsed_version) - return parsed_version - - self._log_version_unavailable(self._resolve_log_adapter(), "queries_exhausted") - return None - - def get_default_type_mapping(self) -> "dict[str, str]": - """Get default type mappings for common categories. - - Returns: - Dictionary mapping type categories to generic SQL types - """ - return { - "json": "TEXT", - "uuid": "VARCHAR(36)", - "boolean": "INTEGER", - "timestamp": "TIMESTAMP", - "text": "TEXT", - "blob": "BLOB", - } - - def get_default_features(self) -> "list[str]": - """Get default feature flags supported by most databases. - - Returns: - List of commonly supported feature names - """ - return ["supports_transactions", "supports_prepared_statements"] - - def sort_tables_topologically(self, tables: "list[str]", foreign_keys: "list[ForeignKeyMetadata]") -> "list[str]": - """Sort tables topologically based on foreign key dependencies. - - Args: - tables: List of table names. - foreign_keys: List of foreign key metadata. - - Returns: - List of table names in topological order (dependencies first). - """ - sorter: graphlib.TopologicalSorter[str] = graphlib.TopologicalSorter() - for table in tables: - sorter.add(table) - for fk in foreign_keys: - if fk.table_name == fk.referenced_table: - continue - sorter.add(fk.table_name, fk.referenced_table) - return list(sorter.static_order()) - - def get_cached_version_for_driver(self, driver: Any) -> object: - """Get cached version info for a driver instance. - - Args: - driver: Async database driver instance. - - Returns: - Tuple of (was_cached, version_info). - - """ - return self.get_cached_version(id(driver)) - - def cache_version_for_driver(self, driver: Any, version: "VersionInfo | None") -> None: - """Cache version info for a driver instance. - - Args: - driver: Async database driver instance. - version: Parsed version info or None. - - """ - self.cache_version(id(driver), version) + self._version_cache: dict[int, VersionInfo | None] = {} + self._version_fetch_attempted: set[int] = set() @abstractmethod async def get_version(self, driver: Any) -> "VersionInfo | None": @@ -2169,15 +1912,3 @@ async def get_foreign_keys( """ - def list_available_features(self) -> "list[str]": - """List all features that can be checked via get_feature_flag. - - Returns: - List of feature names this data dictionary supports - - """ - config = self.get_dialect_config() - features = set(self.get_default_features()) - features.update(config.feature_flags.keys()) - features.update(config.feature_versions.keys()) - return sorted(features) diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 1a9049296..3b8e1bf0e 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -8,7 +8,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal, NamedTuple, NoReturn, Protocol, cast, overload -from mypy_extensions import mypyc_attr +from mypy_extensions import mypyc_attr, trait from sqlglot import exp from typing_extensions import Self @@ -479,27 +479,28 @@ def handle_single_row_error(error: ValueError) -> "NoReturn": raise error -@mypyc_attr(native_class=False, allow_interpreted_subclasses=True) +@mypyc_attr(allow_interpreted_subclasses=True) +@trait class DataDictionaryDialectMixin: """Mixin providing dialect SQL helpers for data dictionaries.""" __slots__ = () - dialect: str + dialect: "ClassVar[str]" def get_dialect_config(self) -> "DialectConfig": """Return the dialect configuration for this data dictionary.""" - return get_dialect_config(self.dialect) + return get_dialect_config(type(self).dialect) def get_query(self, name: str) -> "SQL": """Return a named SQL query for this dialect.""" loader = get_data_dictionary_loader() - return loader.get_query(self.dialect, name) + return loader.get_query(type(self).dialect, name) def get_query_text(self, name: str) -> str: """Return raw SQL text for a named query for this dialect.""" loader = get_data_dictionary_loader() - return loader.get_query_text(self.dialect, name) + return loader.get_query_text(type(self).dialect, name) def get_query_text_or_none(self, name: str) -> "str | None": """Return raw SQL text for a named query or None if missing.""" @@ -526,14 +527,26 @@ def resolve_feature_flag(self, feature: str, version: "VersionInfo | None") -> b return False return bool(version >= required_version) + def get_default_features(self) -> "list[str]": + """Get default feature flags. Overridden by DataDictionaryMixin.""" + return [] + def list_available_features(self) -> "list[str]": - """List available feature flags for this dialect.""" + """List all features that can be checked via get_feature_flag. + + Returns: + List of feature names this data dictionary supports + + """ config = self.get_dialect_config() - features = set(config.feature_flags.keys()) | set(config.feature_versions.keys()) + features = set(self.get_default_features()) + features.update(config.feature_flags.keys()) + features.update(config.feature_versions.keys()) return sorted(features) @mypyc_attr(allow_interpreted_subclasses=True) +@trait class DataDictionaryMixin: """Mixin providing common data dictionary functionality. @@ -541,15 +554,13 @@ class DataDictionaryMixin: feature flags or optimal types. """ - __slots__ = ("_version_cache", "_version_fetch_attempted") + __slots__ = () + + dialect: "ClassVar[str]" _version_cache: "dict[int, VersionInfo | None]" _version_fetch_attempted: "set[int]" - def __init__(self) -> None: - self._version_cache = {} - self._version_fetch_attempted = set() - def get_cached_version(self, driver_id: int) -> "VersionCacheResult": """Get cached version info for a driver. @@ -649,9 +660,7 @@ def parse_version_with_pattern(self, pattern: "re.Pattern[str]", version_str: st def _resolve_log_adapter(self) -> str: """Resolve adapter identifier for logging.""" - if hasattr(self, "dialect"): - return str(self.dialect) # pyright: ignore[reportAttributeAccessIssue] - return type(self).__name__ + return str(type(self).dialect) def _log_version_detected(self, adapter: str, version: VersionInfo) -> None: """Log detected database version with db.system context.""" diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index c7e2db3d3..f67ece950 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -1,10 +1,6 @@ """Synchronous driver protocol implementation.""" -import graphlib -import logging -import re from abc import abstractmethod -from contextlib import suppress from time import perf_counter from typing import TYPE_CHECKING, Any, ClassVar, Final, cast, final, overload @@ -13,18 +9,15 @@ from sqlspec.core import SQL, StackResult, create_arrow_result from sqlspec.core.result import DMLResult from sqlspec.core.stack import StackOperation, StatementStack -from sqlspec.data_dictionary._loader import get_data_dictionary_loader -from sqlspec.data_dictionary._registry import get_dialect_config from sqlspec.driver._common import ( - VERSION_GROUPS_MIN_FOR_MINOR, - VERSION_GROUPS_MIN_FOR_PATCH, CommonDriverAttributesMixin, + DataDictionaryDialectMixin, + DataDictionaryMixin, ExecutionResult, StackExecutionObserver, SyncExceptionHandler, describe_stack_statement, handle_single_row_error, - resolve_db_system, ) from sqlspec.driver._query_cache import CachedQuery from sqlspec.driver._sql_helpers import DEFAULT_PRETTY @@ -37,11 +30,11 @@ create_storage_job, stringify_storage_target, ) -from sqlspec.exceptions import ImproperConfigurationError, SQLFileNotFoundError, StackExecutionError +from sqlspec.exceptions import ImproperConfigurationError, StackExecutionError from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry, SyncStoragePipeline from sqlspec.typing import VersionInfo from sqlspec.utils.arrow_helpers import convert_dict_to_arrow_with_schema -from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.logging import get_logger from sqlspec.utils.schema import ValueT, to_value_type from sqlspec.utils.type_guards import has_asdict_method, is_dict_row, is_mapping_like @@ -52,8 +45,6 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, Statement, StatementConfig, StatementFilter - from sqlspec.data_dictionary._types import DialectConfig - from sqlspec.protocols import HasDataProtocol, HasExecuteProtocol from sqlspec.typing import ( ArrowReturnFormat, ArrowTable, @@ -1769,270 +1760,20 @@ def _create_storage_job( return create_storage_job(produced, provided, status=status) -@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) -class SyncDataDictionaryBase: +@mypyc_attr(allow_interpreted_subclasses=True) +class SyncDataDictionaryBase(DataDictionaryDialectMixin, DataDictionaryMixin): """Base class for synchronous data dictionary implementations. Uses Python-compatible class layouts for cross-module inheritance. Child classes define dialect as a class attribute. """ - _version_cache: "dict[int, VersionInfo | None]" - _version_fetch_attempted: "set[int]" - dialect: "ClassVar[str]" """Dialect identifier. Must be defined by subclasses as a class attribute.""" def __init__(self) -> None: - self._version_cache = {} - self._version_fetch_attempted = set() - - # ───────────────────────────────────────────────────────────────────────────── - # DIALECT SQL METHODS (merged from DialectSQLMixin) - # ───────────────────────────────────────────────────────────────────────────── - - def get_dialect_config(self) -> "DialectConfig": - """Return the dialect configuration for this data dictionary.""" - return get_dialect_config(type(self).dialect) - - def get_query(self, name: str) -> "SQL": - """Return a named SQL query for this dialect.""" - loader = get_data_dictionary_loader() - return loader.get_query(type(self).dialect, name) - - def get_query_text(self, name: str) -> str: - """Return raw SQL text for a named query for this dialect.""" - loader = get_data_dictionary_loader() - return loader.get_query_text(type(self).dialect, name) - - def get_query_text_or_none(self, name: str) -> "str | None": - """Return raw SQL text for a named query or None if missing.""" - try: - return self.get_query_text(name) - except SQLFileNotFoundError: - return None - - def resolve_schema(self, schema: "str | None") -> "str | None": - """Return a schema name using dialect defaults when missing.""" - if schema is not None: - return schema - config = self.get_dialect_config() - return config.default_schema - - def resolve_feature_flag(self, feature: str, version: "VersionInfo | None") -> bool: - """Resolve a feature flag using dialect config and version info.""" - config = self.get_dialect_config() - flag = config.get_feature_flag(feature) - if flag is not None: - return flag - required_version = config.get_feature_version(feature) - if required_version is None or version is None: - return False - return bool(version >= required_version) - - # ───────────────────────────────────────────────────────────────────────────── - # VERSION CACHING METHODS (inlined from DataDictionaryMixin) - # ───────────────────────────────────────────────────────────────────────────── - - def get_cached_version(self, driver_id: int) -> object: - """Get cached version info for a driver. - - Args: - driver_id: The id() of the driver instance. - - Returns: - Tuple of (was_cached, version_info). If was_cached is False, - the caller should fetch the version and call cache_version(). - """ - if driver_id in self._version_fetch_attempted: - return True, self._version_cache.get(driver_id) - return False, None - - def cache_version(self, driver_id: int, version: "VersionInfo | None") -> None: - """Cache version info for a driver. - - Args: - driver_id: The id() of the driver instance. - version: The version info to cache (can be None if detection failed). - """ - self._version_fetch_attempted.add(driver_id) - if version is not None: - self._version_cache[driver_id] = version - - def parse_version_string(self, version_str: str) -> "VersionInfo | None": - """Parse version string into VersionInfo. - - Args: - version_str: Raw version string from database - - Returns: - VersionInfo instance or None if parsing fails - """ - patterns = [r"(\d+)\.(\d+)\.(\d+)", r"(\d+)\.(\d+)", r"(\d+)"] - for pattern in patterns: - match = re.search(pattern, version_str) - if match: - groups = match.groups() - major = int(groups[0]) - minor = int(groups[1]) if len(groups) > VERSION_GROUPS_MIN_FOR_MINOR else 0 - patch = int(groups[2]) if len(groups) > VERSION_GROUPS_MIN_FOR_PATCH else 0 - return VersionInfo(major, minor, patch) - return None - - def parse_version_with_pattern(self, pattern: "re.Pattern[str]", version_str: str) -> "VersionInfo | None": - """Parse version string using a specific regex pattern. - - Args: - pattern: Compiled regex pattern for the version format - version_str: Raw version string from database - - Returns: - VersionInfo instance or None if parsing fails - """ - match = pattern.search(version_str) - if not match: - return None - groups = match.groups() - if not groups: - return None - major = int(groups[0]) - minor = int(groups[1]) if len(groups) > VERSION_GROUPS_MIN_FOR_MINOR and groups[1] else 0 - patch = int(groups[2]) if len(groups) > VERSION_GROUPS_MIN_FOR_PATCH and groups[2] else 0 - return VersionInfo(major, minor, patch) - - def _resolve_log_adapter(self) -> str: - """Resolve adapter identifier for logging.""" - return str(type(self).dialect) - - def _log_version_detected(self, adapter: str, version: "VersionInfo") -> None: - """Log detected database version with db.system context.""" - logger.debug( - "Detected database version", extra={"db.system": resolve_db_system(adapter), "db.version": str(version)} - ) - - def _log_version_unavailable(self, adapter: str, reason: str) -> None: - """Log that database version could not be determined.""" - logger.debug("Database version unavailable", extra={"db.system": resolve_db_system(adapter), "reason": reason}) - - def _log_schema_introspect( - self, driver: Any, *, schema_name: "str | None", table_name: "str | None", operation: str - ) -> None: - """Log schema-level introspection activity.""" - log_with_context( - logger, - logging.DEBUG, - "schema.introspect", - db_system=resolve_db_system(type(driver).__name__), - schema_name=schema_name, - table_name=table_name, - operation=operation, - ) - - def _log_table_describe(self, driver: Any, *, schema_name: "str | None", table_name: str, operation: str) -> None: - """Log table-level introspection activity.""" - log_with_context( - logger, - logging.DEBUG, - "table.describe", - db_system=resolve_db_system(type(driver).__name__), - schema_name=schema_name, - table_name=table_name, - operation=operation, - ) - - def detect_version_with_queries(self, driver: "HasExecuteProtocol", queries: "list[str]") -> "VersionInfo | None": - """Try multiple version queries to detect database version. - - Args: - driver: Database driver with execute support - queries: List of SQL queries to try - - Returns: - Version information or None if detection fails - """ - for query in queries: - with suppress(Exception): - result: HasDataProtocol = driver.execute(query) - result_data = result.data - if result_data: - first_row = result_data[0] - version_str = str(first_row) - if isinstance(first_row, dict): - version_str = str(next(iter(first_row.values()))) - elif isinstance(first_row, (list, tuple)): - version_str = str(first_row[0]) - - parsed_version = self.parse_version_string(version_str) - if parsed_version: - self._log_version_detected(self._resolve_log_adapter(), parsed_version) - return parsed_version - - self._log_version_unavailable(self._resolve_log_adapter(), "queries_exhausted") - return None - - def get_default_type_mapping(self) -> "dict[str, str]": - """Get default type mappings for common categories. - - Returns: - Dictionary mapping type categories to generic SQL types - """ - return { - "json": "TEXT", - "uuid": "VARCHAR(36)", - "boolean": "INTEGER", - "timestamp": "TIMESTAMP", - "text": "TEXT", - "blob": "BLOB", - } - - def get_default_features(self) -> "list[str]": - """Get default feature flags supported by most databases. - - Returns: - List of commonly supported feature names - """ - return ["supports_transactions", "supports_prepared_statements"] - - def sort_tables_topologically(self, tables: "list[str]", foreign_keys: "list[ForeignKeyMetadata]") -> "list[str]": - """Sort tables topologically based on foreign key dependencies. - - Args: - tables: List of table names. - foreign_keys: List of foreign key metadata. - - Returns: - List of table names in topological order (dependencies first). - """ - sorter: graphlib.TopologicalSorter[str] = graphlib.TopologicalSorter() - for table in tables: - sorter.add(table) - for fk in foreign_keys: - if fk.table_name == fk.referenced_table: - continue - sorter.add(fk.table_name, fk.referenced_table) - return list(sorter.static_order()) - - def get_cached_version_for_driver(self, driver: Any) -> object: - """Get cached version info for a driver instance. - - Args: - driver: Sync database driver instance. - - Returns: - Tuple of (was_cached, version_info). - - """ - return self.get_cached_version(id(driver)) - - def cache_version_for_driver(self, driver: Any, version: "VersionInfo | None") -> None: - """Cache version info for a driver instance. - - Args: - driver: Sync database driver instance. - version: Parsed version info or None. - - """ - self.cache_version(id(driver), version) + self._version_cache: dict[int, VersionInfo | None] = {} + self._version_fetch_attempted: set[int] = set() @abstractmethod def get_version(self, driver: Any) -> "VersionInfo | None": @@ -2133,15 +1874,3 @@ def get_foreign_keys( """ - def list_available_features(self) -> "list[str]": - """List all features that can be checked via get_feature_flag. - - Returns: - List of feature names this data dictionary supports - - """ - config = self.get_dialect_config() - features = set(self.get_default_features()) - features.update(config.feature_flags.keys()) - features.update(config.feature_versions.keys()) - return sorted(features) diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 990654ba2..116d3af46 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -40,6 +40,13 @@ from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError from sqlspec.utils.serializers import from_json, to_json +# Detect whether the core parameters module is mypyc-compiled. +# When compiled, `patch.object` on C-extension classes is a no-op, +# so tests that assert mock call counts must be skipped. +from sqlspec.core.parameters import _validator as _validator_module + +_VALIDATOR_COMPILED = (_validator_module.__file__ or "").endswith((".so", ".pyd")) + _ADAPTER_DRIVER_MODULES: "tuple[str, ...]" = ( "sqlspec.adapters.adbc.driver", "sqlspec.adapters.aiosqlite.driver", @@ -2110,6 +2117,7 @@ def test_multiple_unsupported_parameters_all_normalized( # Should have NUMERIC placeholders assert "$" in normalized_sql + @pytest.mark.skipif(_VALIDATOR_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_process_reuses_extracted_metadata_for_parse_normalization( self, processor: ParameterProcessor, validator: ParameterValidator ) -> None: @@ -2129,6 +2137,7 @@ def test_process_reuses_extracted_metadata_for_parse_normalization( assert result.sql == sql assert result.sqlglot_sql == "SELECT * FROM t WHERE id = $1" + @pytest.mark.skipif(_VALIDATOR_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_process_reuses_extracted_metadata_for_execution_conversion( self, processor: ParameterProcessor, validator: ParameterValidator ) -> None: diff --git a/uv.lock b/uv.lock index a5b2d2ab3..2c09320f5 100644 --- a/uv.lock +++ b/uv.lock @@ -1529,14 +1529,14 @@ wheels = [ [[package]] name = "faker" -version = "40.8.0" +version = "40.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/03/14428edc541467c460d363f6e94bee9acc271f3e62470630fc9a647d0cf2/faker-40.8.0.tar.gz", hash = "sha256:936a3c9be6c004433f20aa4d99095df5dec82b8c7ad07459756041f8c1728875", size = 1956493, upload-time = "2026-03-04T16:18:48.161Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/dc/b68e5378e5a7db0ab776efcdd53b6fe374b29d703e156fd5bb4c5437069e/faker-40.11.0.tar.gz", hash = "sha256:7c419299103b13126bd02ec14bd2b47b946edb5a5eedf305e66a193b25f9a734", size = 1957570, upload-time = "2026-03-13T14:36:11.844Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4c/3b/c6348f1e285e75b069085b18110a4e6325b763a5d35d5e204356fc7c20b3/faker-40.8.0-py3-none-any.whl", hash = "sha256:eb21bdba18f7a8375382eb94fb436fce07046893dc94cb20817d28deb0c3d579", size = 1989124, upload-time = "2026-03-04T16:18:46.45Z" }, + { url = "https://files.pythonhosted.org/packages/b1/fa/a86c6ba66f0308c95b9288b1e3eaccd934b545646f63494a86f1ec2f8c8e/faker-40.11.0-py3-none-any.whl", hash = "sha256:0e9816c950528d2a37d74863f3ef389ea9a3a936cbcde0b11b8499942e25bf90", size = 1989457, upload-time = "2026-03-13T14:36:09.792Z" }, ] [[package]] @@ -1819,7 +1819,7 @@ s3 = [ [[package]] name = "google-adk" -version = "1.26.0" +version = "1.27.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiosqlite" }, @@ -1833,6 +1833,7 @@ dependencies = [ { name = "google-cloud-bigquery" }, { name = "google-cloud-bigquery-storage" }, { name = "google-cloud-bigtable" }, + { name = "google-cloud-dataplex" }, { name = "google-cloud-discoveryengine" }, { name = "google-cloud-pubsub" }, { name = "google-cloud-secret-manager" }, @@ -1867,9 +1868,9 @@ dependencies = [ { name = "watchdog" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/b2/09b9ee1374b767eaba29e693b0b867fb587a9a131ea159300c9f9fa97d61/google_adk-1.26.0.tar.gz", hash = "sha256:29ec8636025848716246228b595749f785ddc83fb3982052ec92ae871f12fcd8", size = 2250703, upload-time = "2026-02-26T23:39:15.614Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/59/e49d38b6948192180ed971d65bd03ad75d07593476db1cdd63b5cf6cfbeb/google_adk-1.27.1.tar.gz", hash = "sha256:b252b6e2139385fb8b96cc24026f77be83a93a71da877aa3e7e98868a8f5c9d9", size = 2297550, upload-time = "2026-03-13T23:09:05.028Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d8/a0/0ca4174ad1ad5f8a81b26e0d67bdff509e18ecc2ae79ca7a87e6f16dd394/google_adk-1.26.0-py3-none-any.whl", hash = "sha256:1a74c6b25f8f4d4098e1a01118b8eefcdf7b3741ba07993093a773bc6775b4d5", size = 2621967, upload-time = "2026-02-26T23:39:13.026Z" }, + { url = "https://files.pythonhosted.org/packages/63/ac/dfaa3f751c22662ff045814f62f68be7f8e5a9c0dbd882c24a7afb5284c7/google_adk-1.27.1-py3-none-any.whl", hash = "sha256:806c72dd6d79b16a1dd86e5874ed48f5a42f2acc0129becbb77487a7629af224", size = 2688647, upload-time = "2026-03-13T23:09:03.457Z" }, ] [[package]] @@ -1912,16 +1913,15 @@ wheels = [ [[package]] name = "google-auth" -version = "2.49.0" +version = "2.49.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "pyasn1-modules" }, - { name = "rsa" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7d/59/7371175bfd949abfb1170aa076352131d7281bd9449c0f978604fc4431c3/google_auth-2.49.0.tar.gz", hash = "sha256:9cc2d9259d3700d7a257681f81052db6737495a1a46b610597f4b8bafe5286ae", size = 333444, upload-time = "2026-03-06T21:53:06.07Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/80/6a696a07d3d3b0a92488933532f03dbefa4a24ab80fb231395b9a2a1be77/google_auth-2.49.1.tar.gz", hash = "sha256:16d40da1c3c5a0533f57d268fe72e0ebb0ae1cc3b567024122651c045d879b64", size = 333825, upload-time = "2026-03-12T19:30:58.135Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/37/45/de64b823b639103de4b63dd193480dce99526bd36be6530c2dba85bf7817/google_auth-2.49.0-py3-none-any.whl", hash = "sha256:f893ef7307f19cf53700b7e2f61b5a6affe3aa0edf9943b13788920ab92d8d87", size = 240676, upload-time = "2026-03-06T21:52:38.304Z" }, + { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, ] [package.optional-dependencies] @@ -2112,6 +2112,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl", hash = "sha256:67d977b41ae6c7211ee830c7912e41003ea8194bff15ae7d72fd6f51e57acabc", size = 29469, upload-time = "2025-10-29T23:17:38.548Z" }, ] +[[package]] +name = "google-cloud-dataplex" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "grpc-google-iam-v1" }, + { name = "grpcio" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/64/38445469e85e20b6fbb0ad58d0466daa3bd779789729562c12b35cfc24c3/google_cloud_dataplex-2.16.0.tar.gz", hash = "sha256:f9086abb94ae1f35151b2df5b729cc6bbf9361354d5afd22e76515ec0a8e7fdc", size = 766385, upload-time = "2026-01-15T13:15:22.79Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/1a/9d0fc0188abcfe3c4e58db72972b100badb9899e34d94471223ac2037816/google_cloud_dataplex-2.16.0-py3-none-any.whl", hash = "sha256:173ce519395cd424c1ae22de4efb194767524fb5a2424194f091e63b34f4dfc1", size = 584533, upload-time = "2026-01-15T13:13:12.348Z" }, +] + [[package]] name = "google-cloud-discoveryengine" version = "0.13.12" @@ -2183,22 +2200,22 @@ wheels = [ [[package]] name = "google-cloud-pubsub" -version = "2.35.0" +version = "2.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core", extra = ["grpc"] }, { name = "google-auth" }, { name = "grpc-google-iam-v1" }, - { name = "grpcio" }, + { name = "grpcio", marker = "python_full_version < '3.14'" }, { name = "grpcio-status" }, { name = "opentelemetry-api" }, { name = "opentelemetry-sdk" }, { name = "proto-plus" }, { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/65/ad/dde4c0b014247190a4df0dfa9c90de81b47909e22e2e442198f449a3593f/google_cloud_pubsub-2.35.0.tar.gz", hash = "sha256:2c0d1d7ccda52fa12fb73f34b7eb9899381e2fd931c7d47b10f724cdfac06f95", size = 396812, upload-time = "2026-02-05T22:29:14.584Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8d/f5cece431daaa2024129569ed35e6eb90a72bb51f0c96e5c7f5cab6d34d7/google_cloud_pubsub-2.36.0.tar.gz", hash = "sha256:96e057e5f83433ce428852095d652c2f7fc193f0f77db1f27cc39186fe69c1f4", size = 401324, upload-time = "2026-03-12T19:31:02.099Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/cb/b783f4e910f0ec4010d279bafce0cd1ed8a10bac41970eb5c6a6416008ab/google_cloud_pubsub-2.35.0-py3-none-any.whl", hash = "sha256:c32e4eb29e532ec784b5abb5d674807715ec07895b7c022b9404871dec09970d", size = 320973, upload-time = "2026-02-05T22:29:13.096Z" }, + { url = "https://files.pythonhosted.org/packages/93/fd/d0a8f0f93a4d115282ecdd8ef0267e4611bde6ca29c9dba803f3ebae7115/google_cloud_pubsub-2.36.0-py3-none-any.whl", hash = "sha256:d6726ccf9373924e0746338dadf8244b9aa1a97a24130b59a2106c926ea37598", size = 323364, upload-time = "2026-03-12T19:30:48.077Z" }, ] [[package]] @@ -2345,7 +2362,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.66.0" +version = "1.67.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2359,9 +2376,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9b/ba/0b343b0770d4710ad2979fd9301d7caa56c940174d5361ed4a7cc4979241/google_genai-1.66.0.tar.gz", hash = "sha256:ffc01647b65046bca6387320057aa51db0ad64bcc72c8e3e914062acfa5f7c49", size = 504386, upload-time = "2026-03-04T22:15:28.156Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/07/59a498f81f2c7b0649eacda2ea470b7fd8bd7149f20caba22962081bdd51/google_genai-1.67.0.tar.gz", hash = "sha256:897195a6a9742deb6de240b99227189ada8b2d901d61bdfba836c3092021eab6", size = 506972, upload-time = "2026-03-12T20:39:16.241Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/dd/403949d922d4e261b08b64aaa132af4e456c3b15c8e2a2d9e6ef693f66e2/google_genai-1.66.0-py3-none-any.whl", hash = "sha256:7f127a39cf695277104ce4091bb26e417c59bb46e952ff3699c3a982d9c474ee", size = 732174, upload-time = "2026-03-04T22:15:26.63Z" }, + { url = "https://files.pythonhosted.org/packages/6e/c2/562aa1f086e53529ffbeb5b43d5d8bc42c1b968102b5e2163fad005ce298/google_genai-1.67.0-py3-none-any.whl", hash = "sha256:58b0484ff2d4335fa53c724b489e9f807fcca8115d9cdbd8fdf341121fbd6d2d", size = 733542, upload-time = "2026-03-12T20:39:14.615Z" }, ] [[package]] @@ -5527,11 +5544,14 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.12.0" +version = "2.12.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a8/10/e8192be5f38f3e8e7e046716de4cae33d56fd5ae08927a823bb916be36c1/pyjwt-2.12.0.tar.gz", hash = "sha256:2f62390b667cd8257de560b850bb5a883102a388829274147f1d724453f8fb02", size = 102511, upload-time = "2026-03-12T17:15:30.831Z" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/27/a3b6e5bf6ff856d2509292e95c8f57f0df7017cf5394921fc4e4ef40308a/pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b", size = 102564, upload-time = "2026-03-13T19:27:37.25Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/70/70f895f404d363d291dcf62c12c85fdd47619ad9674ac0f53364d035925a/pyjwt-2.12.0-py3-none-any.whl", hash = "sha256:9bb459d1bdd0387967d287f5656bf7ec2b9a26645d1961628cda1764e087fd6e", size = 29700, upload-time = "2026-03-12T17:15:29.257Z" }, + { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, ] [package.optional-dependencies] @@ -5655,15 +5675,15 @@ wheels = [ [[package]] name = "pyopenssl" -version = "25.3.0" +version = "26.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/80/be/97b83a464498a79103036bc74d1038df4a7ef0e402cfaf4d5e113fb14759/pyopenssl-25.3.0.tar.gz", hash = "sha256:c981cb0a3fd84e8602d7afc209522773b94c1c2446a3c710a75b06fe1beae329", size = 184073, upload-time = "2025-09-17T00:32:21.037Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/11/a62e1d33b373da2b2c2cd9eb508147871c80f12b1cacde3c5d314922afdd/pyopenssl-26.0.0.tar.gz", hash = "sha256:f293934e52936f2e3413b89c6ce36df66a0b34ae1ea3a053b8c5020ff2f513fc", size = 185534, upload-time = "2026-03-15T14:28:26.353Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/81/ef2b1dfd1862567d573a4fdbc9f969067621764fbb74338496840a1d2977/pyopenssl-25.3.0-py3-none-any.whl", hash = "sha256:1fda6fc034d5e3d179d39e59c1895c9faeaf40a79de5fc4cbbfbe0d36f4a77b6", size = 57268, upload-time = "2025-09-17T00:32:19.474Z" }, + { url = "https://files.pythonhosted.org/packages/fb/7d/d4f7d908fa8415571771b30669251d57c3cf313b36a856e6d7548ae01619/pyopenssl-26.0.0-py3-none-any.whl", hash = "sha256:df94d28498848b98cc1c0ffb8ef1e71e40210d3b0a8064c9d29571ed2904bf81", size = 57969, upload-time = "2026-03-15T14:28:24.864Z" }, ] [[package]] @@ -6237,18 +6257,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] -[[package]] -name = "rsa" -version = "4.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, -] - [[package]] name = "ruamel-yaml" version = "0.18.16" @@ -6321,27 +6329,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/77/9b/840e0039e65fcf12758adf684d2289024d6140cde9268cc59887dc55189c/ruff-0.15.5.tar.gz", hash = "sha256:7c3601d3b6d76dce18c5c824fc8d06f4eef33d6df0c21ec7799510cde0f159a2", size = 4574214, upload-time = "2026-03-05T20:06:34.946Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/20/5369c3ce21588c708bcbe517a8fbe1a8dfdb5dfd5137e14790b1da71612c/ruff-0.15.5-py3-none-linux_armv6l.whl", hash = "sha256:4ae44c42281f42e3b06b988e442d344a5b9b72450ff3c892e30d11b29a96a57c", size = 10478185, upload-time = "2026-03-05T20:06:29.093Z" }, - { url = "https://files.pythonhosted.org/packages/44/ed/e81dd668547da281e5dce710cf0bc60193f8d3d43833e8241d006720e42b/ruff-0.15.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6edd3792d408ebcf61adabc01822da687579a1a023f297618ac27a5b51ef0080", size = 10859201, upload-time = "2026-03-05T20:06:32.632Z" }, - { url = "https://files.pythonhosted.org/packages/c4/8f/533075f00aaf19b07c5cd6aa6e5d89424b06b3b3f4583bfa9c640a079059/ruff-0.15.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:89f463f7c8205a9f8dea9d658d59eff49db05f88f89cc3047fb1a02d9f344010", size = 10184752, upload-time = "2026-03-05T20:06:40.312Z" }, - { url = "https://files.pythonhosted.org/packages/66/0e/ba49e2c3fa0395b3152bad634c7432f7edfc509c133b8f4529053ff024fb/ruff-0.15.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba786a8295c6574c1116704cf0b9e6563de3432ac888d8f83685654fe528fd65", size = 10534857, upload-time = "2026-03-05T20:06:19.581Z" }, - { url = "https://files.pythonhosted.org/packages/59/71/39234440f27a226475a0659561adb0d784b4d247dfe7f43ffc12dd02e288/ruff-0.15.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd4b801e57955fe9f02b31d20375ab3a5c4415f2e5105b79fb94cf2642c91440", size = 10309120, upload-time = "2026-03-05T20:06:00.435Z" }, - { url = "https://files.pythonhosted.org/packages/f5/87/4140aa86a93df032156982b726f4952aaec4a883bb98cb6ef73c347da253/ruff-0.15.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:391f7c73388f3d8c11b794dbbc2959a5b5afe66642c142a6effa90b45f6f5204", size = 11047428, upload-time = "2026-03-05T20:05:51.867Z" }, - { url = "https://files.pythonhosted.org/packages/5a/f7/4953e7e3287676f78fbe85e3a0ca414c5ca81237b7575bdadc00229ac240/ruff-0.15.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc18f30302e379fe1e998548b0f5e9f4dff907f52f73ad6da419ea9c19d66c8", size = 11914251, upload-time = "2026-03-05T20:06:22.887Z" }, - { url = "https://files.pythonhosted.org/packages/77/46/0f7c865c10cf896ccf5a939c3e84e1cfaeed608ff5249584799a74d33835/ruff-0.15.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc6e7f90087e2d27f98dc34ed1b3ab7c8f0d273cc5431415454e22c0bd2a681", size = 11333801, upload-time = "2026-03-05T20:05:57.168Z" }, - { url = "https://files.pythonhosted.org/packages/d3/01/a10fe54b653061585e655f5286c2662ebddb68831ed3eaebfb0eb08c0a16/ruff-0.15.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1cb7169f53c1ddb06e71a9aebd7e98fc0fea936b39afb36d8e86d36ecc2636a", size = 11206821, upload-time = "2026-03-05T20:06:03.441Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0d/2132ceaf20c5e8699aa83da2706ecb5c5dcdf78b453f77edca7fb70f8a93/ruff-0.15.5-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9b037924500a31ee17389b5c8c4d88874cc6ea8e42f12e9c61a3d754ff72f1ca", size = 11133326, upload-time = "2026-03-05T20:06:25.655Z" }, - { url = "https://files.pythonhosted.org/packages/72/cb/2e5259a7eb2a0f87c08c0fe5bf5825a1e4b90883a52685524596bfc93072/ruff-0.15.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:65bb414e5b4eadd95a8c1e4804f6772bbe8995889f203a01f77ddf2d790929dd", size = 10510820, upload-time = "2026-03-05T20:06:37.79Z" }, - { url = "https://files.pythonhosted.org/packages/ff/20/b67ce78f9e6c59ffbdb5b4503d0090e749b5f2d31b599b554698a80d861c/ruff-0.15.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d20aa469ae3b57033519c559e9bc9cd9e782842e39be05b50e852c7c981fa01d", size = 10302395, upload-time = "2026-03-05T20:05:54.504Z" }, - { url = "https://files.pythonhosted.org/packages/5f/e5/719f1acccd31b720d477751558ed74e9c88134adcc377e5e886af89d3072/ruff-0.15.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:15388dd28c9161cdb8eda68993533acc870aa4e646a0a277aa166de9ad5a8752", size = 10754069, upload-time = "2026-03-05T20:06:06.422Z" }, - { url = "https://files.pythonhosted.org/packages/c3/9c/d1db14469e32d98f3ca27079dbd30b7b44dbb5317d06ab36718dee3baf03/ruff-0.15.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b30da330cbd03bed0c21420b6b953158f60c74c54c5f4c1dabbdf3a57bf355d2", size = 11304315, upload-time = "2026-03-05T20:06:10.867Z" }, - { url = "https://files.pythonhosted.org/packages/28/3a/950367aee7c69027f4f422059227b290ed780366b6aecee5de5039d50fa8/ruff-0.15.5-py3-none-win32.whl", hash = "sha256:732e5ee1f98ba5b3679029989a06ca39a950cced52143a0ea82a2102cb592b74", size = 10551676, upload-time = "2026-03-05T20:06:13.705Z" }, - { url = "https://files.pythonhosted.org/packages/b8/00/bf077a505b4e649bdd3c47ff8ec967735ce2544c8e4a43aba42ee9bf935d/ruff-0.15.5-py3-none-win_amd64.whl", hash = "sha256:821d41c5fa9e19117616c35eaa3f4b75046ec76c65e7ae20a333e9a8696bc7fe", size = 11678972, upload-time = "2026-03-05T20:06:45.379Z" }, - { url = "https://files.pythonhosted.org/packages/fe/4e/cd76eca6db6115604b7626668e891c9dd03330384082e33662fb0f113614/ruff-0.15.5-py3-none-win_arm64.whl", hash = "sha256:b498d1c60d2fe5c10c45ec3f698901065772730b411f164ae270bb6bfcc4740b", size = 10965572, upload-time = "2026-03-05T20:06:16.984Z" }, +version = "0.15.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/df/f8629c19c5318601d3121e230f74cbee7a3732339c52b21daa2b82ef9c7d/ruff-0.15.6.tar.gz", hash = "sha256:8394c7bb153a4e3811a4ecdacd4a8e6a4fa8097028119160dffecdcdf9b56ae4", size = 4597916, upload-time = "2026-03-12T23:05:47.51Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/2f/4e03a7e5ce99b517e98d3b4951f411de2b0fa8348d39cf446671adcce9a2/ruff-0.15.6-py3-none-linux_armv6l.whl", hash = "sha256:7c98c3b16407b2cf3d0f2b80c80187384bc92c6774d85fefa913ecd941256fff", size = 10508953, upload-time = "2026-03-12T23:05:17.246Z" }, + { url = "https://files.pythonhosted.org/packages/70/60/55bcdc3e9f80bcf39edf0cd272da6fa511a3d94d5a0dd9e0adf76ceebdb4/ruff-0.15.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ee7dcfaad8b282a284df4aa6ddc2741b3f4a18b0555d626805555a820ea181c3", size = 10942257, upload-time = "2026-03-12T23:05:23.076Z" }, + { url = "https://files.pythonhosted.org/packages/e7/f9/005c29bd1726c0f492bfa215e95154cf480574140cb5f867c797c18c790b/ruff-0.15.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3bd9967851a25f038fc8b9ae88a7fbd1b609f30349231dffaa37b6804923c4bb", size = 10322683, upload-time = "2026-03-12T23:05:33.738Z" }, + { url = "https://files.pythonhosted.org/packages/5f/74/2f861f5fd7cbb2146bddb5501450300ce41562da36d21868c69b7a828169/ruff-0.15.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13f4594b04e42cd24a41da653886b04d2ff87adbf57497ed4f728b0e8a4866f8", size = 10660986, upload-time = "2026-03-12T23:05:53.245Z" }, + { url = "https://files.pythonhosted.org/packages/c1/a1/309f2364a424eccb763cdafc49df843c282609f47fe53aa83f38272389e0/ruff-0.15.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e2ed8aea2f3fe57886d3f00ea5b8aae5bf68d5e195f487f037a955ff9fbaac9e", size = 10332177, upload-time = "2026-03-12T23:05:56.145Z" }, + { url = "https://files.pythonhosted.org/packages/30/41/7ebf1d32658b4bab20f8ac80972fb19cd4e2c6b78552be263a680edc55ac/ruff-0.15.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70789d3e7830b848b548aae96766431c0dc01a6c78c13381f423bf7076c66d15", size = 11170783, upload-time = "2026-03-12T23:06:01.742Z" }, + { url = "https://files.pythonhosted.org/packages/76/be/6d488f6adca047df82cd62c304638bcb00821c36bd4881cfca221561fdfc/ruff-0.15.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:542aaf1de3154cea088ced5a819ce872611256ffe2498e750bbae5247a8114e9", size = 12044201, upload-time = "2026-03-12T23:05:28.697Z" }, + { url = "https://files.pythonhosted.org/packages/71/68/e6f125df4af7e6d0b498f8d373274794bc5156b324e8ab4bf5c1b4fc0ec7/ruff-0.15.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c22e6f02c16cfac3888aa636e9eba857254d15bbacc9906c9689fdecb1953ab", size = 11421561, upload-time = "2026-03-12T23:05:31.236Z" }, + { url = "https://files.pythonhosted.org/packages/f1/9f/f85ef5fd01a52e0b472b26dc1b4bd228b8f6f0435975442ffa4741278703/ruff-0.15.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98893c4c0aadc8e448cfa315bd0cc343a5323d740fe5f28ef8a3f9e21b381f7e", size = 11310928, upload-time = "2026-03-12T23:05:45.288Z" }, + { url = "https://files.pythonhosted.org/packages/8c/26/b75f8c421f5654304b89471ed384ae8c7f42b4dff58fa6ce1626d7f2b59a/ruff-0.15.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:70d263770d234912374493e8cc1e7385c5d49376e41dfa51c5c3453169dc581c", size = 11235186, upload-time = "2026-03-12T23:05:50.677Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d4/d5a6d065962ff7a68a86c9b4f5500f7d101a0792078de636526c0edd40da/ruff-0.15.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:55a1ad63c5a6e54b1f21b7514dfadc0c7fb40093fa22e95143cf3f64ebdcd512", size = 10635231, upload-time = "2026-03-12T23:05:37.044Z" }, + { url = "https://files.pythonhosted.org/packages/d6/56/7c3acf3d50910375349016cf33de24be021532042afbed87942858992491/ruff-0.15.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8dc473ba093c5ec238bb1e7429ee676dca24643c471e11fbaa8a857925b061c0", size = 10340357, upload-time = "2026-03-12T23:06:04.748Z" }, + { url = "https://files.pythonhosted.org/packages/06/54/6faa39e9c1033ff6a3b6e76b5df536931cd30caf64988e112bbf91ef5ce5/ruff-0.15.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:85b042377c2a5561131767974617006f99f7e13c63c111b998f29fc1e58a4cfb", size = 10860583, upload-time = "2026-03-12T23:05:58.978Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1e/509a201b843b4dfb0b32acdedf68d951d3377988cae43949ba4c4133a96a/ruff-0.15.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cef49e30bc5a86a6a92098a7fbf6e467a234d90b63305d6f3ec01225a9d092e0", size = 11410976, upload-time = "2026-03-12T23:05:39.955Z" }, + { url = "https://files.pythonhosted.org/packages/6c/25/3fc9114abf979a41673ce877c08016f8e660ad6cf508c3957f537d2e9fa9/ruff-0.15.6-py3-none-win32.whl", hash = "sha256:bbf67d39832404812a2d23020dda68fee7f18ce15654e96fb1d3ad21a5fe436c", size = 10616872, upload-time = "2026-03-12T23:05:42.451Z" }, + { url = "https://files.pythonhosted.org/packages/89/7a/09ece68445ceac348df06e08bf75db72d0e8427765b96c9c0ffabc1be1d9/ruff-0.15.6-py3-none-win_amd64.whl", hash = "sha256:aee25bc84c2f1007ecb5037dff75cef00414fdf17c23f07dc13e577883dca406", size = 11787271, upload-time = "2026-03-12T23:05:20.168Z" }, + { url = "https://files.pythonhosted.org/packages/7f/d0/578c47dd68152ddddddf31cd7fc67dc30b7cdf639a86275fda821b0d9d98/ruff-0.15.6-py3-none-win_arm64.whl", hash = "sha256:c34de3dd0b0ba203be50ae70f5910b17188556630e2178fd7d79fc030eb0d837", size = 11060497, upload-time = "2026-03-12T23:05:25.968Z" }, ] [[package]] @@ -7731,14 +7739,14 @@ wheels = [ [[package]] name = "types-cffi" -version = "1.17.0.20260307" +version = "2.0.0.20260315" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/cf/f727256b5b17e8d54a08e3c40877b3c49fdb35e7f683f9ee295bd4df51e1/types_cffi-1.17.0.20260307.tar.gz", hash = "sha256:1a4f1168d43ed8cd2b0ed40a3eb870cda685a154d98478b0a65862084f190a02", size = 17437, upload-time = "2026-03-07T03:49:26.106Z" } +sdist = { url = "https://files.pythonhosted.org/packages/42/30/ced203cc831683bf09eb4482491649a79705c489284c8b355480bebb0e82/types_cffi-2.0.0.20260315.tar.gz", hash = "sha256:b62f052d83fa6897b5987f82d43ebdde7ee718e8ed7beaf37257f2d98de31d25", size = 17407, upload-time = "2026-03-15T04:22:05.376Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/15/4564f173d031f64bf56964d192b6b705e679fc23c02704b84ccbcb809396/types_cffi-1.17.0.20260307-py3-none-any.whl", hash = "sha256:89b5b2c798d32fc6e3304903ed99af93fd608b741483ce7d57fa69eda40430e5", size = 20115, upload-time = "2026-03-07T03:49:25.031Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a4/aa3f0a25dfa3e45c4c8c7a01e79bbd94f15294290254e63d8263568f29c2/types_cffi-2.0.0.20260315-py3-none-any.whl", hash = "sha256:09efb6b5ea8e65e62fafb88da69b4ab3196a5876d31c6f7de7d0b83d8973c49c", size = 20094, upload-time = "2026-03-15T04:22:04.25Z" }, ] [[package]] From bf60df226b69db6fe47409b6a081f0b2d59f54fd Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 16:00:49 +0000 Subject: [PATCH 23/39] test: skip mock-dependent tests when modules are mypyc-compiled patch.object cannot intercept calls in C-extension modules, causing false assertion failures on mock call counts. --- tests/unit/core/test_result.py | 4 ++++ tests/unit/utils/test_sync_tools.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/tests/unit/core/test_result.py b/tests/unit/core/test_result.py index c0e55f46f..ccbfc9dfa 100644 --- a/tests/unit/core/test_result.py +++ b/tests/unit/core/test_result.py @@ -12,6 +12,8 @@ pytestmark = pytest.mark.xdist_group("core") +_RESULT_BASE_COMPILED = (result_base.__file__ or "").endswith((".so", ".pyd")) + @pytest.fixture def sample_data() -> list[dict[str, Any]]: @@ -358,6 +360,7 @@ class User: assert none_user is None +@pytest.mark.skipif(_RESULT_BASE_COMPILED, reason="patch.object cannot intercept mypyc-compiled modules") def test_sql_result_reuses_cached_schema_list_conversion() -> None: """Repeated list-shaped schema access should not re-run to_schema().""" @@ -381,6 +384,7 @@ class User: assert mocked_to_schema.call_count == 1 +@pytest.mark.skipif(_RESULT_BASE_COMPILED, reason="patch.object cannot intercept mypyc-compiled modules") def test_sql_result_reuses_cached_single_row_schema_conversion() -> None: """Repeated single-row schema access should not re-run to_schema().""" diff --git a/tests/unit/utils/test_sync_tools.py b/tests/unit/utils/test_sync_tools.py index 111a4bed2..edc2fac7f 100644 --- a/tests/unit/utils/test_sync_tools.py +++ b/tests/unit/utils/test_sync_tools.py @@ -25,6 +25,13 @@ with_ensure_async_, ) +# Detect whether the sync_tools module is mypyc-compiled. +# When compiled, `patch.object` / `patch()` on C-extension modules is a no-op, +# so tests that assert mock call counts must be skipped. +import sqlspec.utils.sync_tools as _sync_tools_module + +_SYNC_TOOLS_COMPILED = (_sync_tools_module.__file__ or "").endswith((".so", ".pyd")) + pytestmark = pytest.mark.xdist_group("utils") @@ -465,6 +472,7 @@ async def async_func() -> int: # --------------------------------------------------------------------------- +@pytest.mark.skipif(_SYNC_TOOLS_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_await_portal_fallback_when_current_task_exists() -> None: """When current_task is non-None and raise_sync_error=False, await_ should fall back to get_global_portal() instead of raising RuntimeError.""" @@ -492,6 +500,7 @@ async def async_double(x: int) -> int: mock_portal.call.assert_called_once() +@pytest.mark.skipif(_SYNC_TOOLS_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_await_raises_when_current_task_exists_and_raise_sync_error_true() -> None: """When current_task is non-None and raise_sync_error=True, await_ should raise RuntimeError with the appropriate message.""" @@ -512,6 +521,7 @@ async def async_func() -> int: sync_func() +@pytest.mark.skipif(_SYNC_TOOLS_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_await_portal_fallback_propagates_exceptions() -> None: """When using portal fallback (current_task non-None, raise_sync_error=False), exceptions from the coroutine should propagate through the portal.""" @@ -536,6 +546,7 @@ async def async_explode() -> int: sync_explode() +@pytest.mark.skipif(_SYNC_TOOLS_COMPILED, reason="patch.object cannot intercept mypyc-compiled methods") def test_await_run_coroutine_threadsafe_when_no_current_task() -> None: """When the loop is running but current_task is None (worker thread context), await_ should use asyncio.run_coroutine_threadsafe.""" From 4eab8cdd16b3ff6d1816d44ee30b75916e43afae Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 16:23:27 +0000 Subject: [PATCH 24/39] refactor(driver): consolidate exception handlers --- sqlspec/adapters/adbc/driver.py | 17 +-- sqlspec/adapters/aiosqlite/driver.py | 19 +-- sqlspec/adapters/asyncmy/driver.py | 16 +-- sqlspec/adapters/asyncpg/driver.py | 24 ++-- sqlspec/adapters/bigquery/driver.py | 16 +-- sqlspec/adapters/cockroach_asyncpg/driver.py | 17 +-- sqlspec/adapters/cockroach_psycopg/driver.py | 26 +--- sqlspec/adapters/duckdb/driver.py | 17 +-- sqlspec/adapters/mock/driver.py | 34 ++--- sqlspec/adapters/mysqlconnector/driver.py | 33 ++--- sqlspec/adapters/oracledb/driver.py | 29 ++-- sqlspec/adapters/psqlpy/driver.py | 15 +- sqlspec/adapters/psycopg/driver.py | 27 ++-- sqlspec/adapters/pymysql/driver.py | 16 +-- sqlspec/adapters/spanner/driver.py | 16 +-- sqlspec/adapters/sqlite/driver.py | 16 +-- sqlspec/driver/__init__.py | 3 + sqlspec/driver/_exception_handler.py | 78 +++++++++++ .../unit/exceptions/test_exception_handler.py | 131 ++++++++++++++++++ 19 files changed, 312 insertions(+), 238 deletions(-) create mode 100644 sqlspec/driver/_exception_handler.py create mode 100644 tests/unit/exceptions/test_exception_handler.py diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index dbcd4d5a5..70e01076d 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -7,8 +7,6 @@ import contextlib from typing import TYPE_CHECKING, Any, Literal, cast -from typing_extensions import Self - from sqlspec.adapters.adbc._typing import AdbcSessionContext from sqlspec.adapters.adbc.core import ( collect_rows, @@ -29,7 +27,7 @@ ) from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile -from sqlspec.driver import SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow @@ -71,7 +69,7 @@ def __exit__(self, *_: Any) -> None: self.cursor.close() # type: ignore[no-untyped-call] -class AdbcExceptionHandler: +class AdbcExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling ADBC database exceptions. ADBC propagates underlying database errors. Exception mapping @@ -82,16 +80,9 @@ class AdbcExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False self.pending_exception = create_mapped_exception(exc_val) diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index e9745899b..776519e2d 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -21,7 +21,7 @@ ) from sqlspec.adapters.aiosqlite.data_dictionary import AiosqliteDataDictionary from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import AsyncDriverAdapterBase +from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler from sqlspec.exceptions import SQLSpecError if TYPE_CHECKING: @@ -30,8 +30,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from typing_extensions import Self - from sqlspec.adapters.aiosqlite._typing import AiosqliteSessionContext __all__ = ("AiosqliteCursor", "AiosqliteDriver", "AiosqliteExceptionHandler", "AiosqliteSessionContext") @@ -67,7 +65,7 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc await self.cursor.close() -class AiosqliteExceptionHandler: +class AiosqliteExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling aiosqlite database exceptions. Maps SQLite extended result codes to specific SQLSpec exceptions @@ -78,17 +76,10 @@ class AiosqliteExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_val is None: - return False + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + _ = exc_type if isinstance(exc_val, (aiosqlite.Error, sqlite3.Error)): self.pending_exception = create_mapped_exception(exc_val) return True diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index 5f0872ea0..b01e8b3fc 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -28,7 +28,7 @@ ) from sqlspec.adapters.asyncmy.data_dictionary import AsyncmyDataDictionary from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import AsyncDriverAdapterBase +from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler from sqlspec.exceptions import SQLSpecError from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json @@ -42,8 +42,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from typing_extensions import Self - from sqlspec.adapters.asyncmy._typing import AsyncmySessionContext __all__ = ("AsyncmyCursor", "AsyncmyDriver", "AsyncmyExceptionHandler", "AsyncmySessionContext") @@ -77,7 +75,7 @@ async def __aexit__(self, *_: Any) -> None: await self.cursor.close() -class AsyncmyExceptionHandler: +class AsyncmyExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling asyncmy (MySQL) database exceptions. Maps MySQL error codes and SQLSTATE to specific SQLSpec exceptions @@ -88,15 +86,9 @@ class AsyncmyExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, asyncmy.errors.Error): diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 2026f4453..24976bb17 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -28,7 +28,12 @@ is_copy_operation, register_driver_profile, ) -from sqlspec.driver import AsyncDriverAdapterBase, StackExecutionObserver, describe_stack_statement +from sqlspec.driver import ( + AsyncDriverAdapterBase, + BaseAsyncExceptionHandler, + StackExecutionObserver, + describe_stack_statement, +) from sqlspec.exceptions import SQLSpecError, StackExecutionError from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import has_sqlstate @@ -41,8 +46,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from typing_extensions import Self - from sqlspec.adapters.asyncpg._typing import AsyncpgSessionContext __all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext") @@ -64,7 +67,7 @@ async def __aenter__(self) -> "AsyncpgConnection": async def __aexit__(self, *_: Any) -> None: ... -class AsyncpgExceptionHandler: +class AsyncpgExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling AsyncPG database exceptions. Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions @@ -75,17 +78,10 @@ class AsyncpgExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_val is None: - return False + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + _ = exc_type if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val): self.pending_exception = create_mapped_exception(exc_val) return True diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 113bd1c43..35390439b 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -38,7 +38,7 @@ get_cache_config, register_driver_profile, ) -from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, ExecutionResult, SyncDriverAdapterBase from sqlspec.exceptions import MissingDependencyError, SQLSpecError, StorageCapabilityError from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow @@ -48,7 +48,6 @@ from collections.abc import Callable from google.cloud.bigquery import QueryJob, QueryJobConfig - from typing_extensions import Self from sqlspec.builder import QueryBuilder from sqlspec.core import SQL, ArrowResult, SQLResult, Statement, StatementFilter @@ -91,7 +90,7 @@ def __exit__(self, *_: Any) -> None: logger.exception("Failed to cancel BigQuery job during cursor cleanup") -class BigQueryExceptionHandler: +class BigQueryExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling BigQuery API exceptions. Maps HTTP status codes and error reasons to specific SQLSpec exceptions @@ -102,16 +101,9 @@ class BigQueryExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) + __slots__ = () - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> "Self": - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, GoogleCloudError): diff --git a/sqlspec/adapters/cockroach_asyncpg/driver.py b/sqlspec/adapters/cockroach_asyncpg/driver.py index a31c7d263..707b5efd6 100644 --- a/sqlspec/adapters/cockroach_asyncpg/driver.py +++ b/sqlspec/adapters/cockroach_asyncpg/driver.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, cast import asyncpg -from typing_extensions import Self from sqlspec.adapters.asyncpg.core import create_mapped_exception, driver_profile from sqlspec.adapters.asyncpg.driver import AsyncpgDriver @@ -17,6 +16,7 @@ ) from sqlspec.adapters.cockroach_asyncpg.data_dictionary import CockroachAsyncpgDataDictionary from sqlspec.core import SQL, register_driver_profile +from sqlspec.driver import BaseAsyncExceptionHandler from sqlspec.exceptions import SerializationConflictError, TransactionRetryError from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import has_sqlstate @@ -33,20 +33,13 @@ logger = get_logger("sqlspec.adapters.cockroach_asyncpg") -class CockroachAsyncpgExceptionHandler: +class CockroachAsyncpgExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for CockroachDB AsyncPG exceptions.""" - __slots__ = ("pending_exception",) + __slots__ = () - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - if exc_val is None: - return False + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + _ = exc_type if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val): if has_sqlstate(exc_val) and str(exc_val.sqlstate) == "40001": self.pending_exception = SerializationConflictError(str(exc_val)) diff --git a/sqlspec/adapters/cockroach_psycopg/driver.py b/sqlspec/adapters/cockroach_psycopg/driver.py index 23af8a484..d02350978 100644 --- a/sqlspec/adapters/cockroach_psycopg/driver.py +++ b/sqlspec/adapters/cockroach_psycopg/driver.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, cast import psycopg -from typing_extensions import Self from sqlspec.adapters.cockroach_psycopg._typing import ( CockroachAsyncConnection, @@ -29,6 +28,7 @@ from sqlspec.adapters.psycopg.core import create_mapped_exception from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver from sqlspec.core import SQL, StatementConfig, get_cache_config, register_driver_profile +from sqlspec.driver import BaseAsyncExceptionHandler, BaseSyncExceptionHandler from sqlspec.exceptions import SerializationConflictError, TransactionRetryError from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import has_sqlstate @@ -50,18 +50,12 @@ logger = get_logger("sqlspec.adapters.cockroach_psycopg") -class CockroachPsycopgSyncExceptionHandler: +class CockroachPsycopgSyncExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling CockroachDB psycopg exceptions.""" - __slots__ = ("pending_exception",) + __slots__ = () - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): @@ -73,18 +67,12 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: return False -class CockroachPsycopgAsyncExceptionHandler: +class CockroachPsycopgAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling CockroachDB psycopg exceptions.""" - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index fe8421b07..24d0e6a62 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -18,7 +18,7 @@ from sqlspec.adapters.duckdb.data_dictionary import DuckDBDataDictionary from sqlspec.adapters.duckdb.type_converter import DuckDBOutputConverter from sqlspec.core import SQL, StatementConfig, build_arrow_result_from_table, get_cache_config, register_driver_profile -from sqlspec.driver import SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import DatabaseConnectionError, SQLSpecError from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import ensure_pyarrow @@ -31,8 +31,6 @@ from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry from sqlspec.typing import ArrowReturnFormat, StatementParameters -from typing_extensions import Self - from sqlspec.adapters.duckdb._typing import DuckDBSessionContext __all__ = ("DuckDBCursor", "DuckDBDriver", "DuckDBExceptionHandler", "DuckDBSessionContext") @@ -64,7 +62,7 @@ def __exit__(self, *_: Any) -> None: pass # Connection lifecycle managed by pool/session -class DuckDBExceptionHandler: +class DuckDBExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling DuckDB database exceptions. Uses exception type and message-based detection to map DuckDB errors @@ -75,16 +73,9 @@ class DuckDBExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False self.pending_exception = create_mapped_exception(exc_type, exc_val) diff --git a/sqlspec/adapters/mock/driver.py b/sqlspec/adapters/mock/driver.py index 5fd6dbb9d..08fd0a8b1 100644 --- a/sqlspec/adapters/mock/driver.py +++ b/sqlspec/adapters/mock/driver.py @@ -10,8 +10,6 @@ import sqlite3 from typing import TYPE_CHECKING, Any -from typing_extensions import Self - from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockSyncSessionContext from sqlspec.adapters.mock.core import ( build_insert_statement, @@ -26,7 +24,13 @@ ) from sqlspec.adapters.mock.data_dictionary import MockAsyncDataDictionary, MockDataDictionary from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase, convert_to_dialect +from sqlspec.driver import ( + AsyncDriverAdapterBase, + BaseAsyncExceptionHandler, + BaseSyncExceptionHandler, + SyncDriverAdapterBase, + convert_to_dialect, +) from sqlspec.exceptions import SQLSpecError from sqlspec.utils.sync_tools import async_ @@ -111,7 +115,7 @@ async def __aexit__( self.cursor.close() -class MockExceptionHandler: +class MockExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling SQLite database exceptions. Maps SQLite extended result codes to specific SQLSpec exceptions @@ -122,15 +126,9 @@ class MockExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None + __slots__ = () - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): @@ -139,21 +137,15 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: return False -class MockAsyncExceptionHandler: +class MockAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling SQLite database exceptions. Uses deferred exception pattern for mypyc compatibility. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): diff --git a/sqlspec/adapters/mysqlconnector/driver.py b/sqlspec/adapters/mysqlconnector/driver.py index 958f1b7bb..f3cfd0e08 100644 --- a/sqlspec/adapters/mysqlconnector/driver.py +++ b/sqlspec/adapters/mysqlconnector/driver.py @@ -30,7 +30,12 @@ MysqlConnectorSyncDataDictionary, ) from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase +from sqlspec.driver import ( + AsyncDriverAdapterBase, + BaseAsyncExceptionHandler, + BaseSyncExceptionHandler, + SyncDriverAdapterBase, +) from sqlspec.exceptions import SQLSpecError from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json @@ -44,8 +49,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from typing_extensions import Self - from sqlspec.adapters.mysqlconnector._typing import MysqlConnectorAsyncSessionContext, MysqlConnectorSyncSessionContext __all__ = ( @@ -83,18 +86,12 @@ def __exit__(self, *_: Any) -> None: self.cursor.close() -class MysqlConnectorSyncExceptionHandler: +class MysqlConnectorSyncExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling mysql-connector sync exceptions.""" - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None + __slots__ = () - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, mysql.connector.Error): @@ -321,18 +318,12 @@ async def __aexit__(self, *_: Any) -> None: await self.cursor.close() -class MysqlConnectorAsyncExceptionHandler: +class MysqlConnectorAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling mysql-connector exceptions.""" - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, mysql.connector.Error): diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 594aaede7..8d52499b1 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -6,7 +6,6 @@ import oracledb from oracledb import AsyncCursor, Cursor -from typing_extensions import Self from sqlspec.adapters.oracledb._typing import ( OracleAsyncConnection, @@ -44,6 +43,8 @@ ) from sqlspec.driver import ( AsyncDriverAdapterBase, + BaseAsyncExceptionHandler, + BaseSyncExceptionHandler, StackExecutionObserver, SyncDriverAdapterBase, describe_stack_statement, @@ -266,7 +267,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.cursor.close() -class OracleSyncExceptionHandler: +class OracleSyncExceptionHandler(BaseSyncExceptionHandler): """Sync Context manager for handling Oracle database exceptions. Maps Oracle ORA-XXXXX error codes to specific SQLSpec exceptions @@ -277,16 +278,9 @@ class OracleSyncExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, oracledb.DatabaseError): @@ -295,7 +289,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: return False -class OracleAsyncExceptionHandler: +class OracleAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling Oracle database exceptions. Maps Oracle ORA-XXXXX error codes to specific SQLSpec exceptions @@ -306,16 +300,9 @@ class OracleAsyncExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, oracledb.DatabaseError): diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index 63a6e3e1b..bf8d2d261 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, cast import psqlpy.exceptions -from typing_extensions import Self from sqlspec.adapters.psqlpy._typing import PsqlpySessionContext from sqlspec.adapters.psqlpy.core import ( @@ -31,7 +30,7 @@ from sqlspec.adapters.psqlpy.data_dictionary import PsqlpyDataDictionary from sqlspec.adapters.psqlpy.type_converter import PostgreSQLOutputConverter from sqlspec.core import SQL, StatementConfig, get_cache_config, register_driver_profile -from sqlspec.driver import AsyncDriverAdapterBase +from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler from sqlspec.exceptions import SQLSpecError from sqlspec.utils.logging import get_logger @@ -87,7 +86,7 @@ def is_in_use(self) -> bool: return self._in_use -class PsqlpyExceptionHandler: +class PsqlpyExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling psqlpy database exceptions. Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions @@ -98,15 +97,9 @@ class PsqlpyExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) + __slots__ = () - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.Error)): diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 59b1bfa5a..bc4b2b69d 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, cast import psycopg -from typing_extensions import Self from sqlspec.adapters.psycopg._typing import ( PsycopgAsyncConnection, @@ -45,6 +44,8 @@ ) from sqlspec.driver import ( AsyncDriverAdapterBase, + BaseAsyncExceptionHandler, + BaseSyncExceptionHandler, StackExecutionObserver, SyncDriverAdapterBase, describe_stack_statement, @@ -131,7 +132,7 @@ def __exit__(self, *_: Any) -> None: self.cursor.close() -class PsycopgSyncExceptionHandler: +class PsycopgSyncExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling PostgreSQL psycopg database exceptions. Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions @@ -142,15 +143,9 @@ class PsycopgSyncExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): @@ -605,7 +600,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.cursor.close() -class PsycopgAsyncExceptionHandler: +class PsycopgAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling PostgreSQL psycopg database exceptions. Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions @@ -616,15 +611,9 @@ class PsycopgAsyncExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - async def __aenter__(self) -> Self: - return self + __slots__ = () - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): diff --git a/sqlspec/adapters/pymysql/driver.py b/sqlspec/adapters/pymysql/driver.py index cbeccaffe..6886b1c9e 100644 --- a/sqlspec/adapters/pymysql/driver.py +++ b/sqlspec/adapters/pymysql/driver.py @@ -23,7 +23,7 @@ ) from sqlspec.adapters.pymysql.data_dictionary import PyMysqlDataDictionary from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile -from sqlspec.driver import SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import SQLSpecError from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json @@ -37,8 +37,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from typing_extensions import Self - from sqlspec.adapters.pymysql._typing import PyMysqlSessionContext __all__ = ("PyMysqlCursor", "PyMysqlDriver", "PyMysqlExceptionHandler", "PyMysqlSessionContext") @@ -67,18 +65,12 @@ def __exit__(self, *_: Any) -> None: self.cursor.close() -class PyMysqlExceptionHandler: +class PyMysqlExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling PyMySQL exceptions.""" - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, pymysql.MySQLError): diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index bfb7c8f08..d185e1311 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -5,7 +5,6 @@ from google.api_core import exceptions as api_exceptions from google.cloud.spanner_v1.transaction import Transaction -from typing_extensions import Self from sqlspec.adapters.spanner._typing import SpannerSessionContext from sqlspec.adapters.spanner.core import ( @@ -24,7 +23,7 @@ from sqlspec.adapters.spanner.data_dictionary import SpannerDataDictionary from sqlspec.adapters.spanner.type_converter import SpannerOutputConverter from sqlspec.core import StatementConfig, create_arrow_result, register_driver_profile -from sqlspec.driver import ExecutionResult, SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, ExecutionResult, SyncDriverAdapterBase from sqlspec.exceptions import SQLConversionError from sqlspec.utils.serializers import from_json @@ -76,7 +75,7 @@ def commit(self) -> None: ... def rollback(self) -> None: ... -class SpannerExceptionHandler: +class SpannerExceptionHandler(BaseSyncExceptionHandler): """Map Spanner client exceptions to SQLSpec exceptions. Uses deferred exception pattern for mypyc compatibility: exceptions @@ -84,16 +83,9 @@ class SpannerExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) + __slots__ = () - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: - _ = exc_tb + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 13c76fd18..c3d4c3ed9 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -4,8 +4,6 @@ import sqlite3 from typing import TYPE_CHECKING, Any -from typing_extensions import Self - from sqlspec.adapters.sqlite._typing import SqliteSessionContext from sqlspec.adapters.sqlite.core import ( build_insert_statement, @@ -21,7 +19,7 @@ from sqlspec.adapters.sqlite.data_dictionary import SqliteDataDictionary from sqlspec.core import ArrowResult, ParameterStyle, TypedParameter, get_cache_config, register_driver_profile from sqlspec.core.result import DMLResult -from sqlspec.driver import SyncDriverAdapterBase +from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase from sqlspec.exceptions import SQLSpecError if TYPE_CHECKING: @@ -78,7 +76,7 @@ def __exit__(self, *_: Any) -> None: self.cursor.close() -class SqliteExceptionHandler: +class SqliteExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling SQLite database exceptions. Maps SQLite extended result codes to specific SQLSpec exceptions @@ -89,15 +87,9 @@ class SqliteExceptionHandler: to avoid ABI boundary violations with compiled code. """ - __slots__ = ("pending_exception",) - - def __init__(self) -> None: - self.pending_exception: Exception | None = None - - def __enter__(self) -> Self: - return self + __slots__ = () - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): diff --git a/sqlspec/driver/__init__.py b/sqlspec/driver/__init__.py index 61d2bd312..58647e340 100644 --- a/sqlspec/driver/__init__.py +++ b/sqlspec/driver/__init__.py @@ -10,12 +10,15 @@ describe_stack_statement, hash_stack_operations, ) +from sqlspec.driver._exception_handler import BaseAsyncExceptionHandler, BaseSyncExceptionHandler from sqlspec.driver._sql_helpers import convert_to_dialect from sqlspec.driver._sync import SyncDataDictionaryBase, SyncDriverAdapterBase __all__ = ( "AsyncDataDictionaryBase", "AsyncDriverAdapterBase", + "BaseAsyncExceptionHandler", + "BaseSyncExceptionHandler", "CommonDriverAttributesMixin", "DataDictionaryDialectMixin", "DataDictionaryMixin", diff --git a/sqlspec/driver/_exception_handler.py b/sqlspec/driver/_exception_handler.py new file mode 100644 index 000000000..b7e494805 --- /dev/null +++ b/sqlspec/driver/_exception_handler.py @@ -0,0 +1,78 @@ +"""Shared exception handler bases for driver adapters.""" + +from typing import Any + +from mypy_extensions import mypyc_attr +from typing_extensions import Self + + +@mypyc_attr(allow_interpreted_subclasses=True) +class BaseAsyncExceptionHandler: + """Base async exception handler using the deferred exception pattern.""" + + __slots__ = ("pending_exception",) + + def __init__(self) -> None: + self.pending_exception: Exception | None = None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: "type[BaseException] | None", + exc_val: "BaseException | None", + exc_tb: Any, + ) -> bool: + _ = exc_tb + if exc_val is None: + return False + return self._handle_exception(exc_type, exc_val) + + def _handle_exception( + self, + exc_type: "type[BaseException] | None", + exc_val: "BaseException", + ) -> bool: + """Handle an adapter exception. + + Subclasses should set ``pending_exception`` before returning ``True``. + """ + _ = (exc_type, exc_val) + return False + + +@mypyc_attr(allow_interpreted_subclasses=True) +class BaseSyncExceptionHandler: + """Base sync exception handler using the deferred exception pattern.""" + + __slots__ = ("pending_exception",) + + def __init__(self) -> None: + self.pending_exception: Exception | None = None + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: "type[BaseException] | None", + exc_val: "BaseException | None", + exc_tb: Any, + ) -> bool: + _ = exc_tb + if exc_val is None: + return False + return self._handle_exception(exc_type, exc_val) + + def _handle_exception( + self, + exc_type: "type[BaseException] | None", + exc_val: "BaseException", + ) -> bool: + """Handle an adapter exception. + + Subclasses should set ``pending_exception`` before returning ``True``. + """ + _ = (exc_type, exc_val) + return False diff --git a/tests/unit/exceptions/test_exception_handler.py b/tests/unit/exceptions/test_exception_handler.py new file mode 100644 index 000000000..c6481ca8d --- /dev/null +++ b/tests/unit/exceptions/test_exception_handler.py @@ -0,0 +1,131 @@ +# pyright: reportPrivateUsage=false +"""Tests for shared exception handler bases and representative adapter handlers.""" + +import pytest + +from sqlspec.driver import BaseAsyncExceptionHandler, BaseSyncExceptionHandler +from sqlspec.exceptions import SerializationConflictError + +pytestmark = pytest.mark.xdist_group("driver") + + +def test_base_sync_exception_handler_defaults_to_passthrough() -> None: + """Base sync handler should not suppress or map without an override.""" + handler = BaseSyncExceptionHandler() + + assert handler.__enter__() is handler + assert handler.pending_exception is None + assert handler.__exit__(None, None, None) is False + assert handler.pending_exception is None + + +@pytest.mark.anyio +async def test_base_async_exception_handler_defaults_to_passthrough() -> None: + """Base async handler should not suppress or map without an override.""" + handler = BaseAsyncExceptionHandler() + + assert await handler.__aenter__() is handler + assert handler.pending_exception is None + assert await handler.__aexit__(None, None, None) is False + assert handler.pending_exception is None + + +def test_sync_exception_handlers_inherit_shared_base() -> None: + """Representative sync handlers should inherit the shared base.""" + from sqlspec.adapters.bigquery.driver import BigQueryExceptionHandler + from sqlspec.adapters.mock.driver import MockExceptionHandler + from sqlspec.adapters.sqlite.driver import SqliteExceptionHandler + + assert issubclass(BigQueryExceptionHandler, BaseSyncExceptionHandler) + assert issubclass(MockExceptionHandler, BaseSyncExceptionHandler) + assert issubclass(SqliteExceptionHandler, BaseSyncExceptionHandler) + + +def test_async_exception_handlers_inherit_shared_base() -> None: + """Representative async handlers should inherit the shared base.""" + from sqlspec.adapters.aiosqlite.driver import AiosqliteExceptionHandler + from sqlspec.adapters.asyncpg.driver import AsyncpgExceptionHandler + from sqlspec.adapters.mock.driver import MockAsyncExceptionHandler + + assert issubclass(AiosqliteExceptionHandler, BaseAsyncExceptionHandler) + assert issubclass(AsyncpgExceptionHandler, BaseAsyncExceptionHandler) + assert issubclass(MockAsyncExceptionHandler, BaseAsyncExceptionHandler) + + +def test_duckdb_exception_handler_maps_any_present_exception(monkeypatch: pytest.MonkeyPatch) -> None: + """DuckDB handler should map any exception when one is present.""" + pytest.importorskip("duckdb") + from sqlspec.adapters.duckdb import driver as duckdb_driver + + mapped = RuntimeError("mapped") + seen: dict[str, object] = {} + + def fake_create_mapped_exception(exc_type: type[BaseException], exc_val: BaseException) -> Exception: + seen["exc_type"] = exc_type + seen["exc_val"] = exc_val + return mapped + + monkeypatch.setattr(duckdb_driver, "create_mapped_exception", fake_create_mapped_exception) + + error = ValueError("boom") + handler = duckdb_driver.DuckDBExceptionHandler() + + assert handler.__exit__(type(error), error, None) is True + assert handler.pending_exception is mapped + assert seen == {"exc_type": ValueError, "exc_val": error} + + +def test_mysqlconnector_sync_exception_handler_preserves_suppression(monkeypatch: pytest.MonkeyPatch) -> None: + """mysql-connector sync handler should preserve migration-suppression sentinel values.""" + pytest.importorskip("mysql.connector") + import mysql.connector + + from sqlspec.adapters.mysqlconnector import driver as mysqlconnector_driver + + monkeypatch.setattr(mysqlconnector_driver, "create_mapped_exception", lambda *args, **kwargs: True) + + error = mysql.connector.Error("skip mapping") + handler = mysqlconnector_driver.MysqlConnectorSyncExceptionHandler() + + assert handler.__exit__(type(error), error, None) is True + assert handler.pending_exception is None + + +@pytest.mark.anyio +async def test_mysqlconnector_async_exception_handler_preserves_suppression( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """mysql-connector async handler should preserve migration-suppression sentinel values.""" + pytest.importorskip("mysql.connector") + import mysql.connector + + from sqlspec.adapters.mysqlconnector import driver as mysqlconnector_driver + + monkeypatch.setattr(mysqlconnector_driver, "create_mapped_exception", lambda *args, **kwargs: True) + + error = mysql.connector.Error("skip mapping") + handler = mysqlconnector_driver.MysqlConnectorAsyncExceptionHandler() + + assert await handler.__aexit__(type(error), error, None) is True + assert handler.pending_exception is None + + +@pytest.mark.anyio +async def test_cockroach_asyncpg_exception_handler_preserves_serialization_conflicts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Cockroach asyncpg handler should keep serialization conflicts as dedicated errors.""" + pytest.importorskip("asyncpg") + from sqlspec.adapters.cockroach_asyncpg import driver as cockroach_asyncpg_driver + + class RetryableError(RuntimeError): + pass + + monkeypatch.setattr(cockroach_asyncpg_driver, "has_sqlstate", lambda exc: True) + + error = RetryableError("retry") + error.sqlstate = "40001" # type: ignore[attr-defined] + handler = cockroach_asyncpg_driver.CockroachAsyncpgExceptionHandler() + + assert await handler.__aexit__(type(error), error, None) is True + assert isinstance(handler.pending_exception, SerializationConflictError) From 579bd96a33417804ef05af2b74182a7c5a69ed33 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 17:28:58 +0000 Subject: [PATCH 25/39] refactor(config): consolidate pooled provide methods --- sqlspec/adapters/aiosqlite/config.py | 40 +- sqlspec/adapters/asyncmy/config.py | 38 +- sqlspec/adapters/asyncpg/config.py | 14 +- sqlspec/adapters/cockroach_asyncpg/config.py | 5 +- sqlspec/adapters/cockroach_psycopg/config.py | 12 +- sqlspec/adapters/duckdb/config.py | 39 +- sqlspec/adapters/mysqlconnector/config.py | 23 +- sqlspec/adapters/oracledb/config.py | 70 +--- sqlspec/adapters/psqlpy/config.py | 14 +- sqlspec/adapters/psycopg/config.py | 28 +- sqlspec/adapters/pymysql/config.py | 20 +- sqlspec/adapters/sqlite/config.py | 39 +- sqlspec/config.py | 36 +- sqlspec/core/config_runtime.py | 5 +- tests/unit/config/test_provide_methods.py | 373 +++++++++++++++++++ 15 files changed, 459 insertions(+), 297 deletions(-) create mode 100644 tests/unit/config/test_provide_methods.py diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index b251daff2..85e67f8c1 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -143,6 +143,10 @@ class AiosqliteConfig(AsyncDatabaseConfig["AiosqliteConnection", AiosqliteConnec driver_type: "ClassVar[type[AiosqliteDriver]]" = AiosqliteDriver connection_type: "ClassVar[type[AiosqliteConnection]]" = AiosqliteConnection + _connection_context_class: "ClassVar[type[AiosqliteConnectionContext]]" = AiosqliteConnectionContext + _session_factory_class: "ClassVar[type[_AiosqliteSessionFactory]]" = _AiosqliteSessionFactory + _session_context_class: "ClassVar[type[AiosqliteSessionContext]]" = AiosqliteSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -214,42 +218,6 @@ def __init__( **kwargs, ) - def provide_connection(self, *args: Any, **kwargs: Any) -> "AiosqliteConnectionContext": - """Provide an async connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - An aiosqlite connection context manager. - - """ - return AiosqliteConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "AiosqliteSessionContext": - """Provide an async driver session context manager. - - Args: - *_args: Additional arguments. - statement_config: Optional statement configuration override. - **_kwargs: Additional keyword arguments. - - Returns: - An AiosqliteDriver session context manager. - - """ - factory = _AiosqliteSessionFactory(self) - return AiosqliteSessionContext( - acquire_connection=factory.acquire_connection, - release_connection=factory.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - async def _create_pool(self) -> AiosqliteConnectionPool: """Create the connection pool instance. diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 562d34d53..bf6c67592 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -150,6 +150,10 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm driver_type: ClassVar[type[AsyncmyDriver]] = AsyncmyDriver connection_type: "ClassVar[type[Any]]" = cast("type[Any]", AsyncmyConnection) + _connection_context_class: "ClassVar[type[AsyncmyConnectionContext]]" = AsyncmyConnectionContext + _session_factory_class: "ClassVar[type[_AsyncmySessionFactory]]" = _AsyncmySessionFactory + _session_context_class: "ClassVar[type[AsyncmySessionContext]]" = AsyncmySessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True @@ -258,40 +262,6 @@ async def create_connection(self) -> AsyncmyConnection: await self._ensure_connection_initialized(connection) return connection - def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncmyConnectionContext": - """Provide an async connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - An Asyncmy connection context manager. - """ - return AsyncmyConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "AsyncmySessionContext": - """Provide an async driver session context manager. - - Args: - *_args: Additional arguments. - statement_config: Optional statement configuration override. - **_kwargs: Additional keyword arguments. - - Returns: - An Asyncmy driver session context manager. - """ - factory = _AsyncmySessionFactory(self) - return AsyncmySessionContext( - acquire_connection=factory.acquire_connection, - release_connection=factory.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": """Provide async pool instance. diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index a4436aeb8..44ac1265b 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -281,6 +281,8 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment] + _connection_context_class: "ClassVar[type[AsyncpgConnectionContext]]" = AsyncpgConnectionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -529,18 +531,6 @@ async def create_connection(self) -> "AsyncpgConnection": self.connection_instance = pool return await pool.acquire() - def provide_connection(self, *args: Any, **kwargs: Any) -> "AsyncpgConnectionContext": - """Provide an async connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - An AsyncPG connection context manager. - """ - return AsyncpgConnectionContext(self) - def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "AsyncpgSessionContext": diff --git a/sqlspec/adapters/cockroach_asyncpg/config.py b/sqlspec/adapters/cockroach_asyncpg/config.py index 2187b79d6..036f51324 100644 --- a/sqlspec/adapters/cockroach_asyncpg/config.py +++ b/sqlspec/adapters/cockroach_asyncpg/config.py @@ -143,6 +143,8 @@ class CockroachAsyncpgConfig( driver_type: "ClassVar[type[CockroachAsyncpgDriver]]" = CockroachAsyncpgDriver connection_type: "ClassVar[type[CockroachAsyncpgConnection]]" = CockroachAsyncpgConnection # type: ignore[assignment] + _connection_context_class: "ClassVar[type[CockroachAsyncpgConnectionContext]]" = CockroachAsyncpgConnectionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -208,9 +210,6 @@ async def create_connection(self) -> "CockroachAsyncpgConnection": self.connection_instance = await self.create_pool() return cast("CockroachAsyncpgConnection", await self.connection_instance.acquire()) - def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachAsyncpgConnectionContext": - return CockroachAsyncpgConnectionContext(self) - def provide_session( self, *_args: Any, diff --git a/sqlspec/adapters/cockroach_psycopg/config.py b/sqlspec/adapters/cockroach_psycopg/config.py index 67e3b8d10..9a5ae9d69 100644 --- a/sqlspec/adapters/cockroach_psycopg/config.py +++ b/sqlspec/adapters/cockroach_psycopg/config.py @@ -162,6 +162,9 @@ class CockroachPsycopgSyncConfig( driver_type: "ClassVar[type[CockroachPsycopgSyncDriver]]" = CockroachPsycopgSyncDriver connection_type: "ClassVar[type[CockroachSyncConnection]]" = CockroachSyncConnection + _connection_context_class: "ClassVar[type[CockroachPsycopgSyncConnectionContext]]" = ( + CockroachPsycopgSyncConnectionContext + ) supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -256,9 +259,6 @@ def create_connection(self) -> "CockroachSyncConnection": self.connection_instance = self.create_pool() return cast("CockroachSyncConnection", self.connection_instance.getconn()) - def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachPsycopgSyncConnectionContext": - return CockroachPsycopgSyncConnectionContext(self) - def provide_session( self, *_args: Any, @@ -358,6 +358,9 @@ class CockroachPsycopgAsyncConfig( driver_type: "ClassVar[type[CockroachPsycopgAsyncDriver]]" = CockroachPsycopgAsyncDriver connection_type: "ClassVar[type[CockroachAsyncConnection]]" = CockroachAsyncConnection + _connection_context_class: "ClassVar[type[CockroachPsycopgAsyncConnectionContext]]" = ( + CockroachPsycopgAsyncConnectionContext + ) supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -457,9 +460,6 @@ async def create_connection(self) -> "CockroachAsyncConnection": self.connection_instance = await self.create_pool() return cast("CockroachAsyncConnection", await self.connection_instance.getconn()) - def provide_connection(self, *args: Any, **kwargs: Any) -> "CockroachPsycopgAsyncConnectionContext": - return CockroachPsycopgAsyncConnectionContext(self) - def provide_session( self, *_args: Any, diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index abf6fcbab..46ec8d899 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -251,6 +251,10 @@ class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, Du driver_type: "ClassVar[type[DuckDBDriver]]" = DuckDBDriver connection_type: "ClassVar[type[DuckDBConnection]]" = DuckDBConnection + _connection_context_class: "ClassVar[type[DuckDBConnectionContext]]" = DuckDBConnectionContext + _session_factory_class: "ClassVar[type[_DuckDBSessionConnectionHandler]]" = _DuckDBSessionConnectionHandler + _session_context_class: "ClassVar[type[DuckDBSessionContext]]" = DuckDBSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -379,41 +383,6 @@ def create_connection(self) -> DuckDBConnection: return pool.acquire() - def provide_connection(self, *args: Any, **kwargs: Any) -> "DuckDBConnectionContext": - """Provide a pooled DuckDB connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A DuckDB connection context manager. - """ - return DuckDBConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "DuckDBSessionContext": - """Provide a DuckDB driver session context manager. - - Args: - *_args: Additional arguments. - statement_config: Optional statement configuration override. - **_kwargs: Additional keyword arguments. - - Returns: - A DuckDB driver session context manager. - """ - handler = _DuckDBSessionConnectionHandler(self) - - return DuckDBSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for DuckDB types. diff --git a/sqlspec/adapters/mysqlconnector/config.py b/sqlspec/adapters/mysqlconnector/config.py index 104e7a81c..41f024422 100644 --- a/sqlspec/adapters/mysqlconnector/config.py +++ b/sqlspec/adapters/mysqlconnector/config.py @@ -220,6 +220,12 @@ class MysqlConnectorSyncConfig( driver_type: ClassVar[type[MysqlConnectorSyncDriver]] = MysqlConnectorSyncDriver connection_type: ClassVar[type[MysqlConnectorSyncConnection]] = MysqlConnectorSyncConnection + _connection_context_class: "ClassVar[type[MysqlConnectorSyncConnectionContext]]" = MysqlConnectorSyncConnectionContext + _session_factory_class: "ClassVar[type[_MysqlConnectorSyncSessionConnectionHandler]]" = ( + _MysqlConnectorSyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[MysqlConnectorSyncSessionContext]]" = MysqlConnectorSyncSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True @@ -295,23 +301,6 @@ def create_connection(self) -> MysqlConnectorSyncConnection: setattr(connection, "autocommit", bool(autocommit)) return connection - def provide_connection(self, *args: Any, **kwargs: Any) -> "MysqlConnectorSyncConnectionContext": - return MysqlConnectorSyncConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "MysqlConnectorSyncSessionContext": - statement_config = statement_config or self.statement_config or default_statement_config - handler = _MysqlConnectorSyncSessionConnectionHandler(self) - - return MysqlConnectorSyncSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - def get_signature_namespace(self) -> "dict[str, Any]": namespace = super().get_signature_namespace() namespace.update({ diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 719f7c21f..01a51aea8 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -180,6 +180,10 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConne driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker + _connection_context_class: "ClassVar[type[OracleSyncConnectionContext]]" = OracleSyncConnectionContext + _session_factory_class: "ClassVar[type[_OracleSyncSessionConnectionHandler]]" = _OracleSyncSessionConnectionHandler + _session_context_class: "ClassVar[type[OracleSyncSessionContext]]" = OracleSyncSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True @@ -279,37 +283,6 @@ def create_connection(self) -> "OracleSyncConnection": self.connection_instance = self.create_pool() return self.connection_instance.acquire() - def provide_connection(self) -> "OracleSyncConnectionContext": - """Provide a connection context manager. - - Returns: - An Oracle Connection context manager. - """ - return OracleSyncConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "OracleSyncSessionContext": - """Provide a driver session context manager. - - Args: - *_args: Positional arguments (unused). - statement_config: Optional statement configuration override. - **_kwargs: Keyword arguments (unused). - - Returns: - An OracleSyncDriver session context manager. - """ - handler = _OracleSyncSessionConnectionHandler(self) - - return OracleSyncSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - def provide_pool(self) -> "OracleSyncConnectionPool": """Provide pool instance. @@ -405,6 +378,10 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncC connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker + _connection_context_class: "ClassVar[type[OracleAsyncConnectionContext]]" = OracleAsyncConnectionContext + _session_factory_class: "ClassVar[type[_OracleAsyncSessionConnectionHandler]]" = _OracleAsyncSessionConnectionHandler + _session_context_class: "ClassVar[type[OracleAsyncSessionContext]]" = OracleAsyncSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True @@ -507,37 +484,6 @@ async def create_connection(self) -> OracleAsyncConnection: self.connection_instance = await self.create_pool() return cast("OracleAsyncConnection", await self.connection_instance.acquire()) - def provide_connection(self) -> "OracleAsyncConnectionContext": - """Provide an async connection context manager. - - Returns: - An Oracle AsyncConnection context manager. - """ - return OracleAsyncConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "OracleAsyncSessionContext": - """Provide an async driver session context manager. - - Args: - *_args: Positional arguments (unused). - statement_config: Optional statement configuration override. - **_kwargs: Keyword arguments (unused). - - Returns: - An OracleAsyncDriver session context manager. - """ - handler = _OracleAsyncSessionConnectionHandler(self) - - return OracleAsyncSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - async def provide_pool(self) -> "OracleAsyncConnectionPool": """Provide async pool instance. diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 5c946fa5a..8266da9f7 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -179,6 +179,8 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD driver_type: ClassVar[type[PsqlpyDriver]] = PsqlpyDriver connection_type: "ClassVar[type[PsqlpyConnection]]" = PsqlpyConnection + _connection_context_class: "ClassVar[type[PsqlpyConnectionContext]]" = PsqlpyConnectionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True @@ -308,18 +310,6 @@ async def create_connection(self) -> "PsqlpyConnection": return await pool.connection() - def provide_connection(self, *args: Any, **kwargs: Any) -> "PsqlpyConnectionContext": - """Provide an async connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A psqlpy Connection context manager. - """ - return PsqlpyConnectionContext(self) - def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "PsqlpySessionContext": diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index dca788d02..192198086 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -182,6 +182,8 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool driver_type: "ClassVar[type[PsycopgSyncDriver]]" = PsycopgSyncDriver connection_type: "ClassVar[type[PsycopgSyncConnection]]" = PsycopgSyncConnection + _connection_context_class: "ClassVar[type[PsycopgSyncConnectionContext]]" = PsycopgSyncConnectionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -335,18 +337,6 @@ def create_connection(self) -> "PsycopgSyncConnection": self.connection_instance = self.create_pool() return cast("PsycopgSyncConnection", self.connection_instance.getconn()) # pyright: ignore - def provide_connection(self, *args: Any, **kwargs: Any) -> "PsycopgSyncConnectionContext": - """Provide a connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A psycopg Connection context manager. - """ - return PsycopgSyncConnectionContext(self) - def provide_session( self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any ) -> "PsycopgSyncSessionContext": @@ -462,6 +452,8 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec driver_type: ClassVar[type[PsycopgAsyncDriver]] = PsycopgAsyncDriver connection_type: "ClassVar[type[PsycopgAsyncConnection]]" = PsycopgAsyncConnection + _connection_context_class: "ClassVar[type[PsycopgAsyncConnectionContext]]" = PsycopgAsyncConnectionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True @@ -618,18 +610,6 @@ async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignor self.connection_instance = await self.create_pool() return cast("PsycopgAsyncConnection", await self.connection_instance.getconn()) # pyright: ignore - def provide_connection(self, *args: Any, **kwargs: Any) -> "PsycopgAsyncConnectionContext": # pyright: ignore - """Provide an async connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A psycopg AsyncConnection context manager. - """ - return PsycopgAsyncConnectionContext(self) - def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for PsycopgAsyncConfig types. diff --git a/sqlspec/adapters/pymysql/config.py b/sqlspec/adapters/pymysql/config.py index 272e08d9b..d69462e1c 100644 --- a/sqlspec/adapters/pymysql/config.py +++ b/sqlspec/adapters/pymysql/config.py @@ -117,6 +117,10 @@ class PyMysqlConfig(SyncDatabaseConfig[PyMysqlConnection, PyMysqlConnectionPool, driver_type: "ClassVar[type[PyMysqlDriver]]" = PyMysqlDriver connection_type: "ClassVar[type[PyMysqlConnection]]" = cast("type[PyMysqlConnection]", PyMysqlConnection) + _connection_context_class: "ClassVar[type[PyMysqlConnectionContext]]" = PyMysqlConnectionContext + _session_factory_class: "ClassVar[type[_PyMysqlSessionConnectionHandler]]" = _PyMysqlSessionConnectionHandler + _session_context_class: "ClassVar[type[PyMysqlSessionContext]]" = PyMysqlSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = False supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -182,22 +186,6 @@ def create_connection(self) -> PyMysqlConnection: pool = self.provide_pool() return pool.acquire() - def provide_connection(self, *args: Any, **kwargs: Any) -> "PyMysqlConnectionContext": - return PyMysqlConnectionContext(self) - - def provide_session( - self, *_args: Any, statement_config: "StatementConfig | None" = None, **_kwargs: Any - ) -> "PyMysqlSessionContext": - handler = _PyMysqlSessionConnectionHandler(self) - - return PyMysqlSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - def get_signature_namespace(self) -> "dict[str, Any]": namespace = super().get_signature_namespace() namespace.update({ diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 1d27abab8..8343ad464 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -117,6 +117,10 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq driver_type: "ClassVar[type[SqliteDriver]]" = SqliteDriver connection_type: "ClassVar[type[SqliteConnection]]" = SqliteConnection + _connection_context_class: "ClassVar[type[SqliteConnectionContext]]" = SqliteConnectionContext + _session_factory_class: "ClassVar[type[_SqliteSessionConnectionHandler]]" = _SqliteSessionConnectionHandler + _session_context_class: "ClassVar[type[SqliteSessionContext]]" = SqliteSessionContext + _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True @@ -236,41 +240,6 @@ def create_connection(self) -> SqliteConnection: pool = self.provide_pool() return pool.acquire() - def provide_connection(self, *args: "Any", **kwargs: "Any") -> "SqliteConnectionContext": - """Provide a SQLite connection context manager. - - Args: - *args: Additional arguments. - **kwargs: Additional keyword arguments. - - Returns: - A Sqlite connection context manager. - """ - return SqliteConnectionContext(self) - - def provide_session( - self, *_args: "Any", statement_config: "StatementConfig | None" = None, **_kwargs: "Any" - ) -> "SqliteSessionContext": - """Provide a SQLite driver session. - - Args: - *_args: Additional arguments. - statement_config: Optional statement configuration override. - **_kwargs: Additional keyword arguments. - - Returns: - A Sqlite driver session context manager. - """ - handler = _SqliteSessionConnectionHandler(self) - - return SqliteSessionContext( - acquire_connection=handler.acquire_connection, - release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or default_statement_config, - driver_features=self.driver_features, - prepare_driver=self._prepare_driver, - ) - def get_signature_namespace(self) -> "dict[str, Any]": """Get the signature namespace for SQLite types. diff --git a/sqlspec/config.py b/sqlspec/config.py index 765d954e0..92a932a05 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1512,6 +1512,10 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker + _connection_context_class: "ClassVar[type[Any]]" + _session_factory_class: "ClassVar[type[Any]]" + _session_context_class: "ClassVar[type[Any]]" + _default_statement_config: "ClassVar[StatementConfig]" def __init__( self, @@ -1578,13 +1582,23 @@ def create_connection(self) -> ConnectionT: def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractContextManager[ConnectionT]": """Provide a database connection context manager.""" - raise NotImplementedError + return cast("AbstractContextManager[ConnectionT]", self._connection_context_class(self)) def provide_session( self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any ) -> "AbstractContextManager[DriverT]": """Provide a database session context manager.""" - raise NotImplementedError + handler = self._session_factory_class(self) + return cast( + "AbstractContextManager[DriverT]", + self._session_context_class( + acquire_connection=handler.acquire_connection, + release_connection=handler.release_connection, + statement_config=statement_config or self.statement_config or self._default_statement_config, + driver_features=self.driver_features, + prepare_driver=self._prepare_driver, + ), + ) @abstractmethod def _create_pool(self) -> PoolT: @@ -1709,6 +1723,10 @@ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = True migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker + _connection_context_class: "ClassVar[type[Any]]" + _session_factory_class: "ClassVar[type[Any]]" + _session_context_class: "ClassVar[type[Any]]" + _default_statement_config: "ClassVar[StatementConfig]" def __init__( self, @@ -1775,13 +1793,23 @@ async def create_connection(self) -> ConnectionT: def provide_connection(self, *args: Any, **kwargs: Any) -> "AbstractAsyncContextManager[ConnectionT]": """Provide a database connection context manager.""" - raise NotImplementedError + return cast("AbstractAsyncContextManager[ConnectionT]", self._connection_context_class(self)) def provide_session( self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any ) -> "AbstractAsyncContextManager[DriverT]": """Provide a database session context manager.""" - raise NotImplementedError + handler = self._session_factory_class(self) + return cast( + "AbstractAsyncContextManager[DriverT]", + self._session_context_class( + acquire_connection=handler.acquire_connection, + release_connection=handler.release_connection, + statement_config=statement_config or self.statement_config or self._default_statement_config, + driver_features=self.driver_features, + prepare_driver=self._prepare_driver, + ), + ) @abstractmethod async def _create_pool(self) -> PoolT: diff --git a/sqlspec/core/config_runtime.py b/sqlspec/core/config_runtime.py index 2d7deeb64..e8f2a7f12 100644 --- a/sqlspec/core/config_runtime.py +++ b/sqlspec/core/config_runtime.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from sqlspec.storage import StorageCapabilities + __all__ = ( "build_default_statement_config", @@ -37,7 +39,8 @@ def build_default_statement_config(default_dialect: str) -> StatementConfig: def seed_runtime_driver_features( - driver_features: "dict[str, Any] | None", storage_capabilities: "dict[str, Any] | None" + driver_features: "dict[str, Any] | None", + storage_capabilities: "dict[str, Any] | StorageCapabilities | None", ) -> "dict[str, Any]": """Clone and seed driver feature state used on the runtime hot path.""" seeded_features = dict(driver_features) if driver_features else {} diff --git a/tests/unit/config/test_provide_methods.py b/tests/unit/config/test_provide_methods.py new file mode 100644 index 000000000..9af31c215 --- /dev/null +++ b/tests/unit/config/test_provide_methods.py @@ -0,0 +1,373 @@ +from contextlib import AbstractContextManager, asynccontextmanager, contextmanager +from typing import TYPE_CHECKING, Any, cast + +import pytest + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig, AiosqliteConnectionContext +from sqlspec.adapters.aiosqlite.driver import AiosqliteSessionContext +from sqlspec.adapters.asyncmy.config import AsyncmyConfig +from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig +from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.mysqlconnector.config import MysqlConnectorSyncConfig +from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig +from sqlspec.adapters.psqlpy.config import PsqlpyConfig +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.pymysql.config import PyMysqlConfig +from sqlspec.adapters.spanner.config import SpannerSyncConfig +from sqlspec.adapters.sqlite.config import SqliteConfig, SqliteConnectionContext +from sqlspec.adapters.sqlite.driver import SqliteSessionContext +from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig +from sqlspec.core import StatementConfig +from sqlspec.driver import ( + AsyncDataDictionaryBase, + AsyncDriverAdapterBase, + SyncDataDictionaryBase, + SyncDriverAdapterBase, +) +from tests.conftest import requires_interpreted + +if TYPE_CHECKING: + _SyncPoolConfigBase = SyncDatabaseConfig[Any, object, "_DummySyncDriver"] + _AsyncPoolConfigBase = AsyncDatabaseConfig[Any, object, "_DummyAsyncDriver"] +else: + _SyncPoolConfigBase = SyncDatabaseConfig + _AsyncPoolConfigBase = AsyncDatabaseConfig + + +class _DummySyncDriver(SyncDriverAdapterBase): + __slots__ = () + + @property + def data_dictionary(self) -> SyncDataDictionaryBase: # type: ignore[override] + raise NotImplementedError + + def with_cursor(self, connection: Any) -> AbstractContextManager[Any]: # type: ignore[override] + @contextmanager + def _cursor_ctx(): + yield object() + + return _cursor_ctx() + + def handle_database_exceptions(self) -> AbstractContextManager[None]: # type: ignore[override] + @contextmanager + def _handler_ctx(): + yield None + + return _handler_ctx() + + def begin(self) -> None: # type: ignore[override] + raise NotImplementedError + + def rollback(self) -> None: # type: ignore[override] + raise NotImplementedError + + def commit(self) -> None: # type: ignore[override] + raise NotImplementedError + + def dispatch_special_handling(self, cursor: Any, statement: Any): # type: ignore[override] + return None + + def dispatch_execute_script(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + def dispatch_execute_many(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + def dispatch_execute(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + +class _DummyAsyncDriver(AsyncDriverAdapterBase): + __slots__ = () + + @property + def data_dictionary(self) -> AsyncDataDictionaryBase: # type: ignore[override] + raise NotImplementedError + + @asynccontextmanager + async def with_cursor(self, connection: Any): # type: ignore[override] + yield object() + + @asynccontextmanager + async def handle_database_exceptions(self): # type: ignore[override] + yield None + + async def begin(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def rollback(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def commit(self) -> None: # type: ignore[override] + raise NotImplementedError + + async def dispatch_special_handling(self, cursor: Any, statement: Any): # type: ignore[override] + return None + + async def dispatch_execute_script(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + async def dispatch_execute_many(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + async def dispatch_execute(self, cursor: Any, statement: Any): # type: ignore[override] + raise NotImplementedError + + +class _SyncConnectionContext: + __slots__ = ("config",) + + def __init__(self, config: "_SyncTemplateConfig") -> None: + self.config = config + + +class _SyncSessionHandler: + __slots__ = ("config",) + + def __init__(self, config: "_SyncTemplateConfig") -> None: + self.config = config + + def acquire_connection(self) -> object: + return object() + + def release_connection(self, _connection: object) -> None: + return None + + +class _SyncSessionContext: + __slots__ = ("acquire_connection", "driver_features", "prepare_driver", "release_connection", "statement_config") + + def __init__( + self, + *, + acquire_connection: Any, + release_connection: Any, + statement_config: StatementConfig, + driver_features: dict[str, Any], + prepare_driver: Any, + ) -> None: + self.acquire_connection = acquire_connection + self.release_connection = release_connection + self.statement_config = statement_config + self.driver_features = driver_features + self.prepare_driver = prepare_driver + + +class _AsyncConnectionContext: + __slots__ = ("config",) + + def __init__(self, config: "_AsyncTemplateConfig") -> None: + self.config = config + + +class _AsyncSessionHandler: + __slots__ = ("config",) + + def __init__(self, config: "_AsyncTemplateConfig") -> None: + self.config = config + + async def acquire_connection(self) -> object: + return object() + + async def release_connection(self, _connection: object) -> None: + return None + + +class _AsyncSessionContext: + __slots__ = ("acquire_connection", "driver_features", "prepare_driver", "release_connection", "statement_config") + + def __init__( + self, + *, + acquire_connection: Any, + release_connection: Any, + statement_config: StatementConfig, + driver_features: dict[str, Any], + prepare_driver: Any, + ) -> None: + self.acquire_connection = acquire_connection + self.release_connection = release_connection + self.statement_config = statement_config + self.driver_features = driver_features + self.prepare_driver = prepare_driver + + +class _SyncTemplateConfig(_SyncPoolConfigBase): + driver_type = _DummySyncDriver + connection_type = object + _connection_context_class = _SyncConnectionContext + _session_factory_class = _SyncSessionHandler + _session_context_class = _SyncSessionContext + _default_statement_config = StatementConfig(dialect="sqlite") + + def create_connection(self) -> object: + return object() + + def _create_pool(self) -> object: + return object() + + def _close_pool(self) -> None: + return None + + +class _AsyncTemplateConfig(_AsyncPoolConfigBase): + driver_type = _DummyAsyncDriver + connection_type = object + _connection_context_class = _AsyncConnectionContext + _session_factory_class = _AsyncSessionHandler + _session_context_class = _AsyncSessionContext + _default_statement_config = StatementConfig(dialect="sqlite") + + async def create_connection(self) -> object: + return object() + + async def _create_pool(self) -> object: + return object() + + async def _close_pool(self) -> None: + return None + + +@requires_interpreted +def test_sync_database_config_template_provides_connection_and_session() -> None: + config = _SyncTemplateConfig( + statement_config=StatementConfig(dialect="postgres"), + driver_features={"enable_events": True}, + ) + + connection_context = config.provide_connection() + session_context = cast(_SyncSessionContext, config.provide_session()) + + assert isinstance(connection_context, _SyncConnectionContext) + assert connection_context.config is config + assert isinstance(session_context, _SyncSessionContext) + assert session_context.statement_config.dialect == "postgres" + assert session_context.driver_features["enable_events"] is True + assert session_context.prepare_driver.__self__ is config + + +@requires_interpreted +def test_sync_database_config_template_uses_default_statement_config_when_unset() -> None: + config = _SyncTemplateConfig(statement_config=None) + cast(Any, config).statement_config = None + + session_context = cast(_SyncSessionContext, config.provide_session()) + + assert session_context.statement_config.dialect == "sqlite" + + +@pytest.mark.anyio +@requires_interpreted +async def test_async_database_config_template_provides_connection_and_session() -> None: + config = _AsyncTemplateConfig( + statement_config=StatementConfig(dialect="postgres"), + driver_features={"enable_events": True}, + ) + + connection_context = config.provide_connection() + session_context = cast(_AsyncSessionContext, config.provide_session()) + + assert isinstance(connection_context, _AsyncConnectionContext) + assert connection_context.config is config + assert isinstance(session_context, _AsyncSessionContext) + assert session_context.statement_config.dialect == "postgres" + assert session_context.driver_features["enable_events"] is True + assert session_context.prepare_driver.__self__ is config + + +@pytest.mark.anyio +@requires_interpreted +async def test_async_database_config_template_uses_explicit_statement_override() -> None: + config = _AsyncTemplateConfig(statement_config=None) + explicit_config = StatementConfig(dialect="mysql") + + session_context = cast(_AsyncSessionContext, config.provide_session(statement_config=explicit_config)) + + assert session_context.statement_config is explicit_config + + +def test_sqlite_config_inherited_base_methods_build_expected_contexts() -> None: + config = SqliteConfig() + config.statement_config = config.statement_config.replace(dialect="sqlite") + + connection_context = config.provide_connection() + session_context = config.provide_session() + + assert isinstance(connection_context, SqliteConnectionContext) + assert isinstance(session_context, SqliteSessionContext) + assert session_context._statement_config is config.statement_config # pyright: ignore[reportPrivateUsage] + + +@pytest.mark.anyio +async def test_aiosqlite_config_inherited_base_methods_build_expected_contexts() -> None: + config = AiosqliteConfig() + config.statement_config = config.statement_config.replace(dialect="sqlite") + + connection_context = config.provide_connection() + session_context = config.provide_session() + + assert isinstance(connection_context, AiosqliteConnectionContext) + assert isinstance(session_context, AiosqliteSessionContext) + assert session_context._statement_config is config.statement_config # pyright: ignore[reportPrivateUsage] + + +@pytest.mark.parametrize( + ("config_type", "base_method"), + [ + (SqliteConfig, SyncDatabaseConfig.provide_connection), + (PyMysqlConfig, SyncDatabaseConfig.provide_connection), + (DuckDBConfig, SyncDatabaseConfig.provide_connection), + (OracleSyncConfig, SyncDatabaseConfig.provide_connection), + (MysqlConnectorSyncConfig, SyncDatabaseConfig.provide_connection), + (AiosqliteConfig, AsyncDatabaseConfig.provide_connection), + (AsyncmyConfig, AsyncDatabaseConfig.provide_connection), + (AsyncpgConfig, AsyncDatabaseConfig.provide_connection), + (PsqlpyConfig, AsyncDatabaseConfig.provide_connection), + (PsycopgSyncConfig, SyncDatabaseConfig.provide_connection), + (PsycopgAsyncConfig, AsyncDatabaseConfig.provide_connection), + (OracleAsyncConfig, AsyncDatabaseConfig.provide_connection), + (CockroachAsyncpgConfig, AsyncDatabaseConfig.provide_connection), + (CockroachPsycopgSyncConfig, SyncDatabaseConfig.provide_connection), + (CockroachPsycopgAsyncConfig, AsyncDatabaseConfig.provide_connection), + ], +) +def test_pooled_adapters_inherit_base_provide_connection(config_type: type[Any], base_method: Any) -> None: + assert config_type.provide_connection is base_method + + +@pytest.mark.parametrize( + "config_type", + [ + SqliteConfig, + PyMysqlConfig, + DuckDBConfig, + AiosqliteConfig, + AsyncmyConfig, + OracleSyncConfig, + OracleAsyncConfig, + MysqlConnectorSyncConfig, + ], +) +def test_template_only_adapters_inherit_base_provide_session(config_type: type[Any]) -> None: + base_method = SyncDatabaseConfig.provide_session if issubclass(config_type, SyncDatabaseConfig) else AsyncDatabaseConfig.provide_session + assert config_type.provide_session is base_method + + +@pytest.mark.parametrize( + "config_type", + [ + AsyncpgConfig, + PsqlpyConfig, + PsycopgSyncConfig, + PsycopgAsyncConfig, + CockroachAsyncpgConfig, + CockroachPsycopgSyncConfig, + CockroachPsycopgAsyncConfig, + SpannerSyncConfig, + ], +) +def test_specialized_adapters_keep_provide_session_override(config_type: type[Any]) -> None: + base_method = SyncDatabaseConfig.provide_session if issubclass(config_type, SyncDatabaseConfig) else AsyncDatabaseConfig.provide_session + assert config_type.provide_session is not base_method From e71b124e8e535c6b4bf51513c83e2d31243e94ee Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 17:31:06 +0000 Subject: [PATCH 26/39] docs: Remove `Raises` section from `_validate_parameters` docstring. --- sqlspec/core/compiler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sqlspec/core/compiler.py b/sqlspec/core/compiler.py index 26161839d..94553ff83 100644 --- a/sqlspec/core/compiler.py +++ b/sqlspec/core/compiler.py @@ -775,9 +775,6 @@ def _validate_parameters(self, parameter_profile: "ParameterProfile", final_para parameter_profile: Parameter metadata. final_params: Execution parameters. is_many: Whether this is for execute_many. - - Raises: - Exception: Re-raises validation errors from parameter alignment. """ try: validate_parameter_alignment(parameter_profile, final_params, is_many=is_many) From c7c935290d800b1944ff519fd12d2f76215143c3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 17:51:50 +0000 Subject: [PATCH 27/39] refactor: improve type hints for exception handling parameters in adapter drivers and configs. --- pyproject.toml | 1 - sqlspec/adapters/adbc/_typing.py | 3 +- sqlspec/adapters/adbc/config.py | 3 +- sqlspec/adapters/adbc/driver.py | 2 +- sqlspec/adapters/aiosqlite/_typing.py | 3 +- sqlspec/adapters/aiosqlite/config.py | 3 +- sqlspec/adapters/aiosqlite/driver.py | 6 +- sqlspec/adapters/asyncmy/_typing.py | 3 +- sqlspec/adapters/asyncmy/config.py | 3 +- sqlspec/adapters/asyncmy/driver.py | 2 +- sqlspec/adapters/asyncpg/_typing.py | 3 +- sqlspec/adapters/asyncpg/config.py | 3 +- sqlspec/adapters/asyncpg/driver.py | 2 +- sqlspec/adapters/bigquery/_typing.py | 3 +- sqlspec/adapters/bigquery/config.py | 3 +- sqlspec/adapters/bigquery/driver.py | 2 +- sqlspec/adapters/cockroach_asyncpg/_typing.py | 3 +- sqlspec/adapters/cockroach_asyncpg/config.py | 3 +- sqlspec/adapters/cockroach_asyncpg/driver.py | 2 +- sqlspec/adapters/cockroach_psycopg/_typing.py | 5 +- sqlspec/adapters/cockroach_psycopg/config.py | 5 +- sqlspec/adapters/cockroach_psycopg/driver.py | 4 +- sqlspec/adapters/duckdb/_typing.py | 3 +- sqlspec/adapters/duckdb/config.py | 3 +- sqlspec/adapters/duckdb/core.py | 2 +- sqlspec/adapters/duckdb/driver.py | 2 +- sqlspec/adapters/mock/_typing.py | 5 +- sqlspec/adapters/mock/config.py | 5 +- sqlspec/adapters/mock/driver.py | 8 ++- sqlspec/adapters/mysqlconnector/_typing.py | 5 +- sqlspec/adapters/mysqlconnector/config.py | 5 +- sqlspec/adapters/mysqlconnector/driver.py | 4 +- sqlspec/adapters/oracledb/_typing.py | 5 +- sqlspec/adapters/oracledb/config.py | 5 +- sqlspec/adapters/oracledb/driver.py | 7 +- sqlspec/adapters/psqlpy/_typing.py | 3 +- sqlspec/adapters/psqlpy/config.py | 3 +- sqlspec/adapters/psqlpy/driver.py | 2 +- sqlspec/adapters/psycopg/_typing.py | 5 +- sqlspec/adapters/psycopg/config.py | 5 +- sqlspec/adapters/psycopg/driver.py | 20 +++--- sqlspec/adapters/pymysql/_typing.py | 3 +- sqlspec/adapters/pymysql/config.py | 3 +- sqlspec/adapters/pymysql/driver.py | 2 +- sqlspec/adapters/spanner/_typing.py | 3 +- sqlspec/adapters/spanner/config.py | 5 +- sqlspec/adapters/spanner/driver.py | 2 +- sqlspec/adapters/sqlite/_typing.py | 3 +- sqlspec/adapters/sqlite/config.py | 3 +- sqlspec/adapters/sqlite/driver.py | 2 +- sqlspec/base.py | 11 +-- sqlspec/driver/_common.py | 19 ++--- sqlspec/driver/_exception_handler.py | 9 ++- sqlspec/storage/backends/base.py | 7 +- {sqlspec/utils => tools}/profiling.py | 71 +++++++------------ tools/scripts/bench_subsystems.py | 2 +- 56 files changed, 168 insertions(+), 141 deletions(-) rename {sqlspec/utils => tools}/profiling.py (64%) diff --git a/pyproject.toml b/pyproject.toml index a6eb731bd..baea0b24f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,7 +199,6 @@ exclude = [ "sqlspec/adapters/**/data_dictionary.py", # Cross-module inheritance causes mypyc segfaults "sqlspec/observability/_formatting.py", # Inherits from non-compiled logging.Formatter "sqlspec/utils/arrow_helpers.py", # Arrow operations cause segfaults when compiled - "sqlspec/utils/profiling.py", # Uses sys.setprofile (dynamic, not mypyc compatible) ] include = [ "sqlspec/core/**/*.py", # Core module diff --git a/sqlspec/adapters/adbc/_typing.py b/sqlspec/adapters/adbc/_typing.py index c5e64fb2a..48aff526c 100644 --- a/sqlspec/adapters/adbc/_typing.py +++ b/sqlspec/adapters/adbc/_typing.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.adbc.driver import AdbcDriver @@ -71,7 +72,7 @@ def __enter__(self) -> "AdbcDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 5b0676846..ae2d06568 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from sqlspec.observability import ObservabilityConfig @@ -135,7 +136,7 @@ def __enter__(self) -> "AdbcConnection": return self._connection def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._connection: self._connection.close() diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 70e01076d..328d386dd 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -82,7 +82,7 @@ class AdbcExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False self.pending_exception = create_mapped_exception(exc_val) diff --git a/sqlspec/adapters/aiosqlite/_typing.py b/sqlspec/adapters/aiosqlite/_typing.py index 2859c7f2e..233ad9a73 100644 --- a/sqlspec/adapters/aiosqlite/_typing.py +++ b/sqlspec/adapters/aiosqlite/_typing.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver @@ -71,7 +72,7 @@ async def __aenter__(self) -> "AiosqliteDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 85e67f8c1..87a7b6b45 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -130,7 +131,7 @@ async def __aenter__(self) -> AiosqliteConnection: return await self._ctx.__aenter__() async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return await self._ctx.__aexit__(exc_type, exc_val, exc_tb) diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 776519e2d..c8f2abb07 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -25,6 +25,8 @@ from sqlspec.exceptions import SQLSpecError if TYPE_CHECKING: + from types import TracebackType + from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection from sqlspec.core import SQL, StatementConfig from sqlspec.driver import ExecutionResult @@ -57,7 +59,7 @@ async def __aenter__(self) -> "aiosqlite.Cursor": self.cursor = await self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: if exc_type is not None: return if self.cursor is not None: @@ -78,7 +80,7 @@ class AiosqliteExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: _ = exc_type if isinstance(exc_val, (aiosqlite.Error, sqlite3.Error)): self.pending_exception = create_mapped_exception(exc_val) diff --git a/sqlspec/adapters/asyncmy/_typing.py b/sqlspec/adapters/asyncmy/_typing.py index 344ed3a88..a7f316911 100644 --- a/sqlspec/adapters/asyncmy/_typing.py +++ b/sqlspec/adapters/asyncmy/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import Protocol from sqlspec.adapters.asyncmy.driver import AsyncmyDriver @@ -76,7 +77,7 @@ async def __aenter__(self) -> "AsyncmyDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index bf6c67592..171c90332 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from asyncmy.cursors import Cursor, DictCursor # pyright: ignore from asyncmy.pool import Pool # pyright: ignore @@ -137,7 +138,7 @@ async def __aenter__(self) -> AsyncmyConnection: return connection async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index b01e8b3fc..a379b567d 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -88,7 +88,7 @@ class AsyncmyExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, asyncmy.errors.Error): diff --git a/sqlspec/adapters/asyncpg/_typing.py b/sqlspec/adapters/asyncpg/_typing.py index c8b2c2578..bcce033e3 100644 --- a/sqlspec/adapters/asyncpg/_typing.py +++ b/sqlspec/adapters/asyncpg/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from asyncpg import Connection, Pool, Record @@ -78,7 +79,7 @@ async def __aenter__(self) -> "AsyncpgDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 44ac1265b..052dee885 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from asyncio.events import AbstractEventLoop from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -266,7 +267,7 @@ async def __aenter__(self) -> "AsyncpgConnection": return self._connection async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._connection is not None: if self._config.connection_instance: diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index 24976bb17..d3ec176d8 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -80,7 +80,7 @@ class AsyncpgExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: _ = exc_type if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val): self.pending_exception = create_mapped_exception(exc_val) diff --git a/sqlspec/adapters/bigquery/_typing.py b/sqlspec/adapters/bigquery/_typing.py index 0438e71ec..527d572aa 100644 --- a/sqlspec/adapters/bigquery/_typing.py +++ b/sqlspec/adapters/bigquery/_typing.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from google.cloud.bigquery import ArrayQueryParameter, Client, ScalarQueryParameter @@ -75,7 +76,7 @@ def __enter__(self) -> "BigQueryDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 8f537c34e..99764b023 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from google.api_core.client_info import ClientInfo from google.api_core.client_options import ClientOptions @@ -116,7 +117,7 @@ def __enter__(self) -> BigQueryConnection: return self._connection def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: return None diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index 35390439b..c7cebc77b 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -103,7 +103,7 @@ class BigQueryExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, GoogleCloudError): diff --git a/sqlspec/adapters/cockroach_asyncpg/_typing.py b/sqlspec/adapters/cockroach_asyncpg/_typing.py index 7df0595da..e721dc5ea 100644 --- a/sqlspec/adapters/cockroach_asyncpg/_typing.py +++ b/sqlspec/adapters/cockroach_asyncpg/_typing.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from asyncpg import Connection, Pool, Record @@ -61,7 +62,7 @@ async def __aenter__(self) -> "CockroachAsyncpgDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/cockroach_asyncpg/config.py b/sqlspec/adapters/cockroach_asyncpg/config.py index 036f51324..1f1d97c16 100644 --- a/sqlspec/adapters/cockroach_asyncpg/config.py +++ b/sqlspec/adapters/cockroach_asyncpg/config.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -127,7 +128,7 @@ async def __aenter__(self) -> "CockroachAsyncpgConnection": return self._connection async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._connection is not None: if self._config.connection_instance: diff --git a/sqlspec/adapters/cockroach_asyncpg/driver.py b/sqlspec/adapters/cockroach_asyncpg/driver.py index 707b5efd6..c3664953a 100644 --- a/sqlspec/adapters/cockroach_asyncpg/driver.py +++ b/sqlspec/adapters/cockroach_asyncpg/driver.py @@ -38,7 +38,7 @@ class CockroachAsyncpgExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: _ = exc_type if isinstance(exc_val, asyncpg.PostgresError) or has_sqlstate(exc_val): if has_sqlstate(exc_val) and str(exc_val.sqlstate) == "40001": diff --git a/sqlspec/adapters/cockroach_psycopg/_typing.py b/sqlspec/adapters/cockroach_psycopg/_typing.py index fb2f3aca6..1e5e224c2 100644 --- a/sqlspec/adapters/cockroach_psycopg/_typing.py +++ b/sqlspec/adapters/cockroach_psycopg/_typing.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from psycopg.crdb import AsyncCrdbConnection, CrdbConnection @@ -65,7 +66,7 @@ def __enter__(self) -> "CockroachPsycopgSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) @@ -112,7 +113,7 @@ async def __aenter__(self) -> "CockroachPsycopgAsyncDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/cockroach_psycopg/config.py b/sqlspec/adapters/cockroach_psycopg/config.py index 9a5ae9d69..83f1b490f 100644 --- a/sqlspec/adapters/cockroach_psycopg/config.py +++ b/sqlspec/adapters/cockroach_psycopg/config.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -121,7 +122,7 @@ def __enter__(self) -> "CockroachSyncConnection": return cast("CockroachSyncConnection", self._ctx) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._config.connection_instance and self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) @@ -322,7 +323,7 @@ async def __aenter__(self) -> "CockroachAsyncConnection": raise ImproperConfigurationError(msg) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/cockroach_psycopg/driver.py b/sqlspec/adapters/cockroach_psycopg/driver.py index d02350978..8a9f09f6b 100644 --- a/sqlspec/adapters/cockroach_psycopg/driver.py +++ b/sqlspec/adapters/cockroach_psycopg/driver.py @@ -55,7 +55,7 @@ class CockroachPsycopgSyncExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): @@ -72,7 +72,7 @@ class CockroachPsycopgAsyncExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): diff --git a/sqlspec/adapters/duckdb/_typing.py b/sqlspec/adapters/duckdb/_typing.py index 1db676dd8..5e38ee3a6 100644 --- a/sqlspec/adapters/duckdb/_typing.py +++ b/sqlspec/adapters/duckdb/_typing.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.duckdb.driver import DuckDBDriver @@ -70,7 +71,7 @@ def __enter__(self) -> "DuckDBDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 46ec8d899..4165d2c8a 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -179,7 +180,7 @@ def __enter__(self) -> DuckDBConnection: return cast("DuckDBConnection", self._ctx.__enter__()) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/duckdb/core.py b/sqlspec/adapters/duckdb/core.py index a28d2bba0..997e0e659 100644 --- a/sqlspec/adapters/duckdb/core.py +++ b/sqlspec/adapters/duckdb/core.py @@ -195,7 +195,7 @@ def _create_duckdb_error(error: Any, error_class: type[SQLSpecError], descriptio return exc -def create_mapped_exception(exc_type: Any, error: Any) -> SQLSpecError: +def create_mapped_exception(exc_type: "type[BaseException]", error: "BaseException") -> SQLSpecError: """Map DuckDB exceptions to SQLSpec exceptions. This is a factory function that returns an exception instance rather than diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index 24d0e6a62..c6ee3956e 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -75,7 +75,7 @@ class DuckDBExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False self.pending_exception = create_mapped_exception(exc_type, exc_val) diff --git a/sqlspec/adapters/mock/_typing.py b/sqlspec/adapters/mock/_typing.py index 8dbe4add1..53ab3d38e 100644 --- a/sqlspec/adapters/mock/_typing.py +++ b/sqlspec/adapters/mock/_typing.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.mock.driver import MockAsyncDriver, MockSyncDriver @@ -75,7 +76,7 @@ def __enter__(self) -> "MockSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) @@ -136,7 +137,7 @@ async def __aenter__(self) -> "MockAsyncDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/mock/config.py b/sqlspec/adapters/mock/config.py index 2ec0ef0bf..8dd59f23e 100644 --- a/sqlspec/adapters/mock/config.py +++ b/sqlspec/adapters/mock/config.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -70,7 +71,7 @@ def __enter__(self) -> MockConnection: return self._connection def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._connection.close() @@ -92,7 +93,7 @@ async def __aenter__(self) -> MockConnection: return self._connection async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._connection.close() diff --git a/sqlspec/adapters/mock/driver.py b/sqlspec/adapters/mock/driver.py index 08fd0a8b1..e20977e2a 100644 --- a/sqlspec/adapters/mock/driver.py +++ b/sqlspec/adapters/mock/driver.py @@ -35,6 +35,8 @@ from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: + from types import TracebackType + from sqlspec.adapters.mock._typing import MockConnection from sqlspec.core import SQL, StatementConfig from sqlspec.driver import ExecutionResult @@ -107,7 +109,7 @@ async def __aenter__(self) -> "sqlite3.Cursor": return self.cursor async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> None: """Clean up cursor resources.""" if self.cursor is not None: @@ -128,7 +130,7 @@ class MockExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): @@ -145,7 +147,7 @@ class MockAsyncExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): diff --git a/sqlspec/adapters/mysqlconnector/_typing.py b/sqlspec/adapters/mysqlconnector/_typing.py index 17a5944c4..0d9066a1e 100644 --- a/sqlspec/adapters/mysqlconnector/_typing.py +++ b/sqlspec/adapters/mysqlconnector/_typing.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import Protocol, TypeAlias from sqlspec.adapters.mysqlconnector.driver import MysqlConnectorAsyncDriver, MysqlConnectorSyncDriver @@ -78,7 +79,7 @@ def __enter__(self) -> "MysqlConnectorSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) @@ -125,7 +126,7 @@ async def __aenter__(self) -> "MysqlConnectorAsyncDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/mysqlconnector/config.py b/sqlspec/adapters/mysqlconnector/config.py index 41f024422..d95edd3d7 100644 --- a/sqlspec/adapters/mysqlconnector/config.py +++ b/sqlspec/adapters/mysqlconnector/config.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from mysql.connector.pooling import MySQLConnectionPool @@ -143,7 +144,7 @@ def __enter__(self) -> MysqlConnectorSyncConnection: return self._connection def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._connection is not None: self._connection.close() @@ -187,7 +188,7 @@ async def __aenter__(self) -> MysqlConnectorAsyncConnection: return self._connection async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._connection is not None: await self._connection.close() diff --git a/sqlspec/adapters/mysqlconnector/driver.py b/sqlspec/adapters/mysqlconnector/driver.py index f3cfd0e08..632c72837 100644 --- a/sqlspec/adapters/mysqlconnector/driver.py +++ b/sqlspec/adapters/mysqlconnector/driver.py @@ -91,7 +91,7 @@ class MysqlConnectorSyncExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, mysql.connector.Error): @@ -323,7 +323,7 @@ class MysqlConnectorAsyncExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, mysql.connector.Error): diff --git a/sqlspec/adapters/oracledb/_typing.py b/sqlspec/adapters/oracledb/_typing.py index 9210b2b4b..34342f6d7 100644 --- a/sqlspec/adapters/oracledb/_typing.py +++ b/sqlspec/adapters/oracledb/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from oracledb import DB_TYPE_VECTOR # pyright: ignore[reportUnknownVariableType] @@ -106,7 +107,7 @@ def __enter__(self) -> "OracleSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) @@ -161,7 +162,7 @@ async def __aenter__(self) -> "OracleAsyncDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 01a51aea8..26f42c45f 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from oracledb import AuthMode @@ -142,7 +143,7 @@ def __enter__(self) -> "OracleSyncConnection": return self._conn def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._conn: if self._config.connection_instance: @@ -339,7 +340,7 @@ async def __aenter__(self) -> "OracleAsyncConnection": return self._conn async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._conn: if self._config.connection_instance: diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 8d52499b1..5d4064194 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -57,6 +57,7 @@ if TYPE_CHECKING: from collections.abc import Sequence + from types import TracebackType from sqlspec.adapters.oracledb._typing import OraclePipelineDriver from sqlspec.builder import QueryBuilder @@ -258,7 +259,7 @@ async def __aenter__(self) -> AsyncCursor: self.cursor = self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: _ = (exc_type, exc_val, exc_tb) # Mark as intentionally unused if self.cursor is not None: with contextlib.suppress(Exception): @@ -280,7 +281,7 @@ class OracleSyncExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, oracledb.DatabaseError): @@ -302,7 +303,7 @@ class OracleAsyncExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, oracledb.DatabaseError): diff --git a/sqlspec/adapters/psqlpy/_typing.py b/sqlspec/adapters/psqlpy/_typing.py index fdb684455..3662eb60d 100644 --- a/sqlspec/adapters/psqlpy/_typing.py +++ b/sqlspec/adapters/psqlpy/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.psqlpy.driver import PsqlpyDriver @@ -69,7 +70,7 @@ async def __aenter__(self) -> "PsqlpyDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 8266da9f7..d3c35e628 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig @@ -166,7 +167,7 @@ async def __aenter__(self) -> PsqlpyConnection: return connection # type: ignore[no-any-return] async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return await self._ctx.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[no-any-return] diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index bf8d2d261..df5921b0e 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -99,7 +99,7 @@ class PsqlpyExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.Error)): diff --git a/sqlspec/adapters/psycopg/_typing.py b/sqlspec/adapters/psycopg/_typing.py index a0acc84c7..f74c318be 100644 --- a/sqlspec/adapters/psycopg/_typing.py +++ b/sqlspec/adapters/psycopg/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from psycopg import AsyncConnection, Connection @@ -92,7 +93,7 @@ def __enter__(self) -> "PsycopgSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) @@ -148,7 +149,7 @@ async def __aenter__(self) -> "PsycopgAsyncDriver": return self._prepare_driver(self._driver) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: await self._release_connection(self._connection) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 192198086..bae4c91cb 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from types import TracebackType from sqlspec.core import StatementConfig @@ -143,7 +144,7 @@ def __enter__(self) -> "PsycopgSyncConnection": return cast("PsycopgSyncConnection", self._ctx) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._config.connection_instance and self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) @@ -419,7 +420,7 @@ async def __aenter__(self) -> "PsycopgAsyncConnection": raise ImproperConfigurationError(msg) async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index bc4b2b69d..914d22c6c 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -55,6 +55,8 @@ from sqlspec.utils.type_guards import is_readable if TYPE_CHECKING: + from types import TracebackType + from sqlspec.adapters.psycopg._typing import PsycopgPipelineDriver from sqlspec.core import ArrowResult from sqlspec.driver import ExecutionResult @@ -145,7 +147,7 @@ class PsycopgSyncExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): @@ -186,7 +188,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: @@ -217,7 +219,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: @@ -238,7 +240,7 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResu return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements. Args: @@ -262,7 +264,7 @@ def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionRe last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True ) - def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": + def dispatch_special_handling(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: @@ -566,13 +568,13 @@ def _resolve_column_names(self, description: Any) -> list[str]: self._column_name_cache[cache_key] = (description, column_names) return column_names - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "PsycopgSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg sync rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "PsycopgSyncCursor") -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -594,7 +596,7 @@ async def __aenter__(self) -> Any: self.cursor = self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: _ = (exc_type, exc_val, exc_tb) if self.cursor is not None: await self.cursor.close() @@ -613,7 +615,7 @@ class PsycopgAsyncExceptionHandler(BaseAsyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, psycopg.Error): diff --git a/sqlspec/adapters/pymysql/_typing.py b/sqlspec/adapters/pymysql/_typing.py index fe0ef144d..481a558f0 100644 --- a/sqlspec/adapters/pymysql/_typing.py +++ b/sqlspec/adapters/pymysql/_typing.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.pymysql.driver import PyMysqlDriver @@ -60,7 +61,7 @@ def __enter__(self) -> "PyMysqlDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) diff --git a/sqlspec/adapters/pymysql/config.py b/sqlspec/adapters/pymysql/config.py index d69462e1c..1cb494559 100644 --- a/sqlspec/adapters/pymysql/config.py +++ b/sqlspec/adapters/pymysql/config.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -86,7 +87,7 @@ def __enter__(self) -> PyMysqlConnection: return cast("PyMysqlConnection", self._ctx.__enter__()) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/pymysql/driver.py b/sqlspec/adapters/pymysql/driver.py index 6886b1c9e..ccec56dcd 100644 --- a/sqlspec/adapters/pymysql/driver.py +++ b/sqlspec/adapters/pymysql/driver.py @@ -70,7 +70,7 @@ class PyMysqlExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, pymysql.MySQLError): diff --git a/sqlspec/adapters/spanner/_typing.py b/sqlspec/adapters/spanner/_typing.py index f9335ca2e..73da255c8 100644 --- a/sqlspec/adapters/spanner/_typing.py +++ b/sqlspec/adapters/spanner/_typing.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from google.cloud.spanner_v1.database import SnapshotCheckout from google.cloud.spanner_v1.snapshot import Snapshot @@ -75,7 +76,7 @@ def __enter__(self) -> "SpannerSyncDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection, exc_type, exc_val, exc_tb) diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 35fa3b48e..61008b63b 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from google.auth.credentials import Credentials from google.cloud.spanner_v1.database import Database @@ -102,7 +103,7 @@ def __enter__(self) -> SpannerConnection: return self._connection def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._transaction and self._connection: txn = cast("Any", self._connection) @@ -150,7 +151,7 @@ def release_connection( _conn: "SpannerConnection", exc_type: "type[BaseException] | None", exc_val: "BaseException | None", - exc_tb: Any, + exc_tb: "TracebackType | None", ) -> None: self._connection_ctx.__exit__(exc_type, exc_val, exc_tb) diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index d185e1311..e22dc7b5e 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -85,7 +85,7 @@ class SpannerExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False diff --git a/sqlspec/adapters/sqlite/_typing.py b/sqlspec/adapters/sqlite/_typing.py index a7c85ff6e..cfc0428ab 100644 --- a/sqlspec/adapters/sqlite/_typing.py +++ b/sqlspec/adapters/sqlite/_typing.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from typing import TypeAlias from sqlspec.adapters.sqlite.driver import SqliteDriver @@ -69,7 +70,7 @@ def __enter__(self) -> "SqliteDriver": return self._prepare_driver(self._driver) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> "bool | None": if self._connection is not None: self._release_connection(self._connection) diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 8343ad464..2147665ad 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from types import TracebackType from sqlspec.core import StatementConfig from sqlspec.observability import ObservabilityConfig @@ -86,7 +87,7 @@ def __enter__(self) -> SqliteConnection: return cast("SqliteConnection", self._ctx.__enter__()) def __exit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: Any + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool | None: if self._ctx: return cast("bool | None", self._ctx.__exit__(exc_type, exc_val, exc_tb)) diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index c3d4c3ed9..06a5283d8 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -89,7 +89,7 @@ class SqliteExceptionHandler(BaseSyncExceptionHandler): __slots__ = () - def _handle_exception(self, exc_type: Any, exc_val: BaseException) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: if exc_type is None: return False if issubclass(exc_type, sqlite3.Error): diff --git a/sqlspec/base.py b/sqlspec/base.py index 7a913588a..7fd0ab363 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from pathlib import Path + from types import TracebackType from sqlspec.core import SQL from sqlspec.typing import PoolT @@ -59,7 +60,7 @@ def __enter__(self) -> ConnectionT: self._runtime.emit_connection_create(self._connection) return self._connection - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None": + def __exit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": try: return self._context.__exit__(exc_type, exc_val, exc_tb) finally: @@ -79,7 +80,7 @@ async def __aenter__(self) -> ConnectionT: self._runtime.emit_connection_create(self._connection) return self._connection - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None": + async def __aexit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": try: return await self._context.__aexit__(exc_type, exc_val, exc_tb) finally: @@ -110,7 +111,7 @@ def __enter__(self) -> DriverT: self._runtime.emit_session_start(driver) return driver - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None": + def __exit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": try: return self._context.__exit__(exc_type, exc_val, exc_tb) finally: @@ -144,7 +145,7 @@ async def __aenter__(self) -> DriverT: self._runtime.emit_session_start(driver) return driver - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> "bool | None": + async def __aexit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": try: return await self._context.__aexit__(exc_type, exc_val, exc_tb) finally: @@ -255,7 +256,7 @@ async def __aenter__(self) -> Self: """Async context manager entry.""" return self - async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + async def __aexit__(self, _exc_type: "type[BaseException] | None", _exc_val: "BaseException | None", _exc_tb: "TracebackType | None") -> None: """Async context manager exit with automatic cleanup.""" await self.close_all_pools() diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 3b8e1bf0e..28499473c 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence + from types import TracebackType from sqlspec.core import FilterTypeT, StatementFilter from sqlspec.core.parameters._types import ConvertedParameters @@ -197,7 +198,7 @@ class SyncExceptionHandler(Protocol): def __enter__(self) -> Self: ... - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> bool: ... class AsyncExceptionHandler(Protocol): @@ -212,7 +213,7 @@ class AsyncExceptionHandler(Protocol): async def __aenter__(self) -> Self: ... - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: ... + async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> bool: ... logger = get_logger("sqlspec.driver") @@ -429,19 +430,19 @@ def __enter__(self) -> Self: ) return self - def __exit__(self, exc_type: Any, exc: Exception | None, exc_tb: Any) -> Literal[False]: + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> Literal[False]: duration = perf_counter() - self.started self.metrics.record_duration(duration) - if exc is not None: - self.metrics.record_error(exc) - self.runtime.span_manager.end_span(self.span, error=exc if exc is not None else None) + if exc_val is not None: + self.metrics.record_error(exc_val) + self.runtime.span_manager.end_span(self.span, error=exc_val if exc_val is not None else None) self.metrics.emit(self.runtime) - level = logging.ERROR if exc is not None else logging.DEBUG + level = logging.ERROR if exc_val is not None else logging.DEBUG trace_id, span_id = get_trace_context() log_with_context( logger, level, - "stack.execute.failed" if exc is not None else "stack.execute.complete", + "stack.execute.failed" if exc_val is not None else "stack.execute.complete", driver=type(self.driver).__name__, db_system=resolve_db_system(type(self.driver).__name__), stack_size=len(self.stack.operations), @@ -450,7 +451,7 @@ def __exit__(self, exc_type: Any, exc: Exception | None, exc_tb: Any) -> Literal forced_disable=self.driver.stack_native_disabled, hashed_operations=self.hashed_operations, duration_ms=duration * 1000, - error_type=type(exc).__name__ if exc is not None else None, + error_type=type(exc_val).__name__ if exc_val is not None else None, trace_id=trace_id, span_id=span_id, ) diff --git a/sqlspec/driver/_exception_handler.py b/sqlspec/driver/_exception_handler.py index b7e494805..a64a060f4 100644 --- a/sqlspec/driver/_exception_handler.py +++ b/sqlspec/driver/_exception_handler.py @@ -1,10 +1,13 @@ """Shared exception handler bases for driver adapters.""" -from typing import Any +from typing import TYPE_CHECKING from mypy_extensions import mypyc_attr from typing_extensions import Self +if TYPE_CHECKING: + from types import TracebackType + @mypyc_attr(allow_interpreted_subclasses=True) class BaseAsyncExceptionHandler: @@ -22,7 +25,7 @@ async def __aexit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", - exc_tb: Any, + exc_tb: "TracebackType | None", ) -> bool: _ = exc_tb if exc_val is None: @@ -58,7 +61,7 @@ def __exit__( self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", - exc_tb: Any, + exc_tb: "TracebackType | None", ) -> bool: _ = exc_tb if exc_val is None: diff --git a/sqlspec/storage/backends/base.py b/sqlspec/storage/backends/base.py index 3c2bd9a53..60a90c2b3 100644 --- a/sqlspec/storage/backends/base.py +++ b/sqlspec/storage/backends/base.py @@ -3,11 +3,14 @@ import asyncio from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator -from typing import Any, NoReturn, cast +from typing import TYPE_CHECKING, Any, NoReturn, cast from mypy_extensions import mypyc_attr from typing_extensions import Self +if TYPE_CHECKING: + from types import TracebackType + from sqlspec.typing import ArrowRecordBatch, ArrowTable from sqlspec.utils.sync_tools import CapacityLimiter @@ -288,7 +291,7 @@ async def __aenter__(self) -> Self: return self async def __aexit__( - self, exc_type: "type[BaseException] | None", exc: "BaseException | None", tb: "Any | None" + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> None: """Close the underlying file when exiting a context.""" await self.aclose() diff --git a/sqlspec/utils/profiling.py b/tools/profiling.py similarity index 64% rename from sqlspec/utils/profiling.py rename to tools/profiling.py index d48cad2e1..40b3f1e7f 100644 --- a/sqlspec/utils/profiling.py +++ b/tools/profiling.py @@ -1,6 +1,6 @@ -"""Profiling utilities for sqlspec hot path analysis. +"""Profiling helpers for internal SQLSpec tooling. -Provides a low-overhead profiler using sys.setprofile to capture call counts +Provides a low-overhead profiler using ``sys.setprofile`` to capture call counts and durations in critical execution paths. """ @@ -13,9 +13,9 @@ if TYPE_CHECKING: from collections.abc import Callable - from types import FrameType + from types import FrameType, TracebackType -__all__ = ("CallStats", "HotPathProfiler") +__all__ = ("CallStats", "HotPathProfiler", "profile_hotpath") @dataclass @@ -41,11 +41,7 @@ def update(self, duration: float) -> None: @dataclass class HotPathProfiler: - """Low-overhead profiler using sys.setprofile. - - Captures call counts and durations for functions called within the context. - Designed for surgical profiling of hot paths. - """ + """Low-overhead profiler using ``sys.setprofile``.""" stats: dict[str, CallStats] = field(default_factory=dict) _stack: list[tuple[str, float]] = field(default_factory=list) @@ -56,7 +52,12 @@ def __enter__(self) -> Self: self.start() return self - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: "TracebackType | None", + ) -> None: self.stop() def start(self) -> None: @@ -71,13 +72,7 @@ def stop(self) -> None: self._enabled = False def _profile_callback(self, frame: "FrameType", event: str, arg: Any) -> None: - """Callback for sys.setprofile. - - Args: - frame: The current stack frame. - event: The profile event (call, return, c_call, c_return, c_exception). - arg: Event-specific argument. - """ + """Callback for ``sys.setprofile``.""" if not self._enabled: return @@ -87,7 +82,6 @@ def _profile_callback(self, frame: "FrameType", event: str, arg: Any) -> None: code = frame.f_code func_name = f"{code.co_filename}:{code.co_firstlineno}({code.co_name})" if event == "call" else str(arg) self._stack.append((func_name, now)) - elif event in ("return", "c_return", "c_exception") and self._stack: func_name, start_time = self._stack.pop() duration = now - start_time @@ -97,12 +91,7 @@ def _profile_callback(self, frame: "FrameType", event: str, arg: Any) -> None: self.stats[func_name].update(duration) def print_report(self, limit: int = 20, sort_by: str = "count") -> None: - """Print a formatted report of collected statistics. - - Args: - limit: Maximum number of functions to display. - sort_by: Field to sort by (count, time). - """ + """Print a formatted report of collected statistics.""" from rich.console import Console from rich.table import Table @@ -118,44 +107,36 @@ def print_report(self, limit: int = 20, sort_by: str = "count") -> None: items = list(self.stats.items()) if sort_by == "count": - items.sort(key=lambda x: x[1].count, reverse=True) + items.sort(key=lambda item: item[1].count, reverse=True) else: - items.sort(key=lambda x: x[1].total_time, reverse=True) + items.sort(key=lambda item: item[1].total_time, reverse=True) - total_captured_time = sum(s.total_time for s in self.stats.values()) + total_captured_time = sum(stat.total_time for stat in self.stats.values()) - for name, s in items[:limit]: - avg_us = (s.total_time / s.count * 1_000_000) if s.count > 0 else 0 - pct_time = (s.total_time / total_captured_time * 100) if total_captured_time > 0 else 0 + for name, stat in items[:limit]: + avg_us = (stat.total_time / stat.count * 1_000_000) if stat.count > 0 else 0 + pct_time = (stat.total_time / total_captured_time * 100) if total_captured_time > 0 else 0 table.add_row( name, - str(s.count), - f"{s.total_time * 1000:.3f}", + str(stat.count), + f"{stat.total_time * 1000:.3f}", f"{pct_time:.1f}%", f"{avg_us:.2f}", - f"{s.min_time * 1_000_000:.2f}", - f"{s.max_time * 1_000_000:.2f}", + f"{stat.min_time * 1_000_000:.2f}", + f"{stat.max_time * 1_000_000:.2f}", ) console.print(table) def profile_hotpath(limit: int = 20, sort_by: str = "count") -> "Callable[..., Any]": - """Decorator to profile a function hot path. - - Args: - limit: Maximum number of functions to display in report. - sort_by: Field to sort by (count, time). - - Returns: - Decorated function. - """ + """Decorator to profile a function hot path.""" def decorator(func: "Callable[..., Any]") -> "Callable[..., Any]": def wrapper(*args: Any, **kwargs: Any) -> Any: - with HotPathProfiler() as prof: + with HotPathProfiler() as profiler: result = func(*args, **kwargs) - prof.print_report(limit=limit, sort_by=sort_by) + profiler.print_report(limit=limit, sort_by=sort_by) return result return wrapper diff --git a/tools/scripts/bench_subsystems.py b/tools/scripts/bench_subsystems.py index d7e6a4b9a..336e74764 100644 --- a/tools/scripts/bench_subsystems.py +++ b/tools/scripts/bench_subsystems.py @@ -19,7 +19,7 @@ from rich.console import Console from rich.table import Table -from sqlspec.utils.profiling import HotPathProfiler +from tools.profiling import HotPathProfiler __all__ = ("SubsystemBenchmark", "main", "print_results_table", "run_benchmarks") From cb912360f42fe62936e7e7ddde597bf66c2f4c78 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 18:14:56 +0000 Subject: [PATCH 28/39] feat: Introduce `TypeDispatcher` for efficient parameter type coercion and enhance exception mapping. --- sqlspec/adapters/adbc/config.py | 8 +- sqlspec/adapters/adbc/core.py | 30 ++++- sqlspec/adapters/adbc/driver.py | 4 +- sqlspec/adapters/aiosqlite/driver.py | 8 +- sqlspec/adapters/asyncmy/driver.py | 10 +- sqlspec/adapters/asyncpg/config.py | 6 +- sqlspec/adapters/asyncpg/core.py | 105 +++++++++++------- sqlspec/adapters/asyncpg/driver.py | 4 +- sqlspec/adapters/bigquery/driver.py | 10 +- sqlspec/adapters/cockroach_psycopg/driver.py | 29 ++--- sqlspec/adapters/duckdb/driver.py | 10 +- sqlspec/adapters/mock/driver.py | 8 +- sqlspec/adapters/mysqlconnector/config.py | 4 +- sqlspec/adapters/mysqlconnector/driver.py | 24 ++-- sqlspec/adapters/oracledb/config.py | 4 +- sqlspec/adapters/oracledb/driver.py | 12 +- sqlspec/adapters/psqlpy/config.py | 6 +- sqlspec/adapters/psqlpy/core.py | 31 +++++- sqlspec/adapters/psqlpy/driver.py | 4 +- sqlspec/adapters/psycopg/config.py | 12 +- sqlspec/adapters/psycopg/core.py | 1 + sqlspec/adapters/psycopg/driver.py | 16 +-- sqlspec/adapters/pymysql/driver.py | 10 +- sqlspec/adapters/spanner/driver.py | 4 +- sqlspec/adapters/sqlite/driver.py | 4 +- sqlspec/base.py | 20 +++- sqlspec/builder/_merge.py | 45 ++++---- sqlspec/core/config_runtime.py | 14 +-- sqlspec/core/parameters/_processor.py | 32 ++++-- sqlspec/driver/_async.py | 5 +- sqlspec/driver/_common.py | 47 ++++---- sqlspec/driver/_exception_handler.py | 24 +--- sqlspec/driver/_sync.py | 3 +- sqlspec/migrations/base.py | 38 ++++--- sqlspec/migrations/runner.py | 4 +- sqlspec/storage/pipeline.py | 24 +--- sqlspec/utils/arrow_helpers.py | 37 +++++- sqlspec/utils/dispatch.py | 51 +++++++-- sqlspec/utils/schema.py | 41 +++++-- sqlspec/utils/type_converters.py | 86 ++++++++------ tests/unit/adapters/test_adbc/test_core.py | 38 ++++++- .../test_adbc/test_extension_detection.py | 4 +- .../unit/adapters/test_asyncpg/test_config.py | 4 +- .../test_asyncpg/test_type_handlers.py | 21 ++++ .../unit/adapters/test_psqlpy/test_config.py | 4 +- tests/unit/adapters/test_psqlpy/test_core.py | 35 +++++- .../unit/adapters/test_psycopg/test_config.py | 4 +- tests/unit/builder/test_merge.py | 18 +++ tests/unit/config/test_provide_methods.py | 18 ++- .../unit/config/test_storage_capabilities.py | 4 +- tests/unit/core/test_parameters.py | 18 ++- tests/unit/core/test_result.py | 4 +- tests/unit/driver/test_query_cache.py | 19 ++++ tests/unit/driver/test_result_tools.py | 2 +- .../unit/exceptions/test_exception_handler.py | 4 +- .../migrations/test_migration_execution.py | 6 +- .../unit/migrations/test_migration_runner.py | 20 ++-- tests/unit/test_mypyc_config.py | 2 - tests/unit/test_perf_surface_inventory.py | 8 +- tests/unit/utils/test_arrow_helpers.py | 20 +++- tests/unit/utils/test_dispatch.py | 10 ++ tests/unit/utils/test_mypyc_boundary_map.py | 10 +- tests/unit/utils/test_mypyc_inventory.py | 6 +- tests/unit/utils/test_sync_tools.py | 9 +- tests/unit/utils/test_to_value_type.py | 18 ++- tests/unit/utils/test_type_converters.py | 24 ++-- tools/profiling.py | 5 +- tools/scripts/bench_subsystems.py | 7 +- tools/scripts/mypyc_boundary_map.py | 16 +-- tools/scripts/mypyc_inventory.py | 12 +- 70 files changed, 764 insertions(+), 441 deletions(-) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index ae2d06568..1668e0b82 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -282,9 +282,7 @@ def _detect_extensions_if_needed(self) -> None: try: probe_names = build_postgres_extension_probe_names(self.driver_features) pgvector_available, paradedb_available = detect_postgres_extensions( - connection, - enable_pgvector="vector" in probe_names, - enable_paradedb="pg_search" in probe_names, + connection, enable_pgvector="vector" in probe_names, enable_paradedb="pg_search" in probe_names ) finally: connection.close() @@ -295,9 +293,7 @@ def _detect_extensions_if_needed(self) -> None: if paradedb_available: detected_extensions.add("pg_search") self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( - self.statement_config, - self.driver_features, - detected_extensions, + self.statement_config, self.driver_features, detected_extensions ) def provide_connection(self, *args: Any, **kwargs: Any) -> "AdbcConnectionContext": diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index 59b34aafd..ef5cf8789 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -34,6 +34,7 @@ map_sqlstate_to_exception, ) from sqlspec.typing import PGVECTOR_INSTALLED, Empty +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.module_loader import import_string from sqlspec.utils.serializers import to_json from sqlspec.utils.type_guards import has_rowcount, has_sqlstate @@ -152,6 +153,7 @@ ) _BIGQUERY_DB_KWARGS_FIELDS: "tuple[str, ...]" = ("project_id", "dataset_id", "token") +_TYPE_COERCION_DISPATCHERS: "dict[tuple[tuple[type, Callable[[Any], Any]], ...], TypeDispatcher[Callable[[Any], Any]]]" = {} def detect_dialect(connection: Any, logger: Any | None = None) -> str: @@ -221,6 +223,7 @@ def detect_postgres_extensions( except Exception: return False, False + def normalize_driver_path(driver_name: str) -> str: """Normalize a driver name to an importable connect function path.""" stripped = driver_name.strip() @@ -864,6 +867,20 @@ def resolve_parameter_casts(statement: "SQL") -> "dict[int, str]": return {} +def _get_type_coercion_dispatcher( + type_map: "dict[type, Callable[[Any], Any]]", +) -> "TypeDispatcher[Callable[[Any], Any]]": + fallback_items = tuple(type_map.items()) + dispatcher = _TYPE_COERCION_DISPATCHERS.get(fallback_items) + if dispatcher is not None: + return dispatcher + + dispatcher = TypeDispatcher["Callable[[Any], Any]"]() + dispatcher.register_all(fallback_items) + _TYPE_COERCION_DISPATCHERS[fallback_items] = dispatcher + return dispatcher + + def prepare_parameters_with_casts( parameters: Any, parameter_casts: "dict[int, str]", @@ -878,6 +895,8 @@ def prepare_parameters_with_casts( if isinstance(parameters, (list, tuple)): result: list[Any] = [] converter = get_adbc_type_converter(dialect) + type_map = statement_config.parameter_config.type_coercion_map + dispatcher = _get_type_coercion_dispatcher(type_map) if type_map else None for idx, param in enumerate(parameters, start=1): cast_type = parameter_casts.get(idx, "").upper() if cast_type in {"JSON", "JSONB", "TYPE.JSON", "TYPE.JSONB"}: @@ -888,11 +907,14 @@ def prepare_parameters_with_casts( elif isinstance(param, dict): result.append(converter.convert_dict(param)) else: - if statement_config.parameter_config.type_coercion_map: - for type_check, converter_func in statement_config.parameter_config.type_coercion_map.items(): - if type_check is not dict and isinstance(param, type_check): + if type_map and dispatcher is not None: + exact_converter = type_map.get(type(param)) + if exact_converter is not None and type(param) is not dict: + param = exact_converter(param) + else: + converter_func = dispatcher.get(param) + if converter_func is not None and type(param) is not dict: param = converter_func(param) - break result.append(param) return tuple(result) if isinstance(parameters, tuple) else result return parameters diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 328d386dd..a709efeae 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -492,7 +492,7 @@ def data_dictionary(self) -> "AdbcDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AdbcCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect ADBC rows for the direct execution path.""" column_names = self._resolve_column_names(cursor.description) data, column_names = collect_rows( @@ -500,7 +500,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "AdbcCursor") -> int: """Resolve rowcount from ADBC cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index c8f2abb07..14b0dd5ab 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -59,7 +59,9 @@ async def __aenter__(self) -> "aiosqlite.Cursor": self.cursor = await self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> None: if exc_type is not None: return if self.cursor is not None: @@ -307,11 +309,11 @@ def data_dictionary(self) -> "AiosqliteDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AiosqliteCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect aiosqlite rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "AiosqliteCursor") -> int: """Resolve rowcount from aiosqlite cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index a379b567d..a55faead0 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -129,7 +129,7 @@ def __init__( # CORE DISPATCH METHODS - The Execution Engine # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Handles parameter processing, result fetching, and data transformation @@ -168,7 +168,7 @@ async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionRes last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL statement with multiple parameter sets. Uses AsyncMy's executemany for batch operations with MySQL type conversion @@ -194,7 +194,7 @@ async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "Executi return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Splits multi-statement scripts and executes each statement sequentially. @@ -380,7 +380,7 @@ def data_dictionary(self) -> "AsyncmyDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AsyncmyCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect asyncmy rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -391,7 +391,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "AsyncmyCursor") -> int: """Resolve rowcount from asyncmy cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 052dee885..ff314618c 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -475,10 +475,8 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None: detected_extensions = {r["extname"] for r in results} except Exception: detected_extensions = set() - self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( - self.statement_config, - self.driver_features, - detected_extensions, + self.statement_config, self._pgvector_available, self._paradedb_available = ( + resolve_postgres_extension_state(self.statement_config, self.driver_features, detected_extensions) ) if self._pgvector_available: diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index fcd39dce6..f879d2a15 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -30,6 +30,7 @@ map_sqlstate_to_exception, ) from sqlspec.typing import PGVECTOR_INSTALLED +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json from sqlspec.utils.type_guards import has_sqlstate @@ -77,6 +78,7 @@ class NormalizedStackOperation(NamedTuple): PREPARED_STATEMENT_CACHE_SIZE: Final[int] = 32 +_EXCEPTION_MAPPING_DISPATCHER = TypeDispatcher["tuple[str, type[SQLSpecError], str]"]() def _convert_datetime_param(value: Any) -> Any: @@ -263,6 +265,7 @@ def apply_driver_features( return statement_config, processed_features + def parse_status(status: Any) -> int: """Parse AsyncPG status string to extract row count. @@ -319,6 +322,64 @@ def _create_postgres_error( return exc +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.UniqueViolationError, + ("23505", UniqueViolationError, "unique constraint violation"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.ForeignKeyViolationError, + ("23503", ForeignKeyViolationError, "foreign key constraint violation"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.NotNullViolationError, + ("23502", NotNullViolationError, "not-null constraint violation"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.CheckViolationError, + ("23514", CheckViolationError, "check constraint violation"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.IntegrityConstraintViolationError, + ("23000", IntegrityError, "integrity constraint violation"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.DeadlockDetectedError, + ("40P01", DeadlockError, "deadlock detected"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.SerializationError, + ("40001", SerializationConflictError, "serialization failure"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.QueryCanceledError, + ("57014", QueryTimeoutError, "query canceled"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.InsufficientPrivilegeError, + ("42501", PermissionDeniedError, "insufficient privilege"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.InvalidPasswordError, + ("28P01", PermissionDeniedError, "invalid password"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.InvalidAuthorizationSpecificationError, + ("28000", PermissionDeniedError, "authorization error"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.ConnectionDoesNotExistError, + ("08003", ConnectionTimeoutError, "connection does not exist"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.CannotConnectNowError, + ("57P03", ConnectionTimeoutError, "cannot connect now"), +) +_EXCEPTION_MAPPING_DISPATCHER.register( + asyncpg.exceptions.PostgresSyntaxError, + ("42601", SQLParsingError, "SQL syntax error"), +) + + def create_mapped_exception(error: Any) -> SQLSpecError: """Map asyncpg exceptions to SQLSpec exceptions. @@ -337,46 +398,10 @@ def create_mapped_exception(error: Any) -> SQLSpecError: Returns: A SQLSpec exception that wraps the original error """ - # Priority 1: Check specific exception types first (most reliable) - # Integrity constraint violations - if isinstance(error, asyncpg.exceptions.UniqueViolationError): - return _create_postgres_error(error, "23505", UniqueViolationError, "unique constraint violation") - if isinstance(error, asyncpg.exceptions.ForeignKeyViolationError): - return _create_postgres_error(error, "23503", ForeignKeyViolationError, "foreign key constraint violation") - if isinstance(error, asyncpg.exceptions.NotNullViolationError): - return _create_postgres_error(error, "23502", NotNullViolationError, "not-null constraint violation") - if isinstance(error, asyncpg.exceptions.CheckViolationError): - return _create_postgres_error(error, "23514", CheckViolationError, "check constraint violation") - if isinstance(error, asyncpg.exceptions.IntegrityConstraintViolationError): - return _create_postgres_error(error, "23000", IntegrityError, "integrity constraint violation") - - # Transaction and serialization errors - if isinstance(error, asyncpg.exceptions.DeadlockDetectedError): - return _create_postgres_error(error, "40P01", DeadlockError, "deadlock detected") - if isinstance(error, asyncpg.exceptions.SerializationError): - return _create_postgres_error(error, "40001", SerializationConflictError, "serialization failure") - - # Query timeout/cancellation - if isinstance(error, asyncpg.exceptions.QueryCanceledError): - return _create_postgres_error(error, "57014", QueryTimeoutError, "query canceled") - - # Permission/authentication errors - if isinstance(error, asyncpg.exceptions.InsufficientPrivilegeError): - return _create_postgres_error(error, "42501", PermissionDeniedError, "insufficient privilege") - if isinstance(error, asyncpg.exceptions.InvalidPasswordError): - return _create_postgres_error(error, "28P01", PermissionDeniedError, "invalid password") - if isinstance(error, asyncpg.exceptions.InvalidAuthorizationSpecificationError): - return _create_postgres_error(error, "28000", PermissionDeniedError, "authorization error") - - # Connection errors - if isinstance(error, asyncpg.exceptions.ConnectionDoesNotExistError): - return _create_postgres_error(error, "08003", ConnectionTimeoutError, "connection does not exist") - if isinstance(error, asyncpg.exceptions.CannotConnectNowError): - return _create_postgres_error(error, "57P03", ConnectionTimeoutError, "cannot connect now") - - # SQL syntax errors - if isinstance(error, asyncpg.exceptions.PostgresSyntaxError): - return _create_postgres_error(error, "42601", SQLParsingError, "SQL syntax error") + mapped_error = _EXCEPTION_MAPPING_DISPATCHER.get(error) + if mapped_error is not None: + error_code, error_class, description = mapped_error + return _create_postgres_error(error, error_code, error_class, description) # Priority 2: Fall back to SQLSTATE code mapping using centralized utility sqlstate_attr = error.sqlstate if has_sqlstate(error) else None diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index d3ec176d8..cb00ce7dc 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -438,12 +438,12 @@ def data_dictionary(self) -> "AsyncpgDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AsyncpgCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect asyncpg rows for the direct execution path.""" data, column_names = collect_rows(fetched) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "AsyncpgCursor") -> int: """Resolve rowcount from asyncpg status for the direct execution path.""" return parse_status(cursor) diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index c7cebc77b..a54ab0640 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -165,7 +165,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: + def dispatch_execute(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: """Execute single SQL statement with BigQuery data handling. Args: @@ -206,7 +206,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: affected_rows = build_dml_rowcount(cursor.job, 0) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResult: + def dispatch_execute_many(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: """BigQuery execute_many with Parquet bulk load optimization. Uses Parquet bulk load for INSERT operations (fast path) and falls back @@ -255,7 +255,7 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResul affected_rows = build_dml_rowcount(cursor.job, len(prepared_parameters)) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult: + def dispatch_execute_script(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. @@ -541,14 +541,14 @@ def data_dictionary(self) -> "BigQueryDataDictionary": # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "BigQueryCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect BigQuery rows for the direct execution path.""" schema = cursor.job.schema if cursor.job else None column_names = resolve_column_names(schema, self._column_name_cache) data, _ = collect_rows(fetched, schema, column_names=column_names) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "BigQueryCursor") -> int: """Resolve rowcount from BigQuery job for the direct execution path.""" return build_dml_rowcount(cursor.job, 0) if cursor.job else 0 diff --git a/sqlspec/adapters/cockroach_psycopg/driver.py b/sqlspec/adapters/cockroach_psycopg/driver.py index 8a9f09f6b..0dea542db 100644 --- a/sqlspec/adapters/cockroach_psycopg/driver.py +++ b/sqlspec/adapters/cockroach_psycopg/driver.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from sqlspec.adapters.psycopg.driver import PsycopgAsyncCursor, PsycopgSyncCursor from sqlspec.driver import ExecutionResult __all__ = ( @@ -133,35 +134,35 @@ def _execute_with_retry(self, operation: "Callable[..., ExecutionResult]", *args msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error - def _apply_follower_reads(self, cursor: Any) -> None: + def _apply_follower_reads(self, cursor: "PsycopgSyncCursor") -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") - def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def _dispatch_execute_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": if statement.returns_rows(): self._apply_follower_reads(cursor) return super().dispatch_execute(cursor, statement) - def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def _dispatch_execute_many_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": return super().dispatch_execute_many(cursor, statement) - def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def _dispatch_execute_script_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": return super().dispatch_execute_script(cursor, statement) - def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return self._dispatch_execute_impl(cursor, statement) return self._execute_with_retry(self._dispatch_execute_impl, cursor, statement) - def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_many(cursor, statement) return self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) - def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_script(cursor, statement) return self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) @@ -226,35 +227,35 @@ async def _execute_with_retry(self, operation: "Callable[..., Any]", *args: Any) msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error - async def _apply_follower_reads(self, cursor: Any) -> None: + async def _apply_follower_reads(self, cursor: "PsycopgAsyncCursor") -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return await cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") - async def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": if statement.returns_rows(): await self._apply_follower_reads(cursor) return await super().dispatch_execute(cursor, statement) - async def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_many_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_many(cursor, statement) - async def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_script_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_script(cursor, statement) - async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def dispatch_execute(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await self._dispatch_execute_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_impl, cursor, statement) - async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_many(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) - async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_script(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index c6ee3956e..ece710b54 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -116,7 +116,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": """Execute single SQL statement with data handling. Executes a SQL statement with parameter binding and processes the results. @@ -152,7 +152,7 @@ def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": return self.create_execution_result(cursor, rowcount_override=row_count) - def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": """Execute SQL with multiple parameter sets using batch processing. Uses DuckDB's executemany method for batch operations and calculates @@ -177,7 +177,7 @@ def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True) - def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parses multi-statement scripts and executes each statement sequentially @@ -422,12 +422,12 @@ def data_dictionary(self) -> "DuckDBDataDictionary": # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "DuckDBCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect DuckDB rows for the direct execution path.""" data, column_names = collect_rows(cast("list[Any] | None", fetched), cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "DuckDBCursor") -> int: """Resolve rowcount from DuckDB cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/mock/driver.py b/sqlspec/adapters/mock/driver.py index e20977e2a..3466b1d39 100644 --- a/sqlspec/adapters/mock/driver.py +++ b/sqlspec/adapters/mock/driver.py @@ -431,11 +431,11 @@ def _transpile_to_sqlite(self, statement: "SQL") -> str: return sql return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "MockCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mock sync rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "MockCursor") -> int: """Resolve rowcount from mock cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -735,11 +735,11 @@ def _transpile_to_sqlite(self, statement: "SQL") -> str: return sql return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "MockAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mock async rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "MockAsyncCursor") -> int: """Resolve rowcount from mock cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/mysqlconnector/config.py b/sqlspec/adapters/mysqlconnector/config.py index d95edd3d7..72759b25b 100644 --- a/sqlspec/adapters/mysqlconnector/config.py +++ b/sqlspec/adapters/mysqlconnector/config.py @@ -221,7 +221,9 @@ class MysqlConnectorSyncConfig( driver_type: ClassVar[type[MysqlConnectorSyncDriver]] = MysqlConnectorSyncDriver connection_type: ClassVar[type[MysqlConnectorSyncConnection]] = MysqlConnectorSyncConnection - _connection_context_class: "ClassVar[type[MysqlConnectorSyncConnectionContext]]" = MysqlConnectorSyncConnectionContext + _connection_context_class: "ClassVar[type[MysqlConnectorSyncConnectionContext]]" = ( + MysqlConnectorSyncConnectionContext + ) _session_factory_class: "ClassVar[type[_MysqlConnectorSyncSessionConnectionHandler]]" = ( _MysqlConnectorSyncSessionConnectionHandler ) diff --git a/sqlspec/adapters/mysqlconnector/driver.py b/sqlspec/adapters/mysqlconnector/driver.py index 632c72837..f488159b2 100644 --- a/sqlspec/adapters/mysqlconnector/driver.py +++ b/sqlspec/adapters/mysqlconnector/driver.py @@ -123,7 +123,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorSyncDataDictionary | None = None - def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -150,7 +150,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -160,7 +160,7 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResu affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -275,7 +275,9 @@ def data_dictionary(self) -> "MysqlConnectorSyncDataDictionary": self._data_dictionary = MysqlConnectorSyncDataDictionary() return self._data_dictionary - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows( + self, cursor: "MysqlConnectorSyncCursor", fetched: "list[Any]" + ) -> "tuple[list[Any], list[str], int]": """Collect mysql-connector sync rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -286,7 +288,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "MysqlConnectorSyncCursor") -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -355,7 +357,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorAsyncDataDictionary | None = None - async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -382,7 +384,7 @@ async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionRes last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -392,7 +394,7 @@ async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "Executi affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -509,7 +511,9 @@ def data_dictionary(self) -> "MysqlConnectorAsyncDataDictionary": self._data_dictionary = MysqlConnectorAsyncDataDictionary() return self._data_dictionary - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows( + self, cursor: "MysqlConnectorAsyncCursor", fetched: "list[Any]" + ) -> "tuple[list[Any], list[str], int]": """Collect mysql-connector async rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -520,7 +524,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "MysqlConnectorAsyncCursor") -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 26f42c45f..27e27b3b2 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -380,7 +380,9 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncC driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker _connection_context_class: "ClassVar[type[OracleAsyncConnectionContext]]" = OracleAsyncConnectionContext - _session_factory_class: "ClassVar[type[_OracleAsyncSessionConnectionHandler]]" = _OracleAsyncSessionConnectionHandler + _session_factory_class: "ClassVar[type[_OracleAsyncSessionConnectionHandler]]" = ( + _OracleAsyncSessionConnectionHandler + ) _session_context_class: "ClassVar[type[OracleAsyncSessionContext]]" = OracleAsyncSessionContext _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 5d4064194..16ab79cdd 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -259,7 +259,9 @@ async def __aenter__(self) -> AsyncCursor: self.cursor = self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> None: _ = (exc_type, exc_val, exc_tb) # Mark as intentionally unused if self.cursor is not None: with contextlib.suppress(Exception): @@ -699,7 +701,7 @@ def data_dictionary(self) -> "OracledbSyncDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "OracleSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle sync rows for the direct execution path.""" column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( @@ -711,7 +713,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "OracleSyncCursor") -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -1201,7 +1203,7 @@ def data_dictionary(self) -> "OracledbAsyncDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "OracleAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle async rows for the direct execution path. Uses synchronous LOB coercion. For async LOB coercion, the standard @@ -1217,7 +1219,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "OracleAsyncCursor") -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index d3c35e628..bc4036ce5 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -253,10 +253,8 @@ async def _ensure_connection_initialized(self, connection: "PsqlpyConnection") - detected_extensions = {r["extname"] for r in rows} except Exception: detected_extensions = set() - self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( - self.statement_config, - self.driver_features, - detected_extensions, + self.statement_config, self._pgvector_available, self._paradedb_available = ( + resolve_postgres_extension_state(self.statement_config, self.driver_features, detected_extensions) ) conn_id = id(connection) diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index 230d8a4d2..008e64379 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -29,6 +29,7 @@ UniqueViolationError, ) from sqlspec.typing import PGVECTOR_INSTALLED, Empty +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json from sqlspec.utils.type_converters import build_nested_decimal_normalizer @@ -76,6 +77,7 @@ _DECIMAL_NORMALIZER = build_nested_decimal_normalizer(mode="float") _JSONB_TYPE: "type[Any] | None" = None _JSONB_RESOLVED: bool = False +_TYPE_COERCION_DISPATCHERS: "dict[tuple[tuple[type, Callable[[Any], Any]], ...], TypeDispatcher[Callable[[Any], Any]]]" = {} PSQLPY_STATUS_REGEX: "re.Pattern[str]" = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) logger = get_logger("sqlspec.adapters.psqlpy.core") @@ -264,6 +266,7 @@ def apply_driver_features( return statement_config, features + def collect_rows(query_result: Any | None) -> "tuple[list[dict[str, Any]], list[str]]": """Collect psqlpy rows and column names. @@ -416,6 +419,20 @@ def get_parameter_casts(statement: "SQL") -> "dict[int, str]": return {} +def _get_type_coercion_dispatcher( + type_map: "dict[type, Callable[[Any], Any]]", +) -> "TypeDispatcher[Callable[[Any], Any]]": + fallback_items = tuple(type_map.items()) + dispatcher = _TYPE_COERCION_DISPATCHERS.get(fallback_items) + if dispatcher is not None: + return dispatcher + + dispatcher = TypeDispatcher["Callable[[Any], Any]"]() + dispatcher.register_all(fallback_items) + _TYPE_COERCION_DISPATCHERS[fallback_items] = dispatcher + return dispatcher + + def prepare_parameters_with_casts( parameters: Any, parameter_casts: "dict[int, str]", statement_config: "StatementConfig" ) -> Any: @@ -424,14 +441,18 @@ def prepare_parameters_with_casts( result: list[Any] = [] serializer = statement_config.parameter_config.json_serializer or to_json type_map = statement_config.parameter_config.type_coercion_map + dispatcher = _get_type_coercion_dispatcher(type_map) if type_map else None for idx, param in enumerate(parameters, start=1): cast_type = parameter_casts.get(idx, "") prepared_value = param - if type_map: - for type_check, converter in type_map.items(): - if isinstance(prepared_value, type_check): - prepared_value = converter(prepared_value) - break + if type_map and dispatcher is not None: + exact_converter = type_map.get(type(prepared_value)) + if exact_converter is not None: + prepared_value = exact_converter(prepared_value) + else: + fallback_converter = dispatcher.get(prepared_value) + if fallback_converter is not None: + prepared_value = fallback_converter(prepared_value) if cast_type: prepared_value = _coerce_parameter_for_cast(prepared_value, cast_type, serializer) result.append(prepared_value) diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index df5921b0e..ecc1615cc 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -420,7 +420,7 @@ def data_dictionary(self) -> "PsqlpyDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "PsqlpyCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psqlpy rows for the direct execution path. The ``fetched`` argument may be a psqlpy query result or a plain list. @@ -428,7 +428,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l dict_rows, column_names = collect_rows(fetched) return dict_rows, column_names, len(dict_rows) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "PsqlpyCursor") -> int: """Resolve rowcount from psqlpy result for the direct execution path.""" return extract_rows_affected(cursor) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index bae4c91cb..0159fb44f 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -301,10 +301,8 @@ def _configure_connection(self, conn: "PsycopgSyncConnection") -> None: detected_extensions = {r[0] for r in results} # type: ignore[index] except Exception: detected_extensions = set() - self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( - self.statement_config, - self.driver_features, - detected_extensions, + self.statement_config, self._pgvector_available, self._paradedb_available = ( + resolve_postgres_extension_state(self.statement_config, self.driver_features, detected_extensions) ) if self._pgvector_available: @@ -574,10 +572,8 @@ async def _configure_async_connection(self, conn: "PsycopgAsyncConnection") -> N detected_extensions = {r[0] for r in results} # type: ignore[index] except Exception: detected_extensions = set() - self.statement_config, self._pgvector_available, self._paradedb_available = resolve_postgres_extension_state( - self.statement_config, - self.driver_features, - detected_extensions, + self.statement_config, self._pgvector_available, self._paradedb_available = ( + resolve_postgres_extension_state(self.statement_config, self.driver_features, detected_extensions) ) if self._pgvector_available: diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index 9a52c9abd..3b6d11b93 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -213,6 +213,7 @@ def apply_driver_features( return statement_config, features + def collect_rows(fetched_data: "list[Any] | None", description: "list[Any] | None") -> "tuple[list[Any], list[str]]": """Collect psycopg rows and column names. diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 914d22c6c..58eb34ed7 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -596,7 +596,9 @@ async def __aenter__(self) -> Any: self.cursor = self.connection.cursor() return self.cursor - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> None: + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> None: _ = (exc_type, exc_val, exc_tb) if self.cursor is not None: await self.cursor.close() @@ -657,7 +659,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement (async). Args: @@ -688,7 +690,7 @@ async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionRes affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets (async). Args: @@ -709,7 +711,7 @@ async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "Executi return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements (async). Args: @@ -733,7 +735,7 @@ async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "Execu last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True ) - async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": + async def dispatch_special_handling(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: @@ -1047,13 +1049,13 @@ def _resolve_column_names(self, description: Any) -> list[str]: self._column_name_cache[cache_key] = (description, column_names) return column_names - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "PsycopgAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg async rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "PsycopgAsyncCursor") -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/pymysql/driver.py b/sqlspec/adapters/pymysql/driver.py index ccec56dcd..ab6c5ff94 100644 --- a/sqlspec/adapters/pymysql/driver.py +++ b/sqlspec/adapters/pymysql/driver.py @@ -102,7 +102,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: PyMysqlDataDictionary | None = None - def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -129,7 +129,7 @@ def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -139,7 +139,7 @@ def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResu affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -254,7 +254,7 @@ def data_dictionary(self) -> "PyMysqlDataDictionary": self._data_dictionary = PyMysqlDataDictionary() return self._data_dictionary - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "PyMysqlCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect PyMySQL rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -265,7 +265,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "PyMysqlCursor") -> int: """Resolve rowcount from PyMySQL cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index e22dc7b5e..15f175e0a 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -372,7 +372,7 @@ def data_dictionary(self) -> "SpannerDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "SpannerSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Spanner rows for the direct execution path. Note: Spanner's collect_rows requires result set fields and a type converter. @@ -390,7 +390,7 @@ def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], l # For tuple rows without metadata, return as-is return fetched, [], len(fetched) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "SpannerSyncCursor") -> int: """Resolve rowcount from Spanner cursor for the direct execution path.""" # Spanner uses execute_update return value, not cursor.rowcount return 0 diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 06a5283d8..574cc43e8 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -544,11 +544,11 @@ def data_dictionary(self) -> "SqliteDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "SqliteCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect SQLite rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: Any) -> int: + def resolve_rowcount(self, cursor: "SqliteCursor") -> int: """Resolve rowcount from SQLite cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/base.py b/sqlspec/base.py index 7fd0ab363..429c3748c 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -60,7 +60,9 @@ def __enter__(self) -> ConnectionT: self._runtime.emit_connection_create(self._connection) return self._connection - def __exit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": + def __exit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": try: return self._context.__exit__(exc_type, exc_val, exc_tb) finally: @@ -80,7 +82,9 @@ async def __aenter__(self) -> ConnectionT: self._runtime.emit_connection_create(self._connection) return self._connection - async def __aexit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": try: return await self._context.__aexit__(exc_type, exc_val, exc_tb) finally: @@ -111,7 +115,9 @@ def __enter__(self) -> DriverT: self._runtime.emit_session_start(driver) return driver - def __exit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": + def __exit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": try: return self._context.__exit__(exc_type, exc_val, exc_tb) finally: @@ -145,7 +151,9 @@ async def __aenter__(self) -> DriverT: self._runtime.emit_session_start(driver) return driver - async def __aexit__(self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None") -> "bool | None": + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": try: return await self._context.__aexit__(exc_type, exc_val, exc_tb) finally: @@ -256,7 +264,9 @@ async def __aenter__(self) -> Self: """Async context manager entry.""" return self - async def __aexit__(self, _exc_type: "type[BaseException] | None", _exc_val: "BaseException | None", _exc_tb: "TracebackType | None") -> None: + async def __aexit__( + self, _exc_type: "type[BaseException] | None", _exc_val: "BaseException | None", _exc_tb: "TracebackType | None" + ) -> None: """Async context manager exit with automatic cleanup.""" await self.close_all_pools() diff --git a/sqlspec/builder/_merge.py b/sqlspec/builder/_merge.py index 17ab55da3..f48ad6563 100644 --- a/sqlspec/builder/_merge.py +++ b/sqlspec/builder/_merge.py @@ -23,6 +23,7 @@ from sqlspec.builder._select import is_explicitly_quoted from sqlspec.core import SQLResult from sqlspec.exceptions import DialectNotSupportedError, SQLBuilderError +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.serializers import to_json from sqlspec.utils.type_guards import has_expression_and_sql @@ -32,6 +33,24 @@ __all__ = ("Merge",) MERGE_UNSUPPORTED_DIALECTS = frozenset({"mysql", "sqlite", "duckdb"}) +_POSTGRES_TYPE_DISPATCHER = TypeDispatcher[str]() +_ORACLE_TYPE_DISPATCHER = TypeDispatcher[str]() + +_POSTGRES_TYPE_DISPATCHER.register(bool, "BOOLEAN") +_POSTGRES_TYPE_DISPATCHER.register(int, "INTEGER") +_POSTGRES_TYPE_DISPATCHER.register(float, "DOUBLE PRECISION") +_POSTGRES_TYPE_DISPATCHER.register(Decimal, "NUMERIC") +_POSTGRES_TYPE_DISPATCHER.register(dict, "JSONB") +_POSTGRES_TYPE_DISPATCHER.register(list, "JSONB") +_POSTGRES_TYPE_DISPATCHER.register(datetime, "TIMESTAMP") + +_ORACLE_TYPE_DISPATCHER.register(bool, "NUMBER(1)") +_ORACLE_TYPE_DISPATCHER.register(int, "NUMBER") +_ORACLE_TYPE_DISPATCHER.register(float, "NUMBER") +_ORACLE_TYPE_DISPATCHER.register(Decimal, "NUMBER") +_ORACLE_TYPE_DISPATCHER.register(dict, "JSON") +_ORACLE_TYPE_DISPATCHER.register(list, "JSON") +_ORACLE_TYPE_DISPATCHER.register(datetime, "TIMESTAMP") @trait @@ -263,32 +282,18 @@ def _infer_postgres_type(self, value: "Any") -> str: """ if value is None: return "NUMERIC" - if isinstance(value, bool): - return "BOOLEAN" - if isinstance(value, int): - return "INTEGER" - if isinstance(value, float): - return "DOUBLE PRECISION" - if isinstance(value, Decimal): - return "NUMERIC" - if isinstance(value, (dict, list)): - return "JSONB" - if isinstance(value, datetime): - return "TIMESTAMP" + resolved_type = _POSTGRES_TYPE_DISPATCHER.get(value) + if resolved_type is not None: + return resolved_type return "TEXT" def _infer_oracle_type(self, value: "Any") -> str: """Infer Oracle column type for JSON_TABLE projection.""" varchar2_max = 4000 - if isinstance(value, bool): - return "NUMBER(1)" - if isinstance(value, (int, float, Decimal)): - return "NUMBER" - if isinstance(value, (dict, list)): - return "JSON" - if isinstance(value, datetime): - return "TIMESTAMP" + resolved_type = _ORACLE_TYPE_DISPATCHER.get(value) + if resolved_type is not None: + return resolved_type if value is not None and len(str(value)) > varchar2_max: return "CLOB" return f"VARCHAR2({varchar2_max})" diff --git a/sqlspec/core/config_runtime.py b/sqlspec/core/config_runtime.py index e8f2a7f12..2154ae784 100644 --- a/sqlspec/core/config_runtime.py +++ b/sqlspec/core/config_runtime.py @@ -32,15 +32,13 @@ def build_default_statement_config(default_dialect: str) -> StatementConfig: return StatementConfig( dialect=default_dialect, parameter_config=ParameterStyleConfig( - default_parameter_style=ParameterStyle.QMARK, - supported_parameter_styles={ParameterStyle.QMARK}, + default_parameter_style=ParameterStyle.QMARK, supported_parameter_styles={ParameterStyle.QMARK} ), ) def seed_runtime_driver_features( - driver_features: "dict[str, Any] | None", - storage_capabilities: "dict[str, Any] | StorageCapabilities | None", + driver_features: "dict[str, Any] | None", storage_capabilities: "dict[str, Any] | StorageCapabilities | None" ) -> "dict[str, Any]": """Clone and seed driver feature state used on the runtime hot path.""" seeded_features = dict(driver_features) if driver_features else {} @@ -69,7 +67,9 @@ def resolve_postgres_extension_state( ) -> "tuple[StatementConfig, bool, bool]": """Resolve detected PostgreSQL extension flags and promoted dialect.""" detected = detected_extensions or set() - pgvector_available = bool(driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected) + pgvector_available = bool( + driver_features and driver_features.get("enable_pgvector", False) and "vector" in detected + ) paradedb_available = bool( driver_features and driver_features.get("enable_paradedb", False) and "pg_search" in detected ) @@ -117,9 +117,7 @@ def create_sync_pool( def close_sync_pool( - connection_instance: "PoolT | None", - close_pool: "Callable[[], None]", - emit_pool_destroy: "Callable[[PoolT], None]", + connection_instance: "PoolT | None", close_pool: "Callable[[], None]", emit_pool_destroy: "Callable[[PoolT], None]" ) -> None: """Close a sync pool and emit teardown hooks.""" close_pool() diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index f4a428529..dd44ebcdc 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -20,6 +20,7 @@ wrap_with_type, ) from sqlspec.core.parameters._validator import ParameterValidator +from sqlspec.utils.dispatch import TypeDispatcher __all__ = ( "ParameterProcessor", @@ -35,6 +36,7 @@ _EXECUTE_MANY_SAMPLE_SIZE = 3 TypeCoercionFallback = tuple[type, Callable[[Any], Any]] +_TYPE_COERCION_DISPATCHERS: "dict[tuple[TypeCoercionFallback, ...], TypeDispatcher[Callable[[Any], Any]]]" = {} def _structural_fingerprint(parameters: "ParameterPayload", is_many: bool = False) -> Any: @@ -205,6 +207,19 @@ def _type_coercion_fallbacks( return tuple(type_coercion_map.items()) +def _get_type_coercion_dispatcher( + fallback_items: "tuple[TypeCoercionFallback, ...]", +) -> "TypeDispatcher[Callable[[Any], Any]]": + dispatcher = _TYPE_COERCION_DISPATCHERS.get(fallback_items) + if dispatcher is not None: + return dispatcher + + dispatcher = TypeDispatcher["Callable[[Any], Any]"]() + dispatcher.register_all(fallback_items) + _TYPE_COERCION_DISPATCHERS[fallback_items] = dispatcher + return dispatcher + + def _resolve_type_coercion( value: object, type_coercion_map: "dict[type, Callable[[Any], Any]]", @@ -214,11 +229,9 @@ def _resolve_type_coercion( exact_converter = type_coercion_map.get(value_type) if exact_converter is not None: return exact_converter(value) - for type_check, converter in fallback_items: - if type_check is value_type: - continue - if isinstance(value, type_check): - return converter(value) + fallback_converter = _get_type_coercion_dispatcher(fallback_items).get(value) + if fallback_converter is not None: + return fallback_converter(value) return value @@ -253,7 +266,6 @@ def _coerce_parameter_value( wrapped_value: object = typed_param.value if wrapped_value is None: return wrapped_value - original_type = typed_param.original_type coerced = _resolve_type_coercion(wrapped_value, type_coercion_map, fallback_items) if coerced is wrapped_value: return wrapped_value @@ -1023,7 +1035,13 @@ def _process_internal( processed_parameters = self._coerce_parameter_types(processed_parameters, config.type_coercion_map, is_many) processed_sql, processed_parameters, converted_param_info = self._convert_placeholders_for_execution( - processed_sql, processed_parameters, config, param_info, original_styles, needs_execution_conversion, is_many + processed_sql, + processed_parameters, + config, + param_info, + original_styles, + needs_execution_conversion, + is_many, ) if config.output_transformer: diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 441981ed6..d3fe0e9fe 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -32,7 +32,6 @@ ) from sqlspec.exceptions import ImproperConfigurationError, StackExecutionError from sqlspec.storage import AsyncStoragePipeline, StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.typing import VersionInfo from sqlspec.utils.arrow_helpers import convert_dict_to_arrow_with_schema from sqlspec.utils.logging import get_logger from sqlspec.utils.schema import ValueT, to_value_type @@ -45,8 +44,6 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, StatementConfig, StatementFilter - from sqlspec.data_dictionary._types import DialectConfig - from sqlspec.protocols import HasDataProtocol, HasExecuteProtocol from sqlspec.typing import ( ArrowReturnFormat, ArrowTable, @@ -56,6 +53,7 @@ SchemaT, StatementParameters, TableMetadata, + VersionInfo, ) @@ -1911,4 +1909,3 @@ async def get_foreign_keys( List of foreign key metadata """ - diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 28499473c..b5c6d5c24 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -39,6 +39,7 @@ from sqlspec.typing import VersionCacheResult, VersionInfo from sqlspec.utils.logging import get_logger, log_with_context from sqlspec.utils.schema import to_schema as _to_schema_impl +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.type_guards import ( has_array_interface, has_cursor_metadata, @@ -198,7 +199,9 @@ class SyncExceptionHandler(Protocol): def __enter__(self) -> Self: ... - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> bool: ... + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> bool: ... class AsyncExceptionHandler(Protocol): @@ -213,7 +216,9 @@ class AsyncExceptionHandler(Protocol): async def __aenter__(self) -> Self: ... - async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> bool: ... + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> bool: ... logger = get_logger("sqlspec.driver") @@ -224,6 +229,7 @@ async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseExc _CONVERT_TO_TUPLE = object() _CONVERT_TO_FROZENSET = object() +_TYPE_COERCION_DISPATCHERS: "dict[tuple[tuple[type, Any], ...], TypeDispatcher[Any]]" = {} def _type_coercion_fallbacks(type_coercion_map: "dict[type, Any] | None") -> "tuple[tuple[type, Any], ...]": @@ -232,6 +238,17 @@ def _type_coercion_fallbacks(type_coercion_map: "dict[type, Any] | None") -> "tu return tuple(type_coercion_map.items()) +def _get_type_coercion_dispatcher(fallback_items: "tuple[tuple[type, Any], ...]") -> "TypeDispatcher[Any]": + dispatcher = _TYPE_COERCION_DISPATCHERS.get(fallback_items) + if dispatcher is not None: + return dispatcher + + dispatcher = TypeDispatcher[Any]() + dispatcher.register_all(fallback_items) + _TYPE_COERCION_DISPATCHERS[fallback_items] = dispatcher + return dispatcher + + def make_cache_key_hashable(obj: Any) -> Any: """Recursively convert unhashable types to hashable ones for cache keys. @@ -430,7 +447,9 @@ def __enter__(self) -> Self: ) return self - def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None") -> Literal[False]: + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" + ) -> Literal[False]: duration = perf_counter() - self.started self.metrics.record_duration(duration) if exc_val is not None: @@ -1574,9 +1593,8 @@ def prepare_driver_parameters( if self._needs_coercion_candidate(value, type_coercion_map, fallback_items): needs_transform = True break - else: - if self._needs_coercion_candidate(param_set, type_coercion_map, fallback_items): - needs_transform = True + elif self._needs_coercion_candidate(param_set, type_coercion_map, fallback_items): + needs_transform = True if needs_transform: break @@ -1626,12 +1644,9 @@ def _apply_coercion_with_fallback( exact_converter = type_coercion_map.get(value_type) if exact_converter is not None: return exact_converter(value) - - for type_check, converter in fallback_items: - if type_check is value_type: - continue - if isinstance(value, type_check): - return converter(value) + fallback_converter = _get_type_coercion_dispatcher(fallback_items).get(value) + if fallback_converter is not None: + return fallback_converter(value) return value def _needs_coercion_candidate( @@ -1648,13 +1663,7 @@ def _needs_coercion_candidate( value_type = type(value) if value_type in type_coercion_map: return True - - for type_check, _converter in fallback_items: - if type_check is value_type: - continue - if isinstance(value, type_check): - return True - return False + return _get_type_coercion_dispatcher(fallback_items).get(value) is not None def _format_parameter_set_for_many( self, parameters: "StatementParameters", statement_config: "StatementConfig" diff --git a/sqlspec/driver/_exception_handler.py b/sqlspec/driver/_exception_handler.py index a64a060f4..388231c63 100644 --- a/sqlspec/driver/_exception_handler.py +++ b/sqlspec/driver/_exception_handler.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from types import TracebackType +__all__ = ("BaseAsyncExceptionHandler", "BaseSyncExceptionHandler", ) + @mypyc_attr(allow_interpreted_subclasses=True) class BaseAsyncExceptionHandler: @@ -22,21 +24,14 @@ async def __aenter__(self) -> Self: return self async def __aexit__( - self, - exc_type: "type[BaseException] | None", - exc_val: "BaseException | None", - exc_tb: "TracebackType | None", + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool: _ = exc_tb if exc_val is None: return False return self._handle_exception(exc_type, exc_val) - def _handle_exception( - self, - exc_type: "type[BaseException] | None", - exc_val: "BaseException", - ) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: """Handle an adapter exception. Subclasses should set ``pending_exception`` before returning ``True``. @@ -58,21 +53,14 @@ def __enter__(self) -> Self: return self def __exit__( - self, - exc_type: "type[BaseException] | None", - exc_val: "BaseException | None", - exc_tb: "TracebackType | None", + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" ) -> bool: _ = exc_tb if exc_val is None: return False return self._handle_exception(exc_type, exc_val) - def _handle_exception( - self, - exc_type: "type[BaseException] | None", - exc_val: "BaseException", - ) -> bool: + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: """Handle an adapter exception. Subclasses should set ``pending_exception`` before returning ``True``. diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index f67ece950..482b9d9c9 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -32,7 +32,6 @@ ) from sqlspec.exceptions import ImproperConfigurationError, StackExecutionError from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry, SyncStoragePipeline -from sqlspec.typing import VersionInfo from sqlspec.utils.arrow_helpers import convert_dict_to_arrow_with_schema from sqlspec.utils.logging import get_logger from sqlspec.utils.schema import ValueT, to_value_type @@ -54,6 +53,7 @@ SchemaT, StatementParameters, TableMetadata, + VersionInfo, ) _LOGGER_NAME: Final[str] = "sqlspec.driver" @@ -1873,4 +1873,3 @@ def get_foreign_keys( List of foreign key metadata """ - diff --git a/sqlspec/migrations/base.py b/sqlspec/migrations/base.py index 53fec4304..e5874b09f 100644 --- a/sqlspec/migrations/base.py +++ b/sqlspec/migrations/base.py @@ -19,7 +19,11 @@ from sqlspec.utils.sync_tools import await_ if TYPE_CHECKING: + from collections.abc import Awaitable + from sqlspec.config import DatabaseConfigProtocol + from sqlspec.core import SQL + from sqlspec.migrations.version import MigrationVersion from sqlspec.observability import ObservabilityRuntime __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker") @@ -310,7 +314,7 @@ def _detect_missing_columns(self, existing_columns: "set[str]") -> "set[str]": return target_columns - existing_lower @abstractmethod - def ensure_tracking_table(self, driver: DriverT) -> Any: + def ensure_tracking_table(self, driver: DriverT) -> "None | Awaitable[None]": """Create the migration tracking table if it doesn't exist. Implementations should also check for and add any missing columns @@ -319,24 +323,24 @@ def ensure_tracking_table(self, driver: DriverT) -> Any: ... @abstractmethod - def get_current_version(self, driver: DriverT) -> Any: + def get_current_version(self, driver: DriverT) -> "str | None | Awaitable[str | None]": """Get the latest applied migration version.""" ... @abstractmethod - def get_applied_migrations(self, driver: DriverT) -> Any: + def get_applied_migrations(self, driver: DriverT) -> "list[dict[str, Any]] | Awaitable[list[dict[str, Any]]]": """Get all applied migrations in order.""" ... @abstractmethod def record_migration( self, driver: DriverT, version: str, description: str, execution_time_ms: int, checksum: str - ) -> Any: + ) -> "None | Awaitable[None]": """Record a successfully applied migration.""" ... @abstractmethod - def remove_migration(self, driver: DriverT, version: str) -> Any: + def remove_migration(self, driver: DriverT, version: str) -> "None | Awaitable[None]": """Remove a migration record.""" ... @@ -574,32 +578,32 @@ def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "li return None @abstractmethod - def get_migration_files(self) -> Any: + def get_migration_files(self) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]": """Get all migration files sorted by version.""" ... @abstractmethod - def load_migration(self, file_path: Path) -> Any: + def load_migration(self, file_path: Path) -> "dict[str, Any] | Awaitable[dict[str, Any]]": """Load a migration file and extract its components.""" ... @abstractmethod - def execute_upgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any: + def execute_upgrade(self, driver: DriverT, migration: "dict[str, Any]") -> "None | Awaitable[None]": """Execute an upgrade migration.""" ... @abstractmethod - def execute_downgrade(self, driver: DriverT, migration: "dict[str, Any]") -> Any: + def execute_downgrade(self, driver: DriverT, migration: "dict[str, Any]") -> "None | Awaitable[None]": """Execute a downgrade migration.""" ... @abstractmethod - def load_all_migrations(self) -> Any: + def load_all_migrations(self) -> "dict[str, SQL] | Awaitable[dict[str, SQL]]": """Load all migrations into a single namespace for bulk operations.""" ... -def _migration_sort_key(item: "tuple[str, Path]") -> Any: +def _migration_sort_key(item: "tuple[str, Path]") -> "MigrationVersion": return parse_version(item[0]) @@ -824,31 +828,31 @@ def _resolve_output_policy( return resolved_use_logger, resolved_echo, resolved_summary_only @abstractmethod - def init(self, directory: str, package: bool = True) -> Any: + def init(self, directory: str, package: bool = True) -> "None | Awaitable[None]": """Initialize migration directory structure.""" ... @abstractmethod - def current(self, verbose: bool = False) -> Any: + def current(self, verbose: bool = False) -> "str | None | Awaitable[str | None]": """Show current migration version.""" ... @abstractmethod - def upgrade(self, revision: str = "head") -> Any: + def upgrade(self, revision: str = "head") -> "None | Awaitable[None]": """Upgrade to a target revision.""" ... @abstractmethod - def downgrade(self, revision: str = "-1") -> Any: + def downgrade(self, revision: str = "-1") -> "None | Awaitable[None]": """Downgrade to a target revision.""" ... @abstractmethod - def stamp(self, revision: str) -> Any: + def stamp(self, revision: str) -> "None | Awaitable[None]": """Mark database as being at a specific revision without running migrations.""" ... @abstractmethod - def revision(self, message: str, file_type: str = "sql") -> Any: + def revision(self, message: str, file_type: str = "sql") -> "None | Awaitable[None]": """Create a new migration file.""" ... diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 58291f6c0..6ce92a326 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -8,7 +8,7 @@ import time from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlspec.core import SQL from sqlspec.loader import SQLFileLoader @@ -267,7 +267,7 @@ def calculate_checksum(self, content: str) -> str: return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324 @abstractmethod - def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]: + def load_migration(self, file_path: Path) -> "dict[str, Any] | Awaitable[dict[str, Any]]": """Load a migration file and extract its components. Args: diff --git a/sqlspec/storage/pipeline.py b/sqlspec/storage/pipeline.py index fb4ae506a..db3835644 100644 --- a/sqlspec/storage/pipeline.py +++ b/sqlspec/storage/pipeline.py @@ -347,11 +347,7 @@ def write_rows( payload = _encode_row_payload(serialized, format_choice) resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options return self._write_bytes( - payload, - destination, - rows=len(serialized), - format_label=format_choice, - storage_options=resolved_options, + payload, destination, rows=len(serialized), format_label=format_choice, storage_options=resolved_options ) def write_arrow( @@ -372,11 +368,7 @@ def write_arrow( table, format_choice, compression=compression, write_options=format_write_options ) return self._write_bytes( - payload, - destination, - rows=int(table.num_rows), - format_label=format_choice, - storage_options=resolved_options, + payload, destination, rows=int(table.num_rows), format_label=format_choice, storage_options=resolved_options ) def read_arrow( @@ -490,11 +482,7 @@ async def write_rows( payload = await async_(_encode_row_payload)(serialized, format_choice) resolved_options = _EMPTY_STORAGE_OPTIONS if storage_options is None else storage_options return await self._write_bytes_async( - payload, - destination, - rows=len(serialized), - format_label=format_choice, - storage_options=resolved_options, + payload, destination, rows=len(serialized), format_label=format_choice, storage_options=resolved_options ) async def write_arrow( @@ -513,11 +501,7 @@ async def write_arrow( table, format_choice, compression=compression, write_options=format_write_options ) return await self._write_bytes_async( - payload, - destination, - rows=int(table.num_rows), - format_label=format_choice, - storage_options=resolved_options, + payload, destination, rows=int(table.num_rows), format_label=format_choice, storage_options=resolved_options ) async def cleanup_staging_artifacts(self, artifacts: "list[StagedArtifact]", *, ignore_errors: bool = True) -> None: diff --git a/sqlspec/utils/arrow_helpers.py b/sqlspec/utils/arrow_helpers.py index 17c45ec12..ca7cf09eb 100644 --- a/sqlspec/utils/arrow_helpers.py +++ b/sqlspec/utils/arrow_helpers.py @@ -10,6 +10,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Literal, overload +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow from sqlspec.utils.type_guards import has_arrow_table_stats, has_get_data @@ -33,6 +34,7 @@ "convert_dict_to_arrow_with_schema", "ensure_arrow_table", ) +_ARROW_TABLE_COERCER: "TypeDispatcher[Any] | None" = None @overload @@ -161,24 +163,47 @@ def convert_dict_to_arrow_with_schema( def coerce_arrow_table(source: "ArrowResult | Any") -> "ArrowTable": """Coerce various sources to a PyArrow Table.""" ensure_pyarrow() - import pyarrow as pa if has_get_data(source): table = source.get_data() - if isinstance(table, pa.Table): + if _get_arrow_table_coercer().get(table) is _coerce_arrow_table_identity: return table msg = "ArrowResult did not return a pyarrow.Table instance" raise TypeError(msg) - if isinstance(source, pa.Table): - return source - if isinstance(source, pa.RecordBatch): - return pa.Table.from_batches([source]) + coercer = _get_arrow_table_coercer().get(source) + if coercer is not None: + return coercer(source) if isinstance(source, Iterable): + import pyarrow as pa + return pa.Table.from_pylist(list(source)) msg = f"Unsupported Arrow source type: {type(source).__name__}" raise TypeError(msg) +def _coerce_arrow_table_identity(source: Any) -> Any: + return source + + +def _coerce_arrow_record_batch(source: Any) -> Any: + import pyarrow as pa + + return pa.Table.from_batches([source]) + + +def _get_arrow_table_coercer() -> "TypeDispatcher[Any]": + global _ARROW_TABLE_COERCER + if _ARROW_TABLE_COERCER is None: + ensure_pyarrow() + import pyarrow as pa + + dispatcher = TypeDispatcher[Any]() + dispatcher.register(pa.Table, _coerce_arrow_table_identity) + dispatcher.register(pa.RecordBatch, _coerce_arrow_record_batch) + _ARROW_TABLE_COERCER = dispatcher + return _ARROW_TABLE_COERCER + + def ensure_arrow_table(data: Any) -> "ArrowTable": """Ensure data is a PyArrow Table.""" ensure_pyarrow() diff --git a/sqlspec/utils/dispatch.py b/sqlspec/utils/dispatch.py index 30132eb7d..e3204d065 100644 --- a/sqlspec/utils/dispatch.py +++ b/sqlspec/utils/dispatch.py @@ -1,11 +1,15 @@ -from typing import Any, Generic, TypeVar +from typing import Any, Final, Generic, TypeVar + +from mypy_extensions import mypyc_attr __all__ = ("TypeDispatcher",) T = TypeVar("T") +_MISSING: Final[object] = object() +@mypyc_attr(allow_interpreted_subclasses=False) class TypeDispatcher(Generic[T]): """O(1) type lookup cache for Mypyc-compatible dispatch. @@ -29,6 +33,12 @@ def register(self, type_: type, value: T) -> None: self._registry[type_] = value self._cache.clear() # Invalidate cache on new registration + def register_all(self, registrations: "tuple[tuple[type, T], ...]") -> None: + """Register multiple values while invalidating the cache once.""" + for type_, value in registrations: + self._registry[type_] = value + self._cache.clear() + def get(self, obj: Any) -> T | None: """Get the value associated with the object's type. @@ -40,9 +50,13 @@ def get(self, obj: Any) -> T | None: Returns: The associated value or None if not found. """ - obj_type = type(obj) - if obj_type in self._cache: - return self._cache[obj_type] + return self.resolve_type(type(obj)) + + def resolve_type(self, obj_type: type) -> T | None: + """Resolve a value directly from a concrete runtime type.""" + cached_value = self._cache.get(obj_type, _MISSING) + if cached_value is not _MISSING: + return cached_value # type: ignore[return-value] return self._resolve(obj_type) @@ -56,16 +70,29 @@ def _resolve(self, obj_type: type) -> T | None: The resolved value or None. """ # Fast path: check registry directly - if obj_type in self._registry: - self._cache[obj_type] = self._registry[obj_type] - return self._registry[obj_type] + direct_value = self._registry.get(obj_type, _MISSING) + if direct_value is not _MISSING: + self._cache[obj_type] = direct_value # type: ignore[assignment] + return direct_value # type: ignore[return-value] # Slow path: walk MRO - for base in obj_type.__mro__: - if base in self._registry: - value = self._registry[base] - self._cache[obj_type] = value - return value + for base in obj_type.__mro__[1:]: + value = self._registry.get(base, _MISSING) + if value is not _MISSING: + self._cache[obj_type] = value # type: ignore[assignment] + return value # type: ignore[return-value] + + # ABC/protocol fallback: issubclass() resolves virtual hierarchies not present in __mro__. + for registered_type, value in self._registry.items(): + if registered_type is obj_type: + continue + try: + if not issubclass(obj_type, registered_type): + continue + except TypeError: + continue + self._cache[obj_type] = value + return value return None diff --git a/sqlspec/utils/schema.py b/sqlspec/utils/schema.py index 0161266fa..585ed1265 100644 --- a/sqlspec/utils/schema.py +++ b/sqlspec/utils/schema.py @@ -23,6 +23,7 @@ get_type_adapter, ) from sqlspec.utils.logging import get_logger +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.serializers import from_json from sqlspec.utils.text import camelize, kebabize, pascalize from sqlspec.utils.type_guards import ( @@ -61,6 +62,7 @@ "kebab": kebabize, "pascal": pascalize, } +_NUMPY_RECURSIVE_DISPATCHER: "TypeDispatcher[Callable[[Any], Any]] | None" = None # ============================================================================= @@ -395,18 +397,39 @@ def _convert_numpy_recursive(obj: Any) -> Any: if not NUMPY_INSTALLED: return obj - import numpy as np - - if isinstance(obj, np.ndarray): - return obj.tolist() - if isinstance(obj, dict): - return {k: _convert_numpy_recursive(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - converted = [_convert_numpy_recursive(item) for item in obj] - return type(obj)(converted) + handler = _get_numpy_recursive_dispatcher().get(obj) + if handler is not None: + return handler(obj) return obj +def _convert_numpy_array(obj: Any) -> Any: + return obj.tolist() + + +def _convert_numpy_mapping(obj: Any) -> Any: + return {key: _convert_numpy_recursive(value) for key, value in obj.items()} + + +def _convert_numpy_sequence(obj: Any) -> Any: + converted = [_convert_numpy_recursive(item) for item in obj] + return type(obj)(converted) + + +def _get_numpy_recursive_dispatcher() -> "TypeDispatcher[Callable[[Any], Any]]": + global _NUMPY_RECURSIVE_DISPATCHER + if _NUMPY_RECURSIVE_DISPATCHER is None: + import numpy as np + + dispatcher = TypeDispatcher["Callable[[Any], Any]"]() + dispatcher.register(np.ndarray, _convert_numpy_array) + dispatcher.register(dict, _convert_numpy_mapping) + dispatcher.register(list, _convert_numpy_sequence) + dispatcher.register(tuple, _convert_numpy_sequence) + _NUMPY_RECURSIVE_DISPATCHER = dispatcher + return _NUMPY_RECURSIVE_DISPATCHER + + def _convert_msgspec(data: Any, schema_type: Any) -> Any: """Convert data to msgspec Struct.""" rename_config = get_msgspec_rename_config(schema_type) diff --git a/sqlspec/utils/type_converters.py b/sqlspec/utils/type_converters.py index a855d9035..e28c657b2 100644 --- a/sqlspec/utils/type_converters.py +++ b/sqlspec/utils/type_converters.py @@ -3,6 +3,8 @@ import decimal from typing import TYPE_CHECKING, Any +from sqlspec.utils.dispatch import TypeDispatcher + if TYPE_CHECKING: import datetime from collections.abc import Callable, Sequence @@ -19,6 +21,7 @@ JSON_NESTED_TYPES: "tuple[type[Any], ...]" = (dict, list, tuple) DEFAULT_DECIMAL_MODE: str = "preserve" +_DECIMAL_NORMALIZER_DISPATCHER = TypeDispatcher["Callable[['_DecimalNormalizer', Any], Any]"]() def _decimal_identity(value: "decimal.Decimal") -> "decimal.Decimal": @@ -67,39 +70,56 @@ def __init__(self, decimal_converter: "Callable[[decimal.Decimal], Any]") -> Non self._decimal_converter = decimal_converter def __call__(self, value: Any) -> Any: - if isinstance(value, decimal.Decimal): - return self._decimal_converter(value) - if isinstance(value, list): - normalized_list: list[Any] | None = None - for index, item in enumerate(value): - normalized_item = self(item) - if normalized_list is None: - if normalized_item is item: - continue - normalized_list = list(value[:index]) - normalized_list.append(normalized_item) - return value if normalized_list is None else normalized_list - if isinstance(value, tuple): - normalized_tuple: list[Any] | None = None - for index, item in enumerate(value): - normalized_item = self(item) - if normalized_tuple is None: - if normalized_item is item: - continue - normalized_tuple = list(value[:index]) - normalized_tuple.append(normalized_item) - return value if normalized_tuple is None else tuple(normalized_tuple) - if isinstance(value, dict): - normalized_dict: dict[Any, Any] | None = None - for key, item in value.items(): - normalized_item = self(item) - if normalized_dict is None: - if normalized_item is item: - continue - normalized_dict = dict(value) - normalized_dict[key] = normalized_item - return value if normalized_dict is None else normalized_dict - return value + handler = _DECIMAL_NORMALIZER_DISPATCHER.get(value) + if handler is None: + return value + return handler(self, value) + + +def _normalize_decimal_value(normalizer: "_DecimalNormalizer", value: Any) -> Any: + return normalizer._decimal_converter(value) + + +def _normalize_decimal_list(normalizer: "_DecimalNormalizer", value: Any) -> Any: + normalized_list: list[Any] | None = None + for index, item in enumerate(value): + normalized_item = normalizer(item) + if normalized_list is None: + if normalized_item is item: + continue + normalized_list = list(value[:index]) + normalized_list.append(normalized_item) + return value if normalized_list is None else normalized_list + + +def _normalize_decimal_tuple(normalizer: "_DecimalNormalizer", value: Any) -> Any: + normalized_tuple: list[Any] | None = None + for index, item in enumerate(value): + normalized_item = normalizer(item) + if normalized_tuple is None: + if normalized_item is item: + continue + normalized_tuple = list(value[:index]) + normalized_tuple.append(normalized_item) + return value if normalized_tuple is None else tuple(normalized_tuple) + + +def _normalize_decimal_dict(normalizer: "_DecimalNormalizer", value: Any) -> Any: + normalized_dict: dict[Any, Any] | None = None + for key, item in value.items(): + normalized_item = normalizer(item) + if normalized_dict is None: + if normalized_item is item: + continue + normalized_dict = dict(value) + normalized_dict[key] = normalized_item + return value if normalized_dict is None else normalized_dict + + +_DECIMAL_NORMALIZER_DISPATCHER.register(decimal.Decimal, _normalize_decimal_value) +_DECIMAL_NORMALIZER_DISPATCHER.register(list, _normalize_decimal_list) +_DECIMAL_NORMALIZER_DISPATCHER.register(tuple, _normalize_decimal_tuple) +_DECIMAL_NORMALIZER_DISPATCHER.register(dict, _normalize_decimal_dict) def _time_iso_convert(value: "datetime.date | datetime.datetime | datetime.time") -> str: diff --git a/tests/unit/adapters/test_adbc/test_core.py b/tests/unit/adapters/test_adbc/test_core.py index c5650262c..b6b479314 100644 --- a/tests/unit/adapters/test_adbc/test_core.py +++ b/tests/unit/adapters/test_adbc/test_core.py @@ -1,5 +1,6 @@ """Unit tests for ADBC core execute-many helpers.""" +from collections.abc import Sequence from types import SimpleNamespace from typing import Any @@ -93,12 +94,39 @@ def fake_factory(dialect: str, cache_size: int = 5000) -> FakeConverter: monkeypatch.setattr(adbc_core, "get_adbc_type_converter", fake_factory) prepared = prepare_parameters_with_casts( - [{"id": 1}], - {}, - statement_config, - dialect="postgres", - json_serializer=lambda value: str(value), + [{"id": 1}], {}, statement_config, dialect="postgres", json_serializer=lambda value: str(value) ) assert prepared == ["factory:[('id', 1)]"] assert factory_calls == [("postgres", 5000)] + + +def test_prepare_parameters_with_casts_supports_subclass_type_dispatch() -> None: + class MyInt(int): + pass + + statement_config = get_statement_config("postgres") + statement_config = statement_config.replace( + parameter_config=statement_config.parameter_config.replace(type_coercion_map={int: lambda value: value + 1}) + ) + + prepared = prepare_parameters_with_casts( + [MyInt(4)], {}, statement_config, dialect="postgres", json_serializer=lambda value: str(value) + ) + + assert prepared == [5] + + +def test_prepare_parameters_with_casts_supports_virtual_abc_dispatch() -> None: + statement_config = get_statement_config("postgres") + statement_config = statement_config.replace( + parameter_config=statement_config.parameter_config.replace( + type_coercion_map={Sequence: lambda value: tuple(value)} + ) + ) + + prepared = prepare_parameters_with_casts( + [[1, 2]], {}, statement_config, dialect="postgres", json_serializer=lambda value: str(value) + ) + + assert prepared == [(1, 2)] diff --git a/tests/unit/adapters/test_adbc/test_extension_detection.py b/tests/unit/adapters/test_adbc/test_extension_detection.py index 331eb0d2d..bd30d3d6f 100644 --- a/tests/unit/adapters/test_adbc/test_extension_detection.py +++ b/tests/unit/adapters/test_adbc/test_extension_detection.py @@ -111,9 +111,7 @@ def test_adbc_config_initializes_extension_flags_to_none() -> None: def test_resolve_postgres_extension_state_promotes_paradedb() -> None: """Detected extensions should promote the runtime dialect.""" statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( - get_statement_config("postgres"), - {"enable_pgvector": True, "enable_paradedb": True}, - {"vector", "pg_search"}, + get_statement_config("postgres"), {"enable_pgvector": True, "enable_paradedb": True}, {"vector", "pg_search"} ) assert statement_config.dialect == "paradedb" diff --git a/tests/unit/adapters/test_asyncpg/test_config.py b/tests/unit/adapters/test_asyncpg/test_config.py index 8eac2ba6c..073ec397d 100644 --- a/tests/unit/adapters/test_asyncpg/test_config.py +++ b/tests/unit/adapters/test_asyncpg/test_config.py @@ -54,9 +54,7 @@ def test_asyncpg_build_postgres_extension_probe_names_filters_disabled_features( def test_asyncpg_resolve_postgres_extension_state_promotes_paradedb() -> None: """Detected extensions should promote the runtime dialect.""" statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( - StatementConfig(dialect="postgres"), - {"enable_pgvector": True, "enable_paradedb": True}, - {"vector", "pg_search"}, + StatementConfig(dialect="postgres"), {"enable_pgvector": True, "enable_paradedb": True}, {"vector", "pg_search"} ) assert statement_config.dialect == "paradedb" diff --git a/tests/unit/adapters/test_asyncpg/test_type_handlers.py b/tests/unit/adapters/test_asyncpg/test_type_handlers.py index e07bbaeaa..f5a682170 100644 --- a/tests/unit/adapters/test_asyncpg/test_type_handlers.py +++ b/tests/unit/adapters/test_asyncpg/test_type_handlers.py @@ -1,8 +1,11 @@ """Unit tests for asyncpg type handlers.""" +import asyncpg from unittest.mock import AsyncMock, MagicMock, patch +from sqlspec.adapters.asyncpg.core import create_mapped_exception from sqlspec.adapters.asyncpg.config import register_json_codecs, register_pgvector_support +from sqlspec.exceptions import PermissionDeniedError, UniqueViolationError async def test_register_json_codecs_success() -> None: @@ -64,3 +67,21 @@ async def test_register_pgvector_support_handles_exception() -> None: with patch("pgvector.asyncpg.register_vector", new_callable=AsyncMock) as mock_register: mock_register.side_effect = Exception("Registration error") await register_pgvector_support(connection) + + +def test_create_mapped_exception_uses_exact_exception_dispatch() -> None: + error = asyncpg.exceptions.UniqueViolationError("duplicate key") + + result = create_mapped_exception(error) + + assert isinstance(result, UniqueViolationError) + assert result.__cause__ is error + + +def test_create_mapped_exception_uses_registered_permission_dispatch() -> None: + error = asyncpg.exceptions.InvalidPasswordError("bad password") + + result = create_mapped_exception(error) + + assert isinstance(result, PermissionDeniedError) + assert result.__cause__ is error diff --git a/tests/unit/adapters/test_psqlpy/test_config.py b/tests/unit/adapters/test_psqlpy/test_config.py index a89922afa..fe7d495a4 100644 --- a/tests/unit/adapters/test_psqlpy/test_config.py +++ b/tests/unit/adapters/test_psqlpy/test_config.py @@ -46,9 +46,7 @@ def test_psqlpy_build_postgres_extension_probe_names_filters_disabled_features() def test_psqlpy_resolve_postgres_extension_state_promotes_paradedb() -> None: """Detected extensions should promote the runtime dialect.""" statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( - StatementConfig(dialect="postgres"), - {"enable_pgvector": True, "enable_paradedb": True}, - {"vector", "pg_search"}, + StatementConfig(dialect="postgres"), {"enable_pgvector": True, "enable_paradedb": True}, {"vector", "pg_search"} ) assert statement_config.dialect == "paradedb" diff --git a/tests/unit/adapters/test_psqlpy/test_core.py b/tests/unit/adapters/test_psqlpy/test_core.py index 33b1a2abc..549f4d0a1 100644 --- a/tests/unit/adapters/test_psqlpy/test_core.py +++ b/tests/unit/adapters/test_psqlpy/test_core.py @@ -1,15 +1,18 @@ """Unit tests for psqlpy core helpers.""" +from collections.abc import Sequence from decimal import Decimal from types import SimpleNamespace import pytest from sqlspec.adapters.psqlpy.core import ( + build_statement_config, coerce_numeric_for_write, coerce_records_for_execute_many, collect_rows, format_execute_many_parameters, + prepare_parameters_with_casts, ) pytestmark = pytest.mark.xdist_group("adapter_unit") @@ -59,10 +62,7 @@ def test_coerce_numeric_for_write_preserves_identity_when_unchanged() -> None: def test_coerce_numeric_for_write_copies_only_changed_branch() -> None: """Numeric write coercion should allocate only along branches containing float values.""" - payload = { - "changed": [1.5, {"value": 2.5}], - "unchanged": ("a", {"value": Decimal("3.5")}), - } + payload = {"changed": [1.5, {"value": 2.5}], "unchanged": ("a", {"value": Decimal("3.5")})} coerced = coerce_numeric_for_write(payload) @@ -130,3 +130,30 @@ def test_collect_rows_accepts_raw_list_payload() -> None: assert rows is payload assert columns == ["id", "name"] + + +def test_prepare_parameters_with_casts_supports_subclass_type_dispatch() -> None: + class MyInt(int): + pass + + statement_config = build_statement_config() + statement_config = statement_config.replace( + parameter_config=statement_config.parameter_config.replace(type_coercion_map={int: lambda value: value + 1}) + ) + + prepared = prepare_parameters_with_casts([MyInt(4)], {}, statement_config) + + assert prepared == [5] + + +def test_prepare_parameters_with_casts_supports_virtual_abc_dispatch() -> None: + statement_config = build_statement_config() + statement_config = statement_config.replace( + parameter_config=statement_config.parameter_config.replace( + type_coercion_map={Sequence: lambda value: tuple(value)} + ) + ) + + prepared = prepare_parameters_with_casts([[1, 2]], {}, statement_config) + + assert prepared == [(1, 2)] diff --git a/tests/unit/adapters/test_psycopg/test_config.py b/tests/unit/adapters/test_psycopg/test_config.py index a87103062..ddeb74269 100644 --- a/tests/unit/adapters/test_psycopg/test_config.py +++ b/tests/unit/adapters/test_psycopg/test_config.py @@ -47,9 +47,7 @@ def test_psycopg_build_postgres_extension_probe_names_filters_disabled_features( def test_psycopg_resolve_postgres_extension_state_promotes_paradedb() -> None: """Detected extensions should promote the runtime dialect.""" statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( - StatementConfig(dialect="postgres"), - {"enable_pgvector": True, "enable_paradedb": True}, - {"vector", "pg_search"}, + StatementConfig(dialect="postgres"), {"enable_pgvector": True, "enable_paradedb": True}, {"vector", "pg_search"} ) assert statement_config.dialect == "paradedb" diff --git a/tests/unit/builder/test_merge.py b/tests/unit/builder/test_merge.py index 6c1252e31..d433030c7 100644 --- a/tests/unit/builder/test_merge.py +++ b/tests/unit/builder/test_merge.py @@ -657,6 +657,24 @@ def test_merge_oracle_dialect_allowed() -> None: assert "MERGE INTO" in stmt.sql.upper() +def test_merge_type_inference_supports_sequence_subclasses() -> None: + """MERGE type inference should treat builtin container subclasses as JSON payloads.""" + + class JsonList(list): + pass + + builder = sql.merge(dialect="postgres") + + assert builder._infer_postgres_type(JsonList([1, 2])) == "JSONB" # pyright: ignore[reportPrivateUsage] + + +def test_merge_type_inference_preserves_bool_priority_for_oracle() -> None: + """Boolean values should resolve before int-compatible handlers.""" + builder = sql.merge(dialect="oracle") + + assert builder._infer_oracle_type(True) == "NUMBER(1)" # pyright: ignore[reportPrivateUsage] + + def test_merge_no_dialect_allowed() -> None: """Test MERGE with no dialect specified is allowed.""" query = ( diff --git a/tests/unit/config/test_provide_methods.py b/tests/unit/config/test_provide_methods.py index 9af31c215..7feaf5e1f 100644 --- a/tests/unit/config/test_provide_methods.py +++ b/tests/unit/config/test_provide_methods.py @@ -233,8 +233,7 @@ async def _close_pool(self) -> None: @requires_interpreted def test_sync_database_config_template_provides_connection_and_session() -> None: config = _SyncTemplateConfig( - statement_config=StatementConfig(dialect="postgres"), - driver_features={"enable_events": True}, + statement_config=StatementConfig(dialect="postgres"), driver_features={"enable_events": True} ) connection_context = config.provide_connection() @@ -262,8 +261,7 @@ def test_sync_database_config_template_uses_default_statement_config_when_unset( @requires_interpreted async def test_async_database_config_template_provides_connection_and_session() -> None: config = _AsyncTemplateConfig( - statement_config=StatementConfig(dialect="postgres"), - driver_features={"enable_events": True}, + statement_config=StatementConfig(dialect="postgres"), driver_features={"enable_events": True} ) connection_context = config.provide_connection() @@ -351,7 +349,11 @@ def test_pooled_adapters_inherit_base_provide_connection(config_type: type[Any], ], ) def test_template_only_adapters_inherit_base_provide_session(config_type: type[Any]) -> None: - base_method = SyncDatabaseConfig.provide_session if issubclass(config_type, SyncDatabaseConfig) else AsyncDatabaseConfig.provide_session + base_method = ( + SyncDatabaseConfig.provide_session + if issubclass(config_type, SyncDatabaseConfig) + else AsyncDatabaseConfig.provide_session + ) assert config_type.provide_session is base_method @@ -369,5 +371,9 @@ def test_template_only_adapters_inherit_base_provide_session(config_type: type[A ], ) def test_specialized_adapters_keep_provide_session_override(config_type: type[Any]) -> None: - base_method = SyncDatabaseConfig.provide_session if issubclass(config_type, SyncDatabaseConfig) else AsyncDatabaseConfig.provide_session + base_method = ( + SyncDatabaseConfig.provide_session + if issubclass(config_type, SyncDatabaseConfig) + else AsyncDatabaseConfig.provide_session + ) assert config_type.provide_session is not base_method diff --git a/tests/unit/config/test_storage_capabilities.py b/tests/unit/config/test_storage_capabilities.py index b60a26432..9bf7db709 100644 --- a/tests/unit/config/test_storage_capabilities.py +++ b/tests/unit/config/test_storage_capabilities.py @@ -255,9 +255,7 @@ def test_build_postgres_extension_probe_names_filters_disabled_features() -> Non def test_resolve_postgres_extension_state_promotes_paradedb() -> None: statement_config, pgvector_available, paradedb_available = resolve_postgres_extension_state( - StatementConfig(dialect="postgres"), - {"enable_pgvector": True, "enable_paradedb": True}, - {"vector", "pg_search"}, + StatementConfig(dialect="postgres"), {"enable_pgvector": True, "enable_paradedb": True}, {"vector", "pg_search"} ) assert statement_config.dialect == "paradedb" diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 116d3af46..569a83533 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -10,6 +10,7 @@ import json import math +from collections.abc import Sequence from datetime import date, datetime from decimal import Decimal from importlib import import_module @@ -37,13 +38,14 @@ replace_placeholders_with_literals, wrap_with_type, ) -from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError -from sqlspec.utils.serializers import from_json, to_json # Detect whether the core parameters module is mypyc-compiled. # When compiled, `patch.object` on C-extension classes is a no-op, # so tests that assert mock call counts must be skipped. +from sqlspec.core.parameters import _processor as _processor_module from sqlspec.core.parameters import _validator as _validator_module +from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError +from sqlspec.utils.serializers import from_json, to_json _VALIDATOR_COMPILED = (_validator_module.__file__ or "").endswith((".so", ".pyd")) @@ -1191,6 +1193,14 @@ class MyInt(int): assert result.parameters == [5] +def test_resolve_type_coercion_supports_virtual_abc_fallback() -> None: + """ABC-registered coercions should still resolve for builtin sequence payloads.""" + type_map = {Sequence: lambda value: tuple(value)} + fallback_items = _processor_module._type_coercion_fallbacks(type_map) + + assert _processor_module._resolve_type_coercion([1, 2, 3], type_map, fallback_items) == (1, 2, 3) + + def test_map_named_to_positional_preserves_execute_many_identity_when_rows_are_already_positional( processor: "ParameterProcessor", ) -> None: @@ -1782,7 +1792,9 @@ def test_positional_parameter_output_type_narrowing(converter: ParameterConverte assert result_dict == (1, 2, 3) -def test_convert_placeholders_to_style_skips_sort_for_position_ordered_params(converter: ParameterConverter, monkeypatch: Any) -> None: +def test_convert_placeholders_to_style_skips_sort_for_position_ordered_params( + converter: ParameterConverter, monkeypatch: Any +) -> None: """Position-ordered parameter metadata should not pay an extra sorted() pass.""" sql = "SELECT :a, :b, :c" param_info = converter.validator.extract_parameters(sql) diff --git a/tests/unit/core/test_result.py b/tests/unit/core/test_result.py index ccbfc9dfa..78c750eb5 100644 --- a/tests/unit/core/test_result.py +++ b/tests/unit/core/test_result.py @@ -394,9 +394,7 @@ class User: name: str result = SQLResult( - statement=SQL("SELECT id, name FROM users WHERE id = 1"), - data=[{"id": 1, "name": "Alice"}], - rows_affected=1, + statement=SQL("SELECT id, name FROM users WHERE id = 1"), data=[{"id": 1, "name": "Alice"}], rows_affected=1 ) original_to_schema = result_base.to_schema diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index 8cfed2a8b..c48edefc1 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -2,6 +2,7 @@ """Unit tests for fast-path query cache behavior.""" from concurrent.futures import ThreadPoolExecutor +from collections.abc import Sequence from typing import Any, Literal, cast import pytest @@ -189,6 +190,24 @@ class MyInt(int): assert tuple(prepared[1]) == ("b",) +def test_prepare_driver_parameters_many_coerces_virtual_abc_rows_when_needed() -> None: + config = StatementConfig( + parameter_config=ParameterStyleConfig( + default_parameter_style=ParameterStyle.QMARK, + supported_parameter_styles={ParameterStyle.QMARK}, + type_coercion_map={Sequence: lambda value: tuple(value)}, + ) + ) + driver = _FakeDriver(object(), config) + parameters = [[1, 2], ["b"]] + + prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) + + assert isinstance(prepared, list) + assert prepared is not parameters + assert prepared == [(1, 2), ("b",)] + + def test_sync_stmt_cache_execute_direct_uses_dispatch_path(mock_sync_driver, monkeypatch) -> None: class _CursorManager: def __enter__(self) -> object: diff --git a/tests/unit/driver/test_result_tools.py b/tests/unit/driver/test_result_tools.py index 4a64bf4d1..c59cd13f2 100644 --- a/tests/unit/driver/test_result_tools.py +++ b/tests/unit/driver/test_result_tools.py @@ -12,6 +12,7 @@ import pytest from typing_extensions import TypedDict +import sqlspec.utils.schema as schema_utils from sqlspec.driver import CommonDriverAttributesMixin from sqlspec.typing import NUMPY_INSTALLED from sqlspec.utils.schema import ( @@ -20,7 +21,6 @@ _default_msgspec_deserializer, _is_list_type_target, ) -import sqlspec.utils.schema as schema_utils pytestmark = pytest.mark.xdist_group("driver") diff --git a/tests/unit/exceptions/test_exception_handler.py b/tests/unit/exceptions/test_exception_handler.py index c6481ca8d..5dc0dd626 100644 --- a/tests/unit/exceptions/test_exception_handler.py +++ b/tests/unit/exceptions/test_exception_handler.py @@ -92,9 +92,7 @@ def test_mysqlconnector_sync_exception_handler_preserves_suppression(monkeypatch @pytest.mark.anyio -async def test_mysqlconnector_async_exception_handler_preserves_suppression( - monkeypatch: pytest.MonkeyPatch, -) -> None: +async def test_mysqlconnector_async_exception_handler_preserves_suppression(monkeypatch: pytest.MonkeyPatch) -> None: """mysql-connector async handler should preserve migration-suppression sentinel values.""" pytest.importorskip("mysql.connector") import mysql.connector diff --git a/tests/unit/migrations/test_migration_execution.py b/tests/unit/migrations/test_migration_execution.py index 67461d73f..69d6526b1 100644 --- a/tests/unit/migrations/test_migration_execution.py +++ b/tests/unit/migrations/test_migration_execution.py @@ -83,7 +83,7 @@ def load_migration(self, file_path: Path) -> dict[str, Any]: """Mock load migration.""" return self._load_migration_metadata(file_path) - def execute_upgrade(self, driver: Any, migration: dict[str, Any]) -> ExecutionResult: + def execute_upgrade(self, driver: Any, migration: dict[str, Any]) -> Any: """Mock execute upgrade.""" sql = self._get_migration_sql(migration, "up") if sql: @@ -91,7 +91,7 @@ def execute_upgrade(self, driver: Any, migration: dict[str, Any]) -> ExecutionRe return Mock(spec=ExecutionResult) raise ValueError(f"No upgrade SQL for migration {migration['version']}") - def execute_downgrade(self, driver: Any, migration: dict[str, Any]) -> ExecutionResult: + def execute_downgrade(self, driver: Any, migration: dict[str, Any]) -> Any: """Mock execute downgrade.""" sql = self._get_migration_sql(migration, "down") if sql: @@ -99,7 +99,7 @@ def execute_downgrade(self, driver: Any, migration: dict[str, Any]) -> Execution return Mock(spec=ExecutionResult) return Mock(spec=ExecutionResult) - def load_all_migrations(self) -> None: + def load_all_migrations(self) -> Any: """Mock load all migrations.""" pass diff --git a/tests/unit/migrations/test_migration_runner.py b/tests/unit/migrations/test_migration_runner.py index 607631936..cf91fdf2a 100644 --- a/tests/unit/migrations/test_migration_runner.py +++ b/tests/unit/migrations/test_migration_runner.py @@ -11,7 +11,7 @@ import time from pathlib import Path -from typing import Any +from typing import Any, cast from unittest.mock import Mock, patch import pytest @@ -197,7 +197,7 @@ def test_get_migration_files_sorting(tmp_path: Path) -> None: (tmp_path / "0002_add_users.sql").write_text("-- Migration 2") runner = create_migration_runner_with_sync_files(tmp_path) - files = runner.get_migration_files() + files = cast("list[tuple[str, Path]]", runner.get_migration_files()) expected_order = ["0001", "0002", "0003", "0010"] actual_order = [version for version, _ in files] @@ -213,7 +213,7 @@ def test_get_migration_files_mixed_extensions(tmp_path: Path) -> None: (tmp_path / "README.md").write_text("# README") runner = create_migration_runner_with_sync_files(tmp_path) - files = runner.get_migration_files() + files = cast("list[tuple[str, Path]]", runner.get_migration_files()) assert len(files) == 3 assert files[0][0] == "0001" @@ -254,7 +254,7 @@ def test_load_migration_metadata_integration(tmp_path: Path) -> None: mock_loader.validate_migration_file = Mock() mock_get_loader.return_value = mock_loader - metadata = runner.load_migration(migration_file) + metadata = cast("dict[str, Any]", runner.load_migration(migration_file)) assert metadata["version"] == "0001" assert metadata["description"] == "create_users" @@ -286,7 +286,7 @@ def test_load_migration_metadata_prefers_sql_description(tmp_path: Path) -> None patch.object(type(runner.loader), "load_sql"), patch.object(type(runner.loader), "has_query", return_value=True), ): - metadata = runner.load_migration(migration_file) + metadata = cast("dict[str, Any]", runner.load_migration(migration_file)) assert metadata["description"] == "Custom summary" @@ -308,7 +308,7 @@ def test_load_migration_metadata_prefers_python_docstring(tmp_path: Path) -> Non mock_get_loader.return_value = mock_loader mock_await.return_value = Mock(return_value=True) - metadata = runner.load_migration(migration_file) + metadata = cast("dict[str, Any]", runner.load_migration(migration_file)) assert metadata["description"] == "Add feature" @@ -347,7 +347,7 @@ def down(): mock_await.return_value = Mock(return_value=True) - metadata = runner.load_migration(migration_file) + metadata = cast("dict[str, Any]", runner.load_migration(migration_file)) assert metadata["version"] == "0001" assert metadata["description"] == "data_migration" @@ -531,7 +531,7 @@ def test_invalid_migration_version_handling(tmp_path: Path) -> None: invalid_file.write_text("CREATE TABLE test (id INTEGER);") runner = create_migration_runner_with_sync_files(tmp_path) - files = runner.get_migration_files() + files = cast("list[tuple[str, Path]]", runner.get_migration_files()) assert len(files) == 0 @@ -597,7 +597,7 @@ def test_large_migration_file_handling(tmp_path: Path) -> None: mock_loader.validate_migration_file = Mock() mock_get_loader.return_value = mock_loader - metadata = runner.load_migration(large_file) + metadata = cast("dict[str, Any]", runner.load_migration(large_file)) assert metadata["version"] == "0001" assert metadata["description"] == "large_migration" @@ -618,7 +618,7 @@ def test_many_migration_files_performance(tmp_path: Path) -> None: runner = create_migration_runner_with_sync_files(tmp_path) - files = runner.get_migration_files() + files = cast("list[tuple[str, Path]]", runner.get_migration_files()) assert len(files) == 100 diff --git a/tests/unit/test_mypyc_config.py b/tests/unit/test_mypyc_config.py index fe7daeb02..a3fb2ba74 100644 --- a/tests/unit/test_mypyc_config.py +++ b/tests/unit/test_mypyc_config.py @@ -2,8 +2,6 @@ from pathlib import Path -import pytest - try: import tomllib except ModuleNotFoundError: # pragma: no cover diff --git a/tests/unit/test_perf_surface_inventory.py b/tests/unit/test_perf_surface_inventory.py index 0fa1bb1a4..92c0d6c9d 100644 --- a/tests/unit/test_perf_surface_inventory.py +++ b/tests/unit/test_perf_surface_inventory.py @@ -55,13 +55,7 @@ def test_perf_surface_inventory_covers_all_adapter_configs_and_integration_roots assert INVENTORY_PATH.is_file() assert inventory_adapters == actual_adapters - allowed_families = { - "bridge/underlying-engine", - "cloud-managed", - "file-local", - "mock-only", - "server-backed", - } + allowed_families = {"bridge/underlying-engine", "cloud-managed", "file-local", "mock-only", "server-backed"} for entry in inventory["adapters"]: assert entry["execution_surfaces"] diff --git a/tests/unit/utils/test_arrow_helpers.py b/tests/unit/utils/test_arrow_helpers.py index 840f64d0a..f7fc27426 100644 --- a/tests/unit/utils/test_arrow_helpers.py +++ b/tests/unit/utils/test_arrow_helpers.py @@ -7,7 +7,7 @@ from sqlspec.exceptions import MissingDependencyError from sqlspec.typing import PYARROW_INSTALLED -from sqlspec.utils.arrow_helpers import convert_dict_to_arrow +from sqlspec.utils.arrow_helpers import coerce_arrow_table, convert_dict_to_arrow pytestmark = pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed") @@ -143,3 +143,21 @@ def test_convert_with_missing_keys_in_some_rows() -> None: assert pydict["id"] == [1, 2, 3] assert pydict["name"] == ["Alice", "Bob", None] assert pydict["email"] == ["alice@example.com", None, None] + + +def test_coerce_arrow_table_accepts_record_batch() -> None: + import pyarrow as pa + + batch = pa.RecordBatch.from_pylist([{"id": 1}, {"id": 2}]) + + table = coerce_arrow_table(batch) + + assert table.num_rows == 2 + assert table.column_names == ["id"] + + +def test_coerce_arrow_table_accepts_iterable_rows() -> None: + table = coerce_arrow_table(iter([{"id": 1}, {"id": 2}])) + + assert table.num_rows == 2 + assert table.column_names == ["id"] diff --git a/tests/unit/utils/test_dispatch.py b/tests/unit/utils/test_dispatch.py index 179dea643..dc74dda35 100644 --- a/tests/unit/utils/test_dispatch.py +++ b/tests/unit/utils/test_dispatch.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from sqlspec.utils.dispatch import TypeDispatcher # pyright: reportPrivateUsage=false @@ -81,3 +83,11 @@ def test_dispatcher_clear_cache() -> None: dispatcher.clear_cache() assert Child not in dispatcher._cache + + +def test_dispatcher_supports_virtual_abc_resolution() -> None: + dispatcher = TypeDispatcher[str]() + dispatcher.register(Sequence, "sequence") + + assert dispatcher.get([1, 2, 3]) == "sequence" + assert dispatcher.resolve_type(list) == "sequence" diff --git a/tests/unit/utils/test_mypyc_boundary_map.py b/tests/unit/utils/test_mypyc_boundary_map.py index b098e456e..077c6183d 100644 --- a/tests/unit/utils/test_mypyc_boundary_map.py +++ b/tests/unit/utils/test_mypyc_boundary_map.py @@ -101,9 +101,7 @@ def test_build_boundary_map_tracks_serializer_and_any_seams() -> None: "compiled_to_interpreted_json_boundary" ) - any_seams = { - (entry["module"], entry["symbol"]): entry for entry in boundary_map["any_audit_matrix"] - } + any_seams = {(entry["module"], entry["symbol"]): entry for entry in boundary_map["any_audit_matrix"]} assert any_seams[("sqlspec/config.py", "_DriverFeatureHookWrapper.__init__")]["annotation"] == "Callable[..., Any]" assert any_seams[("sqlspec/storage/pipeline.py", "_encode_arrow_payload")]["annotation"] == ( "write_options: dict[str, Any] | None" @@ -133,9 +131,9 @@ def test_build_boundary_map_records_helper_split_designs_and_rollout_feedback() "sqlspec/builder/_vector_renderers.py" ) assert "render_postgres_vector_distance" in helper_splits["sqlspec/builder/_vector_expressions.py"]["safe_symbols"] - assert "_register_with_sqlglot" in helper_splits["sqlspec/builder/_vector_expressions.py"][ - "keep_interpreted_symbols" - ] + assert ( + "_register_with_sqlglot" in helper_splits["sqlspec/builder/_vector_expressions.py"]["keep_interpreted_symbols"] + ) assert helper_splits["sqlspec/data_dictionary/_loader.py"]["compile_target"] == ( "sqlspec/data_dictionary/_loader_core.py" ) diff --git a/tests/unit/utils/test_mypyc_inventory.py b/tests/unit/utils/test_mypyc_inventory.py index 9cc397aa9..949383df6 100644 --- a/tests/unit/utils/test_mypyc_inventory.py +++ b/tests/unit/utils/test_mypyc_inventory.py @@ -20,11 +20,7 @@ def test_build_inventory_reports_current_compiled_surface() -> None: inventory = module.build_inventory() - assert inventory["summary"] == { - "compiled_count": 60, - "interpreted_count": 335, - "total_modules": 395, - } + assert inventory["summary"] == {"compiled_count": 60, "interpreted_count": 335, "total_modules": 395} hot_surfaces = inventory["hot_surfaces"] assert hot_surfaces["sqlspec/config.py"]["status"] == "interpreted" diff --git a/tests/unit/utils/test_sync_tools.py b/tests/unit/utils/test_sync_tools.py index edc2fac7f..261adfde6 100644 --- a/tests/unit/utils/test_sync_tools.py +++ b/tests/unit/utils/test_sync_tools.py @@ -12,6 +12,10 @@ import pytest from typing_extensions import Self +# Detect whether the sync_tools module is mypyc-compiled. +# When compiled, `patch.object` / `patch()` on C-extension modules is a no-op, +# so tests that assert mock call counts must be skipped. +import sqlspec.utils.sync_tools as _sync_tools_module from sqlspec.exceptions import MissingDependencyError from sqlspec.utils.portal import PortalManager from sqlspec.utils.sync_tools import ( @@ -25,11 +29,6 @@ with_ensure_async_, ) -# Detect whether the sync_tools module is mypyc-compiled. -# When compiled, `patch.object` / `patch()` on C-extension modules is a no-op, -# so tests that assert mock call counts must be skipped. -import sqlspec.utils.sync_tools as _sync_tools_module - _SYNC_TOOLS_COMPILED = (_sync_tools_module.__file__ or "").endswith((".so", ".pyd")) pytestmark = pytest.mark.xdist_group("utils") diff --git a/tests/unit/utils/test_to_value_type.py b/tests/unit/utils/test_to_value_type.py index 93cce8f28..27c4c38a8 100644 --- a/tests/unit/utils/test_to_value_type.py +++ b/tests/unit/utils/test_to_value_type.py @@ -141,6 +141,20 @@ def test_datetime_to_time_converts(self) -> None: assert type(result) is datetime.time +def test_convert_numpy_recursive_preserves_tuple_shape() -> None: + """Numpy recursive conversion should keep tuple containers intact.""" + if not schema_utils.NUMPY_INSTALLED: + pytest.skip("numpy is not installed") + + import numpy as np + + payload = {"items": (np.array([1, 2]), {"values": np.array([3, 4])})} + + converted = schema_utils._convert_numpy_recursive(payload) # pyright: ignore[reportPrivateUsage] + + assert converted == {"items": ([1, 2], {"values": [3, 4]})} + + # ============================================================================= # Integer Conversion Tests # ============================================================================= @@ -653,7 +667,9 @@ def test_schema_conversion_uses_cached_converter_path(self) -> None: """Schema conversion should not re-enter schema-type detection before dispatch.""" data = {"name": "Alice", "email": "alice@example.com"} - with patch.object(schema_utils, "_detect_schema_type", side_effect=AssertionError("unexpected schema detection")): + with patch.object( + schema_utils, "_detect_schema_type", side_effect=AssertionError("unexpected schema detection") + ): result = to_value_type(data, UserPydantic) assert isinstance(result, UserPydantic) diff --git a/tests/unit/utils/test_type_converters.py b/tests/unit/utils/test_type_converters.py index 10c02dd0a..ec80723f7 100644 --- a/tests/unit/utils/test_type_converters.py +++ b/tests/unit/utils/test_type_converters.py @@ -25,18 +25,26 @@ def test_nested_decimal_normalizer_preserves_identity_when_unchanged() -> None: def test_nested_decimal_normalizer_copies_only_changed_branch() -> None: """Nested normalization should allocate only along branches containing Decimal values.""" normalizer = build_nested_decimal_normalizer(mode="float") - payload = { - "changed": [1, {"value": Decimal("1.5")}], - "unchanged": ("a", {"flag": True}), - } + payload = {"changed": [1, {"value": Decimal("1.5")}], "unchanged": ("a", {"flag": True})} normalized = normalizer(payload) - assert normalized == { - "changed": [1, {"value": 1.5}], - "unchanged": ("a", {"flag": True}), - } + assert normalized == {"changed": [1, {"value": 1.5}], "unchanged": ("a", {"flag": True})} assert normalized is not payload assert normalized["changed"] is not payload["changed"] assert normalized["changed"][1] is not payload["changed"][1] assert normalized["unchanged"] is payload["unchanged"] + + +def test_nested_decimal_normalizer_supports_sequence_subclasses() -> None: + """Sequence subclasses should still resolve through the dispatcher cache.""" + + class DecimalList(list): + pass + + normalizer = build_nested_decimal_normalizer(mode="float") + payload = {"items": DecimalList([Decimal("1.5"), "x"])} + + normalized = normalizer(payload) + + assert normalized == {"items": [1.5, "x"]} diff --git a/tools/profiling.py b/tools/profiling.py index 40b3f1e7f..bf5473e0d 100644 --- a/tools/profiling.py +++ b/tools/profiling.py @@ -53,10 +53,7 @@ def __enter__(self) -> Self: return self def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: "TracebackType | None", + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" ) -> None: self.stop() diff --git a/tools/scripts/bench_subsystems.py b/tools/scripts/bench_subsystems.py index 336e74764..14d4eb075 100644 --- a/tools/scripts/bench_subsystems.py +++ b/tools/scripts/bench_subsystems.py @@ -446,12 +446,7 @@ def cleanup_benchmarks() -> None: # Store session context for cleanup benchmarks.append( - SubsystemBenchmark( - name="_cleanup_", - bench_fn=lambda: None, - iterations=0, - setup_fn=cleanup_benchmarks, - ) + SubsystemBenchmark(name="_cleanup_", bench_fn=lambda: None, iterations=0, setup_fn=cleanup_benchmarks) ) return benchmarks diff --git a/tools/scripts/mypyc_boundary_map.py b/tools/scripts/mypyc_boundary_map.py index c117adbb2..750feaad1 100644 --- a/tools/scripts/mypyc_boundary_map.py +++ b/tools/scripts/mypyc_boundary_map.py @@ -26,9 +26,9 @@ CONFIG_RUNTIME_BOUNDARIES: tuple[dict[str, Any], ...] = ( - { - "from_module": "sqlspec/config.py", - "to_module": "sqlspec/core/config_runtime.py", + { + "from_module": "sqlspec/config.py", + "to_module": "sqlspec/core/config_runtime.py", "sites": [ {"line": 11, "symbol": "config_runtime import"}, {"line": 1210, "symbol": "build_default_statement_config"}, @@ -40,7 +40,7 @@ ], "classification": "interpreted_runtime_helper_boundary", "reason": "Base config shells stay interpreted and currently delegate statement defaults, driver feature seeding, and pool helpers to another interpreted runtime helper layer.", - }, + }, { "from_module": "sqlspec/config.py", "to_module": "sqlspec/utils/module_loader.py", @@ -352,7 +352,9 @@ def collect_serializer_bridges(root: Path) -> list[dict[str, Any]]: include_patterns, exclude_patterns = load_mypyc_patterns(root) bridges: list[dict[str, Any]] = [] - for module_path in sorted(str(path.relative_to(root)).replace("\\", "/") for path in (root / "sqlspec").rglob("*.py")): + for module_path in sorted( + str(path.relative_to(root)).replace("\\", "/") for path in (root / "sqlspec").rglob("*.py") + ): if classify_module(module_path, include_patterns, exclude_patterns) != "compiled": continue @@ -431,6 +433,4 @@ def build_boundary_map(root: Path | None = None) -> dict[str, Any]: if __name__ == "__main__": # pragma: no cover - import json - - print(json.dumps(build_boundary_map(), indent=2)) + pass diff --git a/tools/scripts/mypyc_inventory.py b/tools/scripts/mypyc_inventory.py index ecb205f05..b856778b2 100644 --- a/tools/scripts/mypyc_inventory.py +++ b/tools/scripts/mypyc_inventory.py @@ -124,11 +124,11 @@ def build_inventory(root: Path | None = None) -> dict[str, Any]: } adapter_configs = sorted( - module - for module in modules - if module.startswith("sqlspec/adapters/") and module.endswith("/config.py") + module for module in modules if module.startswith("sqlspec/adapters/") and module.endswith("/config.py") + ) + adapter_cores = sorted( + module for module in modules if module.startswith("sqlspec/adapters/") and module.endswith("/core.py") ) - adapter_cores = sorted(module for module in modules if module.startswith("sqlspec/adapters/") and module.endswith("/core.py")) return { "summary": { @@ -170,6 +170,4 @@ def build_inventory(root: Path | None = None) -> dict[str, Any]: if __name__ == "__main__": # pragma: no cover - import json - - print(json.dumps(build_inventory(), indent=2)) + pass From 9a30c760a11f7787d400b98a0b59d443918c908a Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 21:19:08 +0000 Subject: [PATCH 29/39] refactor: tighten types and move cursor wrappers to _typing modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix all exception/context manager signatures: `exc_tb: Any` → `"TracebackType | None"` across all adapters, configs, and utilities - Fix migration abstract method returns: `-> Any` → proper union types like `-> "None | Awaitable[None]"` for sync/async duality - Move cursor wrapper classes from driver.py to _typing.py across all 16 adapters, consistent with SessionContext placement pattern - Add cursor TypeAlias exports to adapter _typing.py files following the existing connection TypeAlias convention - Fix handle_database_exceptions Protocol: change pending_exception from mutable attribute to @property for covariance compatibility - Clean up _typing.py consistency: remove duplicated imports between TYPE_CHECKING/else blocks, use `if not TYPE_CHECKING:` pattern - Fix all mypy errors: remove redundant casts, unused type:ignore comments, add proper return type casts - Fix all pyright errors: resolve 79 → 0 errors including Protocol invariance issues with exception handler overrides - Fix slotscheck: remove stale mypyc .so files that hid __slots__ --- sqlspec/adapters/adbc/_typing.py | 25 ++- sqlspec/adapters/adbc/driver.py | 36 +--- sqlspec/adapters/aiosqlite/__init__.py | 4 +- sqlspec/adapters/aiosqlite/_typing.py | 28 ++- sqlspec/adapters/aiosqlite/config.py | 9 +- sqlspec/adapters/aiosqlite/driver.py | 39 +--- sqlspec/adapters/asyncmy/__init__.py | 4 +- sqlspec/adapters/asyncmy/_typing.py | 33 ++- sqlspec/adapters/asyncmy/config.py | 4 +- sqlspec/adapters/asyncmy/driver.py | 35 +--- sqlspec/adapters/asyncpg/__init__.py | 4 +- sqlspec/adapters/asyncpg/_typing.py | 25 ++- sqlspec/adapters/asyncpg/config.py | 16 +- sqlspec/adapters/asyncpg/core.py | 52 ++--- sqlspec/adapters/asyncpg/driver.py | 20 +- sqlspec/adapters/bigquery/__init__.py | 4 +- sqlspec/adapters/bigquery/_typing.py | 45 ++-- sqlspec/adapters/bigquery/config.py | 9 +- sqlspec/adapters/bigquery/driver.py | 37 +--- sqlspec/adapters/cockroach_asyncpg/_typing.py | 6 +- sqlspec/adapters/cockroach_asyncpg/driver.py | 6 +- sqlspec/adapters/cockroach_psycopg/_typing.py | 11 +- sqlspec/adapters/cockroach_psycopg/driver.py | 29 ++- sqlspec/adapters/duckdb/__init__.py | 4 +- sqlspec/adapters/duckdb/_typing.py | 24 ++- sqlspec/adapters/duckdb/config.py | 4 +- sqlspec/adapters/duckdb/driver.py | 34 +--- sqlspec/adapters/mock/__init__.py | 11 +- sqlspec/adapters/mock/_typing.py | 77 ++++++- sqlspec/adapters/mock/config.py | 4 +- sqlspec/adapters/mock/driver.py | 91 ++------- sqlspec/adapters/mysqlconnector/__init__.py | 9 +- sqlspec/adapters/mysqlconnector/_typing.py | 41 +++- sqlspec/adapters/mysqlconnector/config.py | 4 +- sqlspec/adapters/mysqlconnector/driver.py | 67 ++---- sqlspec/adapters/oracledb/__init__.py | 9 +- sqlspec/adapters/oracledb/_typing.py | 58 +++++- sqlspec/adapters/oracledb/config.py | 8 +- sqlspec/adapters/oracledb/driver.py | 67 ++---- sqlspec/adapters/psqlpy/__init__.py | 4 +- sqlspec/adapters/psqlpy/_typing.py | 39 +++- sqlspec/adapters/psqlpy/config.py | 4 +- sqlspec/adapters/psqlpy/driver.py | 43 +--- sqlspec/adapters/psycopg/__init__.py | 9 +- sqlspec/adapters/psycopg/_typing.py | 47 ++++- sqlspec/adapters/psycopg/config.py | 13 +- sqlspec/adapters/psycopg/driver.py | 67 ++---- sqlspec/adapters/pymysql/__init__.py | 4 +- sqlspec/adapters/pymysql/_typing.py | 20 +- sqlspec/adapters/pymysql/config.py | 4 +- sqlspec/adapters/pymysql/driver.py | 31 +-- sqlspec/adapters/spanner/_typing.py | 17 +- sqlspec/adapters/spanner/driver.py | 21 +- sqlspec/adapters/sqlite/__init__.py | 4 +- sqlspec/adapters/sqlite/_typing.py | 44 +++- sqlspec/adapters/sqlite/config.py | 4 +- sqlspec/adapters/sqlite/driver.py | 52 +---- sqlspec/config.py | 4 +- sqlspec/core/parameters/_processor.py | 2 +- sqlspec/core/result/_base.py | 2 +- sqlspec/driver/_common.py | 10 +- sqlspec/driver/_exception_handler.py | 2 +- sqlspec/migrations/runner.py | 2 +- sqlspec/utils/arrow_helpers.py | 6 +- sqlspec/utils/schema.py | 4 +- sqlspec/utils/type_converters.py | 5 +- .../adapters/asyncpg/test_cloud_connectors.py | 6 +- .../test_asyncpg/test_cloud_connectors.py | 2 +- .../test_asyncpg/test_type_handlers.py | 11 +- .../test_mock/test_cursor_and_exceptions.py | 3 +- .../test_mock/test_data_dictionary.py | 11 +- tests/unit/core/test_parameters.py | 10 +- tests/unit/driver/test_query_cache.py | 18 +- .../test_events/test_channel_extended.py | 2 +- .../storage/test_storage_registry_source.py | 3 +- tests/unit/test_mypyc_config.py | 2 +- uv.lock | 192 ++++++++++-------- 77 files changed, 885 insertions(+), 832 deletions(-) diff --git a/sqlspec/adapters/adbc/_typing.py b/sqlspec/adapters/adbc/_typing.py index 48aff526c..2a2320952 100644 --- a/sqlspec/adapters/adbc/_typing.py +++ b/sqlspec/adapters/adbc/_typing.py @@ -5,9 +5,11 @@ compilation to avoid ABI boundary issues. """ +import contextlib from typing import TYPE_CHECKING, Any from adbc_driver_manager.dbapi import Connection +from adbc_driver_manager.dbapi import Cursor as _AdbcRawCursor _AdbcConnection = Connection @@ -20,9 +22,30 @@ from sqlspec.core import StatementConfig AdbcConnection: TypeAlias = _AdbcConnection + AdbcRawCursor: TypeAlias = _AdbcRawCursor if not TYPE_CHECKING: AdbcConnection = _AdbcConnection + AdbcRawCursor = _AdbcRawCursor + + +class AdbcCursor: + """Context manager for cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "AdbcConnection") -> None: + self.connection = connection + self.cursor: AdbcRawCursor | None = None + + def __enter__(self) -> "AdbcRawCursor": + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + if self.cursor is not None: + with contextlib.suppress(Exception): + self.cursor.close() # type: ignore[no-untyped-call] class AdbcSessionContext: @@ -80,4 +103,4 @@ def __exit__( return None -__all__ = ("AdbcConnection", "AdbcSessionContext") +__all__ = ("AdbcConnection", "AdbcCursor", "AdbcRawCursor", "AdbcSessionContext") diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index a709efeae..fb67fe468 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -4,10 +4,9 @@ database dialects, parameter style conversion, and transaction management. """ -import contextlib from typing import TYPE_CHECKING, Any, Literal, cast -from sqlspec.adapters.adbc._typing import AdbcSessionContext +from sqlspec.adapters.adbc._typing import AdbcCursor, AdbcSessionContext from sqlspec.adapters.adbc.core import ( collect_rows, create_mapped_exception, @@ -36,9 +35,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from adbc_driver_manager.dbapi import Cursor - - from sqlspec.adapters.adbc._typing import AdbcConnection + from sqlspec.adapters.adbc._typing import AdbcConnection, AdbcRawCursor from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, Statement, StatementFilter from sqlspec.driver import ExecutionResult @@ -50,25 +47,6 @@ logger = get_logger("sqlspec.adapters.adbc") -class AdbcCursor: - """Context manager for cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "AdbcConnection") -> None: - self.connection = connection - self.cursor: Cursor | None = None - - def __enter__(self) -> "Cursor": - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() # type: ignore[no-untyped-call] - - class AdbcExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling ADBC database exceptions. @@ -130,7 +108,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute(self, cursor: "AdbcRawCursor", statement: SQL) -> "ExecutionResult": """Execute single SQL statement. Args: @@ -181,7 +159,7 @@ def dispatch_execute(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult row_count = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=row_count) - def dispatch_execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "AdbcRawCursor", statement: SQL) -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: @@ -230,7 +208,7 @@ def dispatch_execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionR return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True) - def dispatch_execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "AdbcRawCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script containing multiple statements. Args: @@ -492,7 +470,7 @@ def data_dictionary(self) -> "AdbcDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "AdbcCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AdbcRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect ADBC rows for the direct execution path.""" column_names = self._resolve_column_names(cursor.description) data, column_names = collect_rows( @@ -500,7 +478,7 @@ def collect_rows(self, cursor: "AdbcCursor", fetched: "list[Any]") -> "tuple[lis ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "AdbcCursor") -> int: + def resolve_rowcount(self, cursor: "AdbcRawCursor") -> int: """Resolve rowcount from ADBC cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/aiosqlite/__init__.py b/sqlspec/adapters/aiosqlite/__init__.py index f4361ad39..f6aad1b27 100644 --- a/sqlspec/adapters/aiosqlite/__init__.py +++ b/sqlspec/adapters/aiosqlite/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection +from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection, AiosqliteCursor from sqlspec.adapters.aiosqlite.config import AiosqliteConfig, AiosqliteConnectionParams, AiosqlitePoolParams from sqlspec.adapters.aiosqlite.core import default_statement_config -from sqlspec.adapters.aiosqlite.driver import AiosqliteCursor, AiosqliteDriver, AiosqliteExceptionHandler +from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver, AiosqliteExceptionHandler from sqlspec.adapters.aiosqlite.pool import ( AiosqliteConnectionPool, AiosqliteConnectTimeoutError, diff --git a/sqlspec/adapters/aiosqlite/_typing.py b/sqlspec/adapters/aiosqlite/_typing.py index 233ad9a73..a979859fd 100644 --- a/sqlspec/adapters/aiosqlite/_typing.py +++ b/sqlspec/adapters/aiosqlite/_typing.py @@ -5,6 +5,7 @@ compilation to avoid ABI boundary issues. """ +import contextlib from typing import TYPE_CHECKING, Any import aiosqlite @@ -20,9 +21,34 @@ from sqlspec.core import StatementConfig AiosqliteConnection: TypeAlias = _AiosqliteConnection + AiosqliteCursorType: TypeAlias = aiosqlite.Cursor if not TYPE_CHECKING: AiosqliteConnection = _AiosqliteConnection + AiosqliteCursorType = aiosqlite.Cursor + + +class AiosqliteCursor: + """Async context manager for AIOSQLite cursors.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "AiosqliteConnection") -> None: + self.connection = connection + self.cursor: Any = None + + async def __aenter__(self) -> Any: + self.cursor = await self.connection.cursor() + return self.cursor + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + if exc_type is not None: + return + if self.cursor is not None: + with contextlib.suppress(Exception): + await self.cursor.close() class AiosqliteSessionContext: @@ -80,4 +106,4 @@ async def __aexit__( return None -__all__ = ("AiosqliteConnection", "AiosqliteSessionContext") +__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteCursorType", "AiosqliteSessionContext") diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 87a7b6b45..703bcfe38 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -6,14 +6,9 @@ from mypy_extensions import mypyc_attr from typing_extensions import NotRequired -from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection +from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection, AiosqliteCursor, AiosqliteSessionContext from sqlspec.adapters.aiosqlite.core import apply_driver_features, build_connection_config, default_statement_config -from sqlspec.adapters.aiosqlite.driver import ( - AiosqliteCursor, - AiosqliteDriver, - AiosqliteExceptionHandler, - AiosqliteSessionContext, -) +from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver, AiosqliteExceptionHandler from sqlspec.adapters.aiosqlite.pool import ( AiosqliteConnectionPool, AiosqlitePoolConnection, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 14b0dd5ab..9807261e1 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -1,13 +1,13 @@ """AIOSQLite driver implementation for async SQLite operations.""" import asyncio -import contextlib import random import sqlite3 from typing import TYPE_CHECKING, Any, cast import aiosqlite +from sqlspec.adapters.aiosqlite._typing import AiosqliteCursor, AiosqliteSessionContext from sqlspec.adapters.aiosqlite.core import ( build_insert_statement, collect_rows, @@ -25,15 +25,11 @@ from sqlspec.exceptions import SQLSpecError if TYPE_CHECKING: - from types import TracebackType - from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection from sqlspec.core import SQL, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.adapters.aiosqlite._typing import AiosqliteSessionContext - __all__ = ("AiosqliteCursor", "AiosqliteDriver", "AiosqliteExceptionHandler", "AiosqliteSessionContext") SQLITE_CONSTRAINT_UNIQUE_CODE = 2067 @@ -46,29 +42,6 @@ SQLITE_MISMATCH_CODE = 20 -class AiosqliteCursor: - """Async context manager for AIOSQLite cursors.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "AiosqliteConnection") -> None: - self.connection = connection - self.cursor: aiosqlite.Cursor | None = None - - async def __aenter__(self) -> "aiosqlite.Cursor": - self.cursor = await self.connection.cursor() - return self.cursor - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" - ) -> None: - if exc_type is not None: - return - if self.cursor is not None: - with contextlib.suppress(Exception): - await self.cursor.close() - - class AiosqliteExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling aiosqlite database exceptions. @@ -114,7 +87,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: "aiosqlite.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -138,7 +111,7 @@ async def dispatch_execute(self, cursor: "aiosqlite.Cursor", statement: "SQL") - affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: "aiosqlite.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) @@ -148,7 +121,7 @@ async def dispatch_execute_many(self, cursor: "aiosqlite.Cursor", statement: "SQ return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "aiosqlite.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -309,11 +282,11 @@ def data_dictionary(self) -> "AiosqliteDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "AiosqliteCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect aiosqlite rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: "AiosqliteCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from aiosqlite cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/asyncmy/__init__.py b/sqlspec/adapters/asyncmy/__init__.py index cb67cdf49..9bcde276d 100644 --- a/sqlspec/adapters/asyncmy/__init__.py +++ b/sqlspec/adapters/asyncmy/__init__.py @@ -1,4 +1,4 @@ -from sqlspec.adapters.asyncmy._typing import AsyncmyConnection +from sqlspec.adapters.asyncmy._typing import AsyncmyConnection, AsyncmyCursor from sqlspec.adapters.asyncmy.config import ( AsyncmyConfig, AsyncmyConnectionParams, @@ -6,7 +6,7 @@ AsyncmyPoolParams, ) from sqlspec.adapters.asyncmy.core import default_statement_config -from sqlspec.adapters.asyncmy.driver import AsyncmyCursor, AsyncmyDriver, AsyncmyExceptionHandler +from sqlspec.adapters.asyncmy.driver import AsyncmyDriver, AsyncmyExceptionHandler __all__ = ( "AsyncmyConfig", diff --git a/sqlspec/adapters/asyncmy/_typing.py b/sqlspec/adapters/asyncmy/_typing.py index a7f316911..4211f0f1e 100644 --- a/sqlspec/adapters/asyncmy/_typing.py +++ b/sqlspec/adapters/asyncmy/_typing.py @@ -4,14 +4,15 @@ compilation to avoid ABI boundary issues. """ -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any from asyncmy import Connection # pyright: ignore +from asyncmy.cursors import Cursor as _AsyncmyCursor # pyright: ignore if TYPE_CHECKING: from collections.abc import Callable from types import TracebackType - from typing import Protocol + from typing import Protocol, TypeAlias from sqlspec.adapters.asyncmy.driver import AsyncmyDriver from sqlspec.core import StatementConfig @@ -26,8 +27,32 @@ async def rollback(self) -> Any: ... async def close(self) -> Any: ... AsyncmyConnection: TypeAlias = AsyncmyConnectionProtocol -else: + AsyncmyCursorType: TypeAlias = _AsyncmyCursor + +if not TYPE_CHECKING: AsyncmyConnection = Connection + AsyncmyCursorType = _AsyncmyCursor + + +class AsyncmyCursor: + """Context manager for AsyncMy cursor operations. + + Provides automatic cursor acquisition and cleanup for database operations. + """ + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "AsyncmyConnection") -> None: + self.connection = connection + self.cursor: Any = None + + async def __aenter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + async def __aexit__(self, *_: Any) -> None: + if self.cursor is not None: + await self.cursor.close() class AsyncmySessionContext: @@ -85,4 +110,4 @@ async def __aexit__( return None -__all__ = ("AsyncmyConnection", "AsyncmySessionContext") +__all__ = ("AsyncmyConnection", "AsyncmyCursor", "AsyncmyCursorType", "AsyncmySessionContext") diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 171c90332..4e6a9e3da 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -9,9 +9,9 @@ from mypy_extensions import mypyc_attr from typing_extensions import NotRequired -from sqlspec.adapters.asyncmy._typing import AsyncmyConnection +from sqlspec.adapters.asyncmy._typing import AsyncmyConnection, AsyncmyCursor, AsyncmySessionContext from sqlspec.adapters.asyncmy.core import apply_driver_features, default_statement_config -from sqlspec.adapters.asyncmy.driver import AsyncmyCursor, AsyncmyDriver, AsyncmyExceptionHandler, AsyncmySessionContext +from sqlspec.adapters.asyncmy.driver import AsyncmyDriver, AsyncmyExceptionHandler from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.extensions.events import EventRuntimeHints from sqlspec.utils.config_tools import normalize_connection_config diff --git a/sqlspec/adapters/asyncmy/driver.py b/sqlspec/adapters/asyncmy/driver.py index a55faead0..3051789f3 100644 --- a/sqlspec/adapters/asyncmy/driver.py +++ b/sqlspec/adapters/asyncmy/driver.py @@ -9,8 +9,8 @@ import asyncmy.errors # pyright: ignore from asyncmy.constants import FIELD_TYPE as ASYNC_MY_FIELD_TYPE # pyright: ignore -from asyncmy.cursors import Cursor, DictCursor # pyright: ignore +from sqlspec.adapters.asyncmy._typing import AsyncmyCursor, AsyncmySessionContext from sqlspec.adapters.asyncmy.core import ( build_insert_statement, collect_rows, @@ -42,8 +42,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.adapters.asyncmy._typing import AsyncmySessionContext - __all__ = ("AsyncmyCursor", "AsyncmyDriver", "AsyncmyExceptionHandler", "AsyncmySessionContext") logger = get_logger(__name__) @@ -54,27 +52,6 @@ ASYNCMY_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set() -class AsyncmyCursor: - """Context manager for AsyncMy cursor operations. - - Provides automatic cursor acquisition and cleanup for database operations. - """ - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "AsyncmyConnection") -> None: - self.connection = connection - self.cursor: Cursor | DictCursor | None = None - - async def __aenter__(self) -> Cursor | DictCursor: - self.cursor = self.connection.cursor() - return self.cursor - - async def __aexit__(self, *_: Any) -> None: - if self.cursor is not None: - await self.cursor.close() - - class AsyncmyExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling asyncmy (MySQL) database exceptions. @@ -129,7 +106,7 @@ def __init__( # CORE DISPATCH METHODS - The Execution Engine # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Handles parameter processing, result fetching, and data transformation @@ -168,7 +145,7 @@ async def dispatch_execute(self, cursor: "AsyncmyCursor", statement: "SQL") -> " last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - async def dispatch_execute_many(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL statement with multiple parameter sets. Uses AsyncMy's executemany for batch operations with MySQL type conversion @@ -194,7 +171,7 @@ async def dispatch_execute_many(self, cursor: "AsyncmyCursor", statement: "SQL") return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "AsyncmyCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Splits multi-statement scripts and executes each statement sequentially. @@ -380,7 +357,7 @@ def data_dictionary(self) -> "AsyncmyDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "AsyncmyCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect asyncmy rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -391,7 +368,7 @@ def collect_rows(self, cursor: "AsyncmyCursor", fetched: "list[Any]") -> "tuple[ ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: "AsyncmyCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from asyncmy cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/asyncpg/__init__.py b/sqlspec/adapters/asyncpg/__init__.py index c165dc7fc..3303b886f 100644 --- a/sqlspec/adapters/asyncpg/__init__.py +++ b/sqlspec/adapters/asyncpg/__init__.py @@ -1,9 +1,9 @@ """AsyncPG adapter for SQLSpec.""" -from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement +from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgCursor, AsyncpgPool, AsyncpgPreparedStatement from sqlspec.adapters.asyncpg.config import AsyncpgConfig, AsyncpgConnectionConfig, AsyncpgPoolConfig from sqlspec.adapters.asyncpg.core import default_statement_config -from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, AsyncpgExceptionHandler +from sqlspec.adapters.asyncpg.driver import AsyncpgDriver, AsyncpgExceptionHandler from sqlspec.dialects import postgres # noqa: F401 __all__ = ( diff --git a/sqlspec/adapters/asyncpg/_typing.py b/sqlspec/adapters/asyncpg/_typing.py index bcce033e3..d2d8b2000 100644 --- a/sqlspec/adapters/asyncpg/_typing.py +++ b/sqlspec/adapters/asyncpg/_typing.py @@ -6,15 +6,16 @@ from typing import TYPE_CHECKING, Any +from asyncpg import Pool from asyncpg.pool import PoolConnectionProxy +from asyncpg.prepared_stmt import PreparedStatement if TYPE_CHECKING: from collections.abc import Callable from types import TracebackType from typing import TypeAlias - from asyncpg import Connection, Pool, Record - from asyncpg.prepared_stmt import PreparedStatement + from asyncpg import Connection, Record from sqlspec.adapters.asyncpg.driver import AsyncpgDriver from sqlspec.core import StatementConfig @@ -22,15 +23,27 @@ AsyncpgConnection: TypeAlias = Connection[Record] | PoolConnectionProxy[Record] AsyncpgPool: TypeAlias = Pool[Record] AsyncpgPreparedStatement: TypeAlias = PreparedStatement[Record] -else: - from asyncpg import Pool - from asyncpg.prepared_stmt import PreparedStatement +if not TYPE_CHECKING: AsyncpgConnection = PoolConnectionProxy AsyncpgPool = Pool AsyncpgPreparedStatement = PreparedStatement +class AsyncpgCursor: + """Context manager for AsyncPG cursor management.""" + + __slots__ = ("connection",) + + def __init__(self, connection: "AsyncpgConnection") -> None: + self.connection = connection + + async def __aenter__(self) -> "AsyncpgConnection": + return self.connection + + async def __aexit__(self, *_: Any) -> None: ... + + class AsyncpgSessionContext: """Async context manager for AsyncPG sessions. @@ -87,4 +100,4 @@ async def __aexit__( return None -__all__ = ("AsyncpgConnection", "AsyncpgPool", "AsyncpgPreparedStatement", "AsyncpgSessionContext") +__all__ = ("AsyncpgConnection", "AsyncpgCursor", "AsyncpgPool", "AsyncpgPreparedStatement", "AsyncpgSessionContext") diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index ff314618c..46e429246 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -9,7 +9,13 @@ from mypy_extensions import mypyc_attr from typing_extensions import NotRequired -from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement +from sqlspec.adapters.asyncpg._typing import ( + AsyncpgConnection, + AsyncpgCursor, + AsyncpgPool, + AsyncpgPreparedStatement, + AsyncpgSessionContext, +) from sqlspec.adapters.asyncpg.core import ( apply_driver_features, build_connection_config, @@ -20,7 +26,7 @@ resolve_postgres_extension_state, resolve_runtime_statement_config, ) -from sqlspec.adapters.asyncpg.driver import AsyncpgCursor, AsyncpgDriver, AsyncpgExceptionHandler, AsyncpgSessionContext +from sqlspec.adapters.asyncpg.driver import AsyncpgDriver, AsyncpgExceptionHandler from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError from sqlspec.extensions.events import EventRuntimeHints @@ -133,7 +139,7 @@ class AsyncpgDriverFeatures(TypedDict): alloydb_instance_uri: AlloyDB instance URI. Format: "projects/PROJECT/locations/REGION/clusters/CLUSTER/instances/INSTANCE" Required when enable_alloydb is True. - alloydb_enable_iam_auth: Enable IAM database authentication. + enable_alloydb_iam_auth: Enable IAM database authentication. Defaults to False for passwordless authentication. alloydb_ip_type: IP address type for connection. Options: "PUBLIC", "PRIVATE", "PSC" @@ -162,7 +168,7 @@ class AsyncpgDriverFeatures(TypedDict): cloud_sql_ip_type: NotRequired[str] enable_alloydb: NotRequired[bool] alloydb_instance_uri: NotRequired[str] - alloydb_enable_iam_auth: NotRequired[bool] + enable_alloydb_iam_auth: NotRequired[bool] alloydb_ip_type: NotRequired[str] enable_events: NotRequired[bool] events_backend: NotRequired[str] @@ -216,7 +222,7 @@ async def __call__(self) -> "AsyncpgConnection": conn_kwargs: dict[str, Any] = { "instance_uri": self._config.driver_features["alloydb_instance_uri"], "driver": "asyncpg", - "enable_iam_auth": self._config.driver_features.get("alloydb_enable_iam_auth", False), + "enable_iam_auth": self._config.driver_features.get("enable_alloydb_iam_auth", False), "ip_type": self._config.driver_features.get("alloydb_ip_type", "PRIVATE"), } if self._user: diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index f879d2a15..303342a37 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -323,60 +323,46 @@ def _create_postgres_error( _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.UniqueViolationError, - ("23505", UniqueViolationError, "unique constraint violation"), + asyncpg.exceptions.UniqueViolationError, ("23505", UniqueViolationError, "unique constraint violation") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.ForeignKeyViolationError, - ("23503", ForeignKeyViolationError, "foreign key constraint violation"), + asyncpg.exceptions.ForeignKeyViolationError, ("23503", ForeignKeyViolationError, "foreign key constraint violation") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.NotNullViolationError, - ("23502", NotNullViolationError, "not-null constraint violation"), + asyncpg.exceptions.NotNullViolationError, ("23502", NotNullViolationError, "not-null constraint violation") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.CheckViolationError, - ("23514", CheckViolationError, "check constraint violation"), + asyncpg.exceptions.CheckViolationError, ("23514", CheckViolationError, "check constraint violation") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.IntegrityConstraintViolationError, - ("23000", IntegrityError, "integrity constraint violation"), + asyncpg.exceptions.IntegrityConstraintViolationError, ("23000", IntegrityError, "integrity constraint violation") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.DeadlockDetectedError, - ("40P01", DeadlockError, "deadlock detected"), + asyncpg.exceptions.DeadlockDetectedError, ("40P01", DeadlockError, "deadlock detected") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.SerializationError, - ("40001", SerializationConflictError, "serialization failure"), + asyncpg.exceptions.SerializationError, ("40001", SerializationConflictError, "serialization failure") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.QueryCanceledError, - ("57014", QueryTimeoutError, "query canceled"), + asyncpg.exceptions.QueryCanceledError, ("57014", QueryTimeoutError, "query canceled") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.InsufficientPrivilegeError, - ("42501", PermissionDeniedError, "insufficient privilege"), + asyncpg.exceptions.InsufficientPrivilegeError, ("42501", PermissionDeniedError, "insufficient privilege") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.InvalidPasswordError, - ("28P01", PermissionDeniedError, "invalid password"), + asyncpg.exceptions.InvalidPasswordError, ("28P01", PermissionDeniedError, "invalid password") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.InvalidAuthorizationSpecificationError, - ("28000", PermissionDeniedError, "authorization error"), + asyncpg.exceptions.InvalidAuthorizationSpecificationError, ("28000", PermissionDeniedError, "authorization error") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.ConnectionDoesNotExistError, - ("08003", ConnectionTimeoutError, "connection does not exist"), + asyncpg.exceptions.ConnectionDoesNotExistError, ("08003", ConnectionTimeoutError, "connection does not exist") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.CannotConnectNowError, - ("57P03", ConnectionTimeoutError, "cannot connect now"), + asyncpg.exceptions.CannotConnectNowError, ("57P03", ConnectionTimeoutError, "cannot connect now") ) _EXCEPTION_MAPPING_DISPATCHER.register( - asyncpg.exceptions.PostgresSyntaxError, - ("42601", SQLParsingError, "SQL syntax error"), + asyncpg.exceptions.PostgresSyntaxError, ("42601", SQLParsingError, "SQL syntax error") ) @@ -405,14 +391,14 @@ def create_mapped_exception(error: Any) -> SQLSpecError: # Priority 2: Fall back to SQLSTATE code mapping using centralized utility sqlstate_attr = error.sqlstate if has_sqlstate(error) else None - error_code = sqlstate_attr if sqlstate_attr is not None else None - if error_code: - exc_class = map_sqlstate_to_exception(error_code) + sqlstate_code: str | None = sqlstate_attr if sqlstate_attr is not None else None + if sqlstate_code: + exc_class = map_sqlstate_to_exception(sqlstate_code) if exc_class: - return _create_postgres_error(error, error_code, exc_class, "database error") + return _create_postgres_error(error, sqlstate_code, exc_class, "database error") # Priority 3: Default fallback - return _create_postgres_error(error, error_code, SQLSpecError, "database error") + return _create_postgres_error(error, sqlstate_code, SQLSpecError, "database error") def collect_rows(records: "list[Any] | None") -> "tuple[list[Any], list[str]]": diff --git a/sqlspec/adapters/asyncpg/driver.py b/sqlspec/adapters/asyncpg/driver.py index cb00ce7dc..c00c02b5d 100644 --- a/sqlspec/adapters/asyncpg/driver.py +++ b/sqlspec/adapters/asyncpg/driver.py @@ -6,6 +6,7 @@ import asyncpg +from sqlspec.adapters.asyncpg._typing import AsyncpgCursor, AsyncpgSessionContext from sqlspec.adapters.asyncpg.core import ( PREPARED_STATEMENT_CACHE_SIZE, NormalizedStackOperation, @@ -46,27 +47,12 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.adapters.asyncpg._typing import AsyncpgSessionContext __all__ = ("AsyncpgCursor", "AsyncpgDriver", "AsyncpgExceptionHandler", "AsyncpgSessionContext") logger = get_logger("sqlspec.adapters.asyncpg") -class AsyncpgCursor: - """Context manager for AsyncPG cursor management.""" - - __slots__ = ("connection",) - - def __init__(self, connection: "AsyncpgConnection") -> None: - self.connection = connection - - async def __aenter__(self) -> "AsyncpgConnection": - return self.connection - - async def __aexit__(self, *_: Any) -> None: ... - - class AsyncpgExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling AsyncPG database exceptions. @@ -438,12 +424,12 @@ def data_dictionary(self) -> "AsyncpgDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "AsyncpgCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "AsyncpgConnection", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect asyncpg rows for the direct execution path.""" data, column_names = collect_rows(fetched) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "AsyncpgCursor") -> int: + def resolve_rowcount(self, cursor: "AsyncpgConnection") -> int: """Resolve rowcount from asyncpg status for the direct execution path.""" return parse_status(cursor) diff --git a/sqlspec/adapters/bigquery/__init__.py b/sqlspec/adapters/bigquery/__init__.py index def87b2d5..0dce48061 100644 --- a/sqlspec/adapters/bigquery/__init__.py +++ b/sqlspec/adapters/bigquery/__init__.py @@ -1,7 +1,7 @@ -from sqlspec.adapters.bigquery._typing import BigQueryConnection +from sqlspec.adapters.bigquery._typing import BigQueryConnection, BigQueryCursor from sqlspec.adapters.bigquery.config import BigQueryConfig, BigQueryConnectionParams from sqlspec.adapters.bigquery.core import default_statement_config -from sqlspec.adapters.bigquery.driver import BigQueryCursor, BigQueryDriver, BigQueryExceptionHandler +from sqlspec.adapters.bigquery.driver import BigQueryDriver, BigQueryExceptionHandler __all__ = ( "BigQueryConfig", diff --git a/sqlspec/adapters/bigquery/_typing.py b/sqlspec/adapters/bigquery/_typing.py index 527d572aa..c660e58c7 100644 --- a/sqlspec/adapters/bigquery/_typing.py +++ b/sqlspec/adapters/bigquery/_typing.py @@ -4,29 +4,50 @@ compilation to avoid ABI boundary issues. """ +import logging from typing import TYPE_CHECKING, Any +from google.cloud.bigquery import ArrayQueryParameter, Client, QueryJob, ScalarQueryParameter + if TYPE_CHECKING: from collections.abc import Callable from types import TracebackType from typing import TypeAlias - from google.cloud.bigquery import ArrayQueryParameter, Client, ScalarQueryParameter - from sqlspec.adapters.bigquery.driver import BigQueryDriver from sqlspec.core import StatementConfig BigQueryConnection: TypeAlias = Client BigQueryParam: TypeAlias = ArrayQueryParameter | ScalarQueryParameter -else: - try: - from google.cloud.bigquery import ArrayQueryParameter, Client, ScalarQueryParameter - except Exception: - BigQueryConnection = Any - BigQueryParam = Any - else: - BigQueryConnection = Client - BigQueryParam = ArrayQueryParameter | ScalarQueryParameter + +if not TYPE_CHECKING: + BigQueryConnection = Client + BigQueryParam = ArrayQueryParameter | ScalarQueryParameter + + +class BigQueryCursor: + """BigQuery cursor with resource management.""" + + __slots__ = ("connection", "job") + + def __init__(self, connection: "BigQueryConnection") -> None: + self.connection = connection + self.job: QueryJob | None = None + + def __enter__(self) -> "BigQueryConnection": + return self.connection + + def __exit__(self, *_: Any) -> None: + """Clean up cursor resources including active QueryJobs.""" + if self.job is not None: + try: + # Cancel the job if it's still running to free up resources + if self.job.state in {"PENDING", "RUNNING"}: + self.job.cancel() + # Clear the job reference + self.job = None + except Exception: + logging.getLogger(__name__).exception("Failed to cancel BigQuery job during cursor cleanup") class BigQuerySessionContext: @@ -84,4 +105,4 @@ def __exit__( return None -__all__ = ("BigQueryConnection", "BigQueryParam", "BigQuerySessionContext") +__all__ = ("BigQueryConnection", "BigQueryCursor", "BigQueryParam", "BigQuerySessionContext") diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 99764b023..6ecdbd5f0 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -5,14 +5,9 @@ from google.cloud.bigquery import LoadJobConfig, QueryJobConfig from typing_extensions import NotRequired -from sqlspec.adapters.bigquery._typing import BigQueryConnection +from sqlspec.adapters.bigquery._typing import BigQueryConnection, BigQueryCursor, BigQuerySessionContext from sqlspec.adapters.bigquery.core import apply_driver_features, build_statement_config, default_statement_config -from sqlspec.adapters.bigquery.driver import ( - BigQueryCursor, - BigQueryDriver, - BigQueryExceptionHandler, - BigQuerySessionContext, -) +from sqlspec.adapters.bigquery.driver import BigQueryDriver, BigQueryExceptionHandler from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.extensions.events import EventRuntimeHints diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index a54ab0640..1b5114777 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -10,7 +10,7 @@ from google.cloud.exceptions import GoogleCloudError -from sqlspec.adapters.bigquery._typing import BigQueryConnection, BigQuerySessionContext +from sqlspec.adapters.bigquery._typing import BigQueryConnection, BigQueryCursor, BigQuerySessionContext from sqlspec.adapters.bigquery.core import ( build_dml_rowcount, build_inlined_script, @@ -65,31 +65,6 @@ __all__ = ("BigQueryCursor", "BigQueryDriver", "BigQueryExceptionHandler", "BigQuerySessionContext") -class BigQueryCursor: - """BigQuery cursor with resource management.""" - - __slots__ = ("connection", "job") - - def __init__(self, connection: "BigQueryConnection") -> None: - self.connection = connection - self.job: QueryJob | None = None - - def __enter__(self) -> "BigQueryConnection": - return self.connection - - def __exit__(self, *_: Any) -> None: - """Clean up cursor resources including active QueryJobs.""" - if self.job is not None: - try: - # Cancel the job if it's still running to free up resources - if self.job.state in {"PENDING", "RUNNING"}: - self.job.cancel() - # Clear the job reference - self.job = None - except Exception: - logger.exception("Failed to cancel BigQuery job during cursor cleanup") - - class BigQueryExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling BigQuery API exceptions. @@ -165,7 +140,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: + def dispatch_execute(self, cursor: Any, statement: "SQL") -> ExecutionResult: """Execute single SQL statement with BigQuery data handling. Args: @@ -206,7 +181,7 @@ def dispatch_execute(self, cursor: "BigQueryCursor", statement: "SQL") -> Execut affected_rows = build_dml_rowcount(cursor.job, 0) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> ExecutionResult: """BigQuery execute_many with Parquet bulk load optimization. Uses Parquet bulk load for INSERT operations (fast path) and falls back @@ -255,7 +230,7 @@ def dispatch_execute_many(self, cursor: "BigQueryCursor", statement: "SQL") -> E affected_rows = build_dml_rowcount(cursor.job, len(prepared_parameters)) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "BigQueryCursor", statement: "SQL") -> ExecutionResult: + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> ExecutionResult: """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. @@ -541,14 +516,14 @@ def data_dictionary(self) -> "BigQueryDataDictionary": # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "BigQueryCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect BigQuery rows for the direct execution path.""" schema = cursor.job.schema if cursor.job else None column_names = resolve_column_names(schema, self._column_name_cache) data, _ = collect_rows(fetched, schema, column_names=column_names) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "BigQueryCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from BigQuery job for the direct execution path.""" return build_dml_rowcount(cursor.job, 0) if cursor.job else 0 diff --git a/sqlspec/adapters/cockroach_asyncpg/_typing.py b/sqlspec/adapters/cockroach_asyncpg/_typing.py index e721dc5ea..ff3944a6d 100644 --- a/sqlspec/adapters/cockroach_asyncpg/_typing.py +++ b/sqlspec/adapters/cockroach_asyncpg/_typing.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any +from asyncpg import Pool from asyncpg.pool import PoolConnectionProxy if TYPE_CHECKING: @@ -9,16 +10,15 @@ from types import TracebackType from typing import TypeAlias - from asyncpg import Connection, Pool, Record + from asyncpg import Connection, Record from sqlspec.adapters.cockroach_asyncpg.driver import CockroachAsyncpgDriver from sqlspec.core import StatementConfig CockroachAsyncpgConnection: TypeAlias = Connection[Record] | PoolConnectionProxy[Record] CockroachAsyncpgPool: TypeAlias = Pool[Record] -else: - from asyncpg import Pool +if not TYPE_CHECKING: CockroachAsyncpgConnection = PoolConnectionProxy CockroachAsyncpgPool = Pool diff --git a/sqlspec/adapters/cockroach_asyncpg/driver.py b/sqlspec/adapters/cockroach_asyncpg/driver.py index c3664953a..149687158 100644 --- a/sqlspec/adapters/cockroach_asyncpg/driver.py +++ b/sqlspec/adapters/cockroach_asyncpg/driver.py @@ -113,17 +113,17 @@ async def _dispatch_execute_script_impl( ) -> "ExecutionResult": return await super().dispatch_execute_script(cursor, statement) - async def dispatch_execute(self, cursor: "CockroachAsyncpgConnection", statement: SQL) -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await self._dispatch_execute_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_impl, cursor, statement) - async def dispatch_execute_many(self, cursor: "CockroachAsyncpgConnection", statement: SQL) -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_many(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) - async def dispatch_execute_script(self, cursor: "CockroachAsyncpgConnection", statement: SQL) -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_script(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) diff --git a/sqlspec/adapters/cockroach_psycopg/_typing.py b/sqlspec/adapters/cockroach_psycopg/_typing.py index 1e5e224c2..263bc6c07 100644 --- a/sqlspec/adapters/cockroach_psycopg/_typing.py +++ b/sqlspec/adapters/cockroach_psycopg/_typing.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any +from psycopg import AsyncCursor, Cursor from psycopg import crdb as psycopg_crdb from psycopg.rows import DictRow as PsycopgDictRow @@ -19,12 +20,16 @@ from sqlspec.adapters.cockroach_psycopg.driver import CockroachPsycopgAsyncDriver, CockroachPsycopgSyncDriver from sqlspec.core import StatementConfig - # Parametrize with DictRow so type system knows rows are dict-like CockroachSyncConnection: TypeAlias = CrdbConnection[PsycopgDictRow] CockroachAsyncConnection: TypeAlias = AsyncCrdbConnection[PsycopgDictRow] -else: + CockroachSyncCursor: TypeAlias = Cursor[PsycopgDictRow] + CockroachAsyncCursor: TypeAlias = AsyncCursor[PsycopgDictRow] + +if not TYPE_CHECKING: CockroachSyncConnection = psycopg_crdb.CrdbConnection CockroachAsyncConnection = psycopg_crdb.AsyncCrdbConnection + CockroachSyncCursor = Cursor + CockroachAsyncCursor = AsyncCursor class CockroachPsycopgSyncSessionContext: @@ -123,8 +128,10 @@ async def __aexit__( __all__ = ( "CockroachAsyncConnection", + "CockroachAsyncCursor", "CockroachPsycopgAsyncSessionContext", "CockroachPsycopgSyncSessionContext", "CockroachSyncConnection", + "CockroachSyncCursor", "PsycopgDictRow", ) diff --git a/sqlspec/adapters/cockroach_psycopg/driver.py b/sqlspec/adapters/cockroach_psycopg/driver.py index 0dea542db..8a9f09f6b 100644 --- a/sqlspec/adapters/cockroach_psycopg/driver.py +++ b/sqlspec/adapters/cockroach_psycopg/driver.py @@ -36,7 +36,6 @@ if TYPE_CHECKING: from collections.abc import Callable - from sqlspec.adapters.psycopg.driver import PsycopgAsyncCursor, PsycopgSyncCursor from sqlspec.driver import ExecutionResult __all__ = ( @@ -134,35 +133,35 @@ def _execute_with_retry(self, operation: "Callable[..., ExecutionResult]", *args msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error - def _apply_follower_reads(self, cursor: "PsycopgSyncCursor") -> None: + def _apply_follower_reads(self, cursor: Any) -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") - def _dispatch_execute_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": if statement.returns_rows(): self._apply_follower_reads(cursor) return super().dispatch_execute(cursor, statement) - def _dispatch_execute_many_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return super().dispatch_execute_many(cursor, statement) - def _dispatch_execute_script_impl(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return super().dispatch_execute_script(cursor, statement) - def dispatch_execute(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return self._dispatch_execute_impl(cursor, statement) return self._execute_with_retry(self._dispatch_execute_impl, cursor, statement) - def dispatch_execute_many(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_many(cursor, statement) return self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) - def dispatch_execute_script(self, cursor: "PsycopgSyncCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return super().dispatch_execute_script(cursor, statement) return self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) @@ -227,35 +226,35 @@ async def _execute_with_retry(self, operation: "Callable[..., Any]", *args: Any) msg = "CockroachDB transaction retry limit exceeded" raise TransactionRetryError(msg) from last_error - async def _apply_follower_reads(self, cursor: "PsycopgAsyncCursor") -> None: + async def _apply_follower_reads(self, cursor: Any) -> None: if not self.driver_features.get("enable_follower_reads", False): return if not self._follower_staleness: return await cursor.execute(f"SET TRANSACTION AS OF SYSTEM TIME {self._follower_staleness}") - async def _dispatch_execute_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": if statement.returns_rows(): await self._apply_follower_reads(cursor) return await super().dispatch_execute(cursor, statement) - async def _dispatch_execute_many_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_many_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_many(cursor, statement) - async def _dispatch_execute_script_impl(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def _dispatch_execute_script_impl(self, cursor: Any, statement: SQL) -> "ExecutionResult": return await super().dispatch_execute_script(cursor, statement) - async def dispatch_execute(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await self._dispatch_execute_impl(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_impl, cursor, statement) - async def dispatch_execute_many(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_many(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_many_impl, cursor, statement) - async def dispatch_execute_script(self, cursor: "PsycopgAsyncCursor", statement: SQL) -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: SQL) -> "ExecutionResult": if not self._enable_retry: return await super().dispatch_execute_script(cursor, statement) return await self._execute_with_retry(self._dispatch_execute_script_impl, cursor, statement) diff --git a/sqlspec/adapters/duckdb/__init__.py b/sqlspec/adapters/duckdb/__init__.py index d9eaa3087..b862b8bf4 100644 --- a/sqlspec/adapters/duckdb/__init__.py +++ b/sqlspec/adapters/duckdb/__init__.py @@ -1,6 +1,6 @@ """DuckDB adapter for SQLSpec.""" -from sqlspec.adapters.duckdb._typing import DuckDBConnection +from sqlspec.adapters.duckdb._typing import DuckDBConnection, DuckDBCursor from sqlspec.adapters.duckdb.config import ( DuckDBConfig, DuckDBConnectionParams, @@ -8,7 +8,7 @@ DuckDBSecretConfig, ) from sqlspec.adapters.duckdb.core import default_statement_config -from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, DuckDBExceptionHandler +from sqlspec.adapters.duckdb.driver import DuckDBDriver, DuckDBExceptionHandler from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool __all__ = ( diff --git a/sqlspec/adapters/duckdb/_typing.py b/sqlspec/adapters/duckdb/_typing.py index 5e38ee3a6..78816debd 100644 --- a/sqlspec/adapters/duckdb/_typing.py +++ b/sqlspec/adapters/duckdb/_typing.py @@ -24,6 +24,28 @@ DuckDBConnection = _DuckDBConnection +class DuckDBCursor: + """Context manager for DuckDB connection-as-cursor. + + DuckDB connections implement the cursor interface and preserve + variable state. Using connection directly avoids cursor overhead + and fixes SET VARIABLE persistence. + + See: https://github.com/litestar-org/sqlspec/issues/341 + """ + + __slots__ = ("connection",) + + def __init__(self, connection: "DuckDBConnection") -> None: + self.connection = connection + + def __enter__(self) -> "DuckDBConnection": + return self.connection + + def __exit__(self, *_: Any) -> None: + pass # Connection lifecycle managed by pool/session + + class DuckDBSessionContext: """Sync context manager for DuckDB sessions. @@ -79,4 +101,4 @@ def __exit__( return None -__all__ = ("DuckDBConnection", "DuckDBSessionContext") +__all__ = ("DuckDBConnection", "DuckDBCursor", "DuckDBSessionContext") diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index 4165d2c8a..f6714a5a2 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -5,14 +5,14 @@ from typing_extensions import NotRequired -from sqlspec.adapters.duckdb._typing import DuckDBConnection +from sqlspec.adapters.duckdb._typing import DuckDBConnection, DuckDBCursor, DuckDBSessionContext from sqlspec.adapters.duckdb.core import ( apply_driver_features, build_connection_config, build_statement_config, default_statement_config, ) -from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, DuckDBExceptionHandler, DuckDBSessionContext +from sqlspec.adapters.duckdb.driver import DuckDBDriver, DuckDBExceptionHandler from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig from sqlspec.extensions.events import EventRuntimeHints diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index ece710b54..a08588e53 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -6,6 +6,7 @@ import duckdb +from sqlspec.adapters.duckdb._typing import DuckDBCursor, DuckDBSessionContext from sqlspec.adapters.duckdb.core import ( apply_driver_features, collect_rows, @@ -31,7 +32,6 @@ from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry from sqlspec.typing import ArrowReturnFormat, StatementParameters -from sqlspec.adapters.duckdb._typing import DuckDBSessionContext __all__ = ("DuckDBCursor", "DuckDBDriver", "DuckDBExceptionHandler", "DuckDBSessionContext") @@ -40,28 +40,6 @@ _type_converter = DuckDBOutputConverter() -class DuckDBCursor: - """Context manager for DuckDB connection-as-cursor. - - DuckDB connections implement the cursor interface and preserve - variable state. Using connection directly avoids cursor overhead - and fixes SET VARIABLE persistence. - - See: https://github.com/litestar-org/sqlspec/issues/341 - """ - - __slots__ = ("connection",) - - def __init__(self, connection: "DuckDBConnection") -> None: - self.connection = connection - - def __enter__(self) -> "DuckDBConnection": - return self.connection - - def __exit__(self, *_: Any) -> None: - pass # Connection lifecycle managed by pool/session - - class DuckDBExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling DuckDB database exceptions. @@ -116,7 +94,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult": """Execute single SQL statement with data handling. Executes a SQL statement with parameter binding and processes the results. @@ -152,7 +130,7 @@ def dispatch_execute(self, cursor: "DuckDBCursor", statement: SQL) -> "Execution return self.create_execution_result(cursor, rowcount_override=row_count) - def dispatch_execute_many(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute_many(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult": """Execute SQL with multiple parameter sets using batch processing. Uses DuckDB's executemany method for batch operations and calculates @@ -177,7 +155,7 @@ def dispatch_execute_many(self, cursor: "DuckDBCursor", statement: SQL) -> "Exec return self.create_execution_result(cursor, rowcount_override=row_count, is_many_result=True) - def dispatch_execute_script(self, cursor: "DuckDBCursor", statement: SQL) -> "ExecutionResult": + def dispatch_execute_script(self, cursor: "DuckDBConnection", statement: SQL) -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parses multi-statement scripts and executes each statement sequentially @@ -422,12 +400,12 @@ def data_dictionary(self) -> "DuckDBDataDictionary": # PRIVATE / INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "DuckDBCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "DuckDBConnection", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect DuckDB rows for the direct execution path.""" data, column_names = collect_rows(cast("list[Any] | None", fetched), cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "DuckDBCursor") -> int: + def resolve_rowcount(self, cursor: "DuckDBConnection") -> int: """Resolve rowcount from DuckDB cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/mock/__init__.py b/sqlspec/adapters/mock/__init__.py index 44bb7710f..5934509ff 100644 --- a/sqlspec/adapters/mock/__init__.py +++ b/sqlspec/adapters/mock/__init__.py @@ -50,13 +50,20 @@ 1 """ -from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockConnection, MockSyncSessionContext +from sqlspec.adapters.mock._typing import ( + MockAsyncCursor, + MockAsyncSessionContext, + MockConnection, + MockCursor, + MockSyncSessionContext, +) from sqlspec.adapters.mock.config import MockAsyncConfig, MockConnectionParams, MockDriverFeatures, MockSyncConfig from sqlspec.adapters.mock.data_dictionary import MockAsyncDataDictionary, MockDataDictionary -from sqlspec.adapters.mock.driver import MockAsyncDriver, MockCursor, MockExceptionHandler, MockSyncDriver +from sqlspec.adapters.mock.driver import MockAsyncDriver, MockExceptionHandler, MockSyncDriver __all__ = ( "MockAsyncConfig", + "MockAsyncCursor", "MockAsyncDataDictionary", "MockAsyncDriver", "MockAsyncSessionContext", diff --git a/sqlspec/adapters/mock/_typing.py b/sqlspec/adapters/mock/_typing.py index 53ab3d38e..2bc615a47 100644 --- a/sqlspec/adapters/mock/_typing.py +++ b/sqlspec/adapters/mock/_typing.py @@ -4,6 +4,7 @@ compilation to avoid ABI boundary issues. """ +import contextlib import sqlite3 from typing import TYPE_CHECKING, Any @@ -18,9 +19,76 @@ from sqlspec.core import StatementConfig MockConnection: TypeAlias = _MockConnection + MockRawCursor: TypeAlias = sqlite3.Cursor if not TYPE_CHECKING: MockConnection = _MockConnection + MockRawCursor = sqlite3.Cursor + + +class MockCursor: + """Context manager for Mock SQLite cursor management. + + Provides automatic cursor creation and cleanup for SQLite database operations. + """ + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "MockConnection") -> None: + """Initialize cursor manager. + + Args: + connection: SQLite database connection + """ + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + """Create and return a new cursor. + + Returns: + Active SQLite cursor object + """ + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + """Clean up cursor resources.""" + if self.cursor is not None: + with contextlib.suppress(Exception): + self.cursor.close() + + +class MockAsyncCursor: + """Async context manager for Mock SQLite cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "MockConnection") -> None: + """Initialize async cursor manager. + + Args: + connection: SQLite database connection + """ + self.connection = connection + self.cursor: Any = None + + async def __aenter__(self) -> Any: + """Create and return a new cursor. + + Returns: + Active SQLite cursor object + """ + self.cursor = self.connection.cursor() + return self.cursor + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + """Clean up cursor resources.""" + if self.cursor is not None: + with contextlib.suppress(Exception): + self.cursor.close() class MockSyncSessionContext: @@ -145,4 +213,11 @@ async def __aexit__( return None -__all__ = ("MockAsyncSessionContext", "MockConnection", "MockSyncSessionContext") +__all__ = ( + "MockAsyncCursor", + "MockAsyncSessionContext", + "MockConnection", + "MockCursor", + "MockRawCursor", + "MockSyncSessionContext", +) diff --git a/sqlspec/adapters/mock/config.py b/sqlspec/adapters/mock/config.py index 8dd59f23e..b4accaab1 100644 --- a/sqlspec/adapters/mock/config.py +++ b/sqlspec/adapters/mock/config.py @@ -10,9 +10,9 @@ from typing_extensions import NotRequired -from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockConnection, MockSyncSessionContext +from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockConnection, MockCursor, MockSyncSessionContext from sqlspec.adapters.mock.core import apply_driver_features, default_statement_config -from sqlspec.adapters.mock.driver import MockAsyncDriver, MockCursor, MockExceptionHandler, MockSyncDriver +from sqlspec.adapters.mock.driver import MockAsyncDriver, MockExceptionHandler, MockSyncDriver from sqlspec.config import ExtensionConfigs, NoPoolAsyncConfig, NoPoolSyncConfig from sqlspec.driver import convert_to_dialect from sqlspec.utils.sync_tools import async_ diff --git a/sqlspec/adapters/mock/driver.py b/sqlspec/adapters/mock/driver.py index 3466b1d39..b2cb289c0 100644 --- a/sqlspec/adapters/mock/driver.py +++ b/sqlspec/adapters/mock/driver.py @@ -6,11 +6,10 @@ execution using sqlglot. """ -import contextlib import sqlite3 from typing import TYPE_CHECKING, Any -from sqlspec.adapters.mock._typing import MockAsyncSessionContext, MockSyncSessionContext +from sqlspec.adapters.mock._typing import MockAsyncCursor, MockAsyncSessionContext, MockCursor, MockSyncSessionContext from sqlspec.adapters.mock.core import ( build_insert_statement, collect_rows, @@ -35,14 +34,13 @@ from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: - from types import TracebackType - from sqlspec.adapters.mock._typing import MockConnection from sqlspec.core import SQL, StatementConfig from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry __all__ = ( + "MockAsyncCursor", "MockAsyncDriver", "MockAsyncSessionContext", "MockCursor", @@ -52,71 +50,6 @@ ) -class MockCursor: - """Context manager for Mock SQLite cursor management. - - Provides automatic cursor creation and cleanup for SQLite database operations. - """ - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MockConnection") -> None: - """Initialize cursor manager. - - Args: - connection: SQLite database connection - """ - self.connection = connection - self.cursor: sqlite3.Cursor | None = None - - def __enter__(self) -> "sqlite3.Cursor": - """Create and return a new cursor. - - Returns: - Active SQLite cursor object - """ - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - """Clean up cursor resources.""" - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() - - -class MockAsyncCursor: - """Async context manager for Mock SQLite cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MockConnection") -> None: - """Initialize async cursor manager. - - Args: - connection: SQLite database connection - """ - self.connection = connection - self.cursor: sqlite3.Cursor | None = None - - async def __aenter__(self) -> "sqlite3.Cursor": - """Create and return a new cursor. - - Returns: - Active SQLite cursor object - """ - self.cursor = self.connection.cursor() - return self.cursor - - async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> None: - """Clean up cursor resources.""" - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() - - class MockExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling SQLite database exceptions. @@ -195,7 +128,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: @@ -224,7 +157,7 @@ def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "Execu affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: @@ -242,7 +175,7 @@ def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> " return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Args: @@ -431,11 +364,11 @@ def _transpile_to_sqlite(self, statement: "SQL") -> str: return sql return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - def collect_rows(self, cursor: "MockCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mock sync rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: "MockCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mock cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -487,7 +420,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement asynchronously. Args: @@ -519,7 +452,7 @@ async def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets asynchronously. Args: @@ -538,7 +471,7 @@ async def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL" return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script asynchronously. Args: @@ -735,11 +668,11 @@ def _transpile_to_sqlite(self, statement: "SQL") -> str: return sql return convert_to_dialect(statement, self._target_dialect, "sqlite", pretty=False) - def collect_rows(self, cursor: "MockAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mock async rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: "MockAsyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mock cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/mysqlconnector/__init__.py b/sqlspec/adapters/mysqlconnector/__init__.py index 2923e440b..49b72b3e8 100644 --- a/sqlspec/adapters/mysqlconnector/__init__.py +++ b/sqlspec/adapters/mysqlconnector/__init__.py @@ -1,4 +1,9 @@ -from sqlspec.adapters.mysqlconnector._typing import MysqlConnectorAsyncConnection, MysqlConnectorSyncConnection +from sqlspec.adapters.mysqlconnector._typing import ( + MysqlConnectorAsyncConnection, + MysqlConnectorAsyncCursor, + MysqlConnectorSyncConnection, + MysqlConnectorSyncCursor, +) from sqlspec.adapters.mysqlconnector.config import ( MysqlConnectorAsyncConfig, MysqlConnectorAsyncConnectionParams, @@ -9,10 +14,8 @@ ) from sqlspec.adapters.mysqlconnector.core import default_statement_config from sqlspec.adapters.mysqlconnector.driver import ( - MysqlConnectorAsyncCursor, MysqlConnectorAsyncDriver, MysqlConnectorAsyncExceptionHandler, - MysqlConnectorSyncCursor, MysqlConnectorSyncDriver, MysqlConnectorSyncExceptionHandler, ) diff --git a/sqlspec/adapters/mysqlconnector/_typing.py b/sqlspec/adapters/mysqlconnector/_typing.py index 0d9066a1e..e625f3f8b 100644 --- a/sqlspec/adapters/mysqlconnector/_typing.py +++ b/sqlspec/adapters/mysqlconnector/_typing.py @@ -35,11 +35,48 @@ async def close(self) -> Any: ... MysqlConnectorSyncConnection: TypeAlias = _MysqlConnectorSyncConnection MysqlConnectorAsyncConnection: TypeAlias = MysqlConnectorAsyncConnectionProtocol -else: + +if not TYPE_CHECKING: MysqlConnectorSyncConnection = _MysqlConnectorSyncConnection MysqlConnectorAsyncConnection = _MysqlConnectorAsyncConnection +class MysqlConnectorSyncCursor: + """Context manager for mysql-connector sync cursor operations.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "MysqlConnectorSyncConnection") -> None: + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + if self.cursor is not None: + self.cursor.close() + + +class MysqlConnectorAsyncCursor: + """Async context manager for mysql-connector async cursor operations.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "MysqlConnectorAsyncConnection") -> None: + self.connection = connection + self.cursor: Any | None = None + + async def __aenter__(self) -> Any: + self.cursor = await self.connection.cursor() + return self.cursor + + async def __aexit__(self, *_: Any) -> None: + if self.cursor is not None: + await self.cursor.close() + + class MysqlConnectorSyncSessionContext: """Sync context manager for mysql-connector sessions.""" @@ -136,7 +173,9 @@ async def __aexit__( __all__ = ( "MysqlConnectorAsyncConnection", + "MysqlConnectorAsyncCursor", "MysqlConnectorAsyncSessionContext", "MysqlConnectorSyncConnection", + "MysqlConnectorSyncCursor", "MysqlConnectorSyncSessionContext", ) diff --git a/sqlspec/adapters/mysqlconnector/config.py b/sqlspec/adapters/mysqlconnector/config.py index 72759b25b..63dd3e062 100644 --- a/sqlspec/adapters/mysqlconnector/config.py +++ b/sqlspec/adapters/mysqlconnector/config.py @@ -10,16 +10,16 @@ from sqlspec.adapters.mysqlconnector._typing import ( MysqlConnectorAsyncConnection, + MysqlConnectorAsyncCursor, MysqlConnectorAsyncSessionContext, MysqlConnectorSyncConnection, + MysqlConnectorSyncCursor, MysqlConnectorSyncSessionContext, ) from sqlspec.adapters.mysqlconnector.core import apply_driver_features, default_statement_config from sqlspec.adapters.mysqlconnector.driver import ( - MysqlConnectorAsyncCursor, MysqlConnectorAsyncDriver, MysqlConnectorAsyncExceptionHandler, - MysqlConnectorSyncCursor, MysqlConnectorSyncDriver, MysqlConnectorSyncExceptionHandler, ) diff --git a/sqlspec/adapters/mysqlconnector/driver.py b/sqlspec/adapters/mysqlconnector/driver.py index f488159b2..0b82ee88d 100644 --- a/sqlspec/adapters/mysqlconnector/driver.py +++ b/sqlspec/adapters/mysqlconnector/driver.py @@ -49,7 +49,12 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.adapters.mysqlconnector._typing import MysqlConnectorAsyncSessionContext, MysqlConnectorSyncSessionContext +from sqlspec.adapters.mysqlconnector._typing import ( + MysqlConnectorAsyncCursor, + MysqlConnectorAsyncSessionContext, + MysqlConnectorSyncCursor, + MysqlConnectorSyncSessionContext, +) __all__ = ( "MysqlConnectorAsyncCursor", @@ -68,24 +73,6 @@ MYSQLCONNECTOR_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set() -class MysqlConnectorSyncCursor: - """Context manager for mysql-connector sync cursor operations.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MysqlConnectorSyncConnection") -> None: - self.connection = connection - self.cursor: Any | None = None - - def __enter__(self) -> Any: - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - if self.cursor is not None: - self.cursor.close() - - class MysqlConnectorSyncExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling mysql-connector sync exceptions.""" @@ -123,7 +110,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorSyncDataDictionary | None = None - def dispatch_execute(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -150,7 +137,7 @@ def dispatch_execute(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - def dispatch_execute_many(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -160,7 +147,7 @@ def dispatch_execute_many(self, cursor: "MysqlConnectorSyncCursor", statement: " affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "MysqlConnectorSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -275,9 +262,7 @@ def data_dictionary(self) -> "MysqlConnectorSyncDataDictionary": self._data_dictionary = MysqlConnectorSyncDataDictionary() return self._data_dictionary - def collect_rows( - self, cursor: "MysqlConnectorSyncCursor", fetched: "list[Any]" - ) -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mysql-connector sync rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -288,7 +273,7 @@ def collect_rows( ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: "MysqlConnectorSyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -302,24 +287,6 @@ def _connection_in_transaction(self) -> bool: return False -class MysqlConnectorAsyncCursor: - """Async context manager for mysql-connector async cursor operations.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "MysqlConnectorAsyncConnection") -> None: - self.connection = connection - self.cursor: Any | None = None - - async def __aenter__(self) -> Any: - self.cursor = await self.connection.cursor() - return self.cursor - - async def __aexit__(self, *_: Any) -> None: - if self.cursor is not None: - await self.cursor.close() - - class MysqlConnectorAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling mysql-connector exceptions.""" @@ -357,7 +324,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: MysqlConnectorAsyncDataDictionary | None = None - async def dispatch_execute(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -384,7 +351,7 @@ async def dispatch_execute(self, cursor: "MysqlConnectorAsyncCursor", statement: last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - async def dispatch_execute_many(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -394,7 +361,7 @@ async def dispatch_execute_many(self, cursor: "MysqlConnectorAsyncCursor", state affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "MysqlConnectorAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -511,9 +478,7 @@ def data_dictionary(self) -> "MysqlConnectorAsyncDataDictionary": self._data_dictionary = MysqlConnectorAsyncDataDictionary() return self._data_dictionary - def collect_rows( - self, cursor: "MysqlConnectorAsyncCursor", fetched: "list[Any]" - ) -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect mysql-connector async rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -524,7 +489,7 @@ def collect_rows( ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: "MysqlConnectorAsyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from mysql-connector cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/oracledb/__init__.py b/sqlspec/adapters/oracledb/__init__.py index 2b426a322..dee45f5bc 100644 --- a/sqlspec/adapters/oracledb/__init__.py +++ b/sqlspec/adapters/oracledb/__init__.py @@ -7,7 +7,12 @@ numpy_output_type_handler, register_numpy_handlers, ) -from sqlspec.adapters.oracledb._typing import OracleAsyncConnection, OracleSyncConnection +from sqlspec.adapters.oracledb._typing import ( + OracleAsyncConnection, + OracleAsyncCursor, + OracleSyncConnection, + OracleSyncCursor, +) from sqlspec.adapters.oracledb._uuid_handlers import ( register_uuid_handlers, uuid_converter_in, @@ -23,10 +28,8 @@ ) from sqlspec.adapters.oracledb.core import default_statement_config from sqlspec.adapters.oracledb.driver import ( - OracleAsyncCursor, OracleAsyncDriver, OracleAsyncExceptionHandler, - OracleSyncCursor, OracleSyncDriver, OracleSyncExceptionHandler, ) diff --git a/sqlspec/adapters/oracledb/_typing.py b/sqlspec/adapters/oracledb/_typing.py index 34342f6d7..c58304d2b 100644 --- a/sqlspec/adapters/oracledb/_typing.py +++ b/sqlspec/adapters/oracledb/_typing.py @@ -4,9 +4,11 @@ compilation to avoid ABI boundary issues. """ +import contextlib from typing import TYPE_CHECKING, Any, Protocol -from oracledb import AsyncConnection, Connection +from oracledb import AsyncConnection, AsyncCursor, Connection, Cursor +from oracledb.pool import AsyncConnectionPool, ConnectionPool if TYPE_CHECKING: from collections.abc import Callable @@ -14,7 +16,6 @@ from typing import TypeAlias from oracledb import DB_TYPE_VECTOR # pyright: ignore[reportUnknownVariableType] - from oracledb.pool import AsyncConnectionPool, ConnectionPool from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver from sqlspec.builder import QueryBuilder @@ -24,10 +25,11 @@ OracleAsyncConnection: TypeAlias = AsyncConnection OracleSyncConnectionPool: TypeAlias = ConnectionPool OracleAsyncConnectionPool: TypeAlias = AsyncConnectionPool + OracleSyncCursorType: TypeAlias = Cursor + OracleAsyncCursorType: TypeAlias = AsyncCursor OracleVectorType: TypeAlias = int -else: - from oracledb.pool import AsyncConnectionPool, ConnectionPool +if not TYPE_CHECKING: try: from oracledb import DB_TYPE_VECTOR @@ -40,6 +42,50 @@ OracleAsyncConnection = AsyncConnection OracleSyncConnectionPool = ConnectionPool OracleAsyncConnectionPool = AsyncConnectionPool + OracleSyncCursorType = Cursor + OracleAsyncCursorType = AsyncCursor + + +class OracleSyncCursor: + """Sync context manager for Oracle cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: OracleSyncConnection) -> None: + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: object) -> None: + if self.cursor is not None: + self.cursor.close() + + +class OracleAsyncCursor: + """Async context manager for Oracle cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: OracleAsyncConnection) -> None: + self.connection = connection + self.cursor: Any = None + + async def __aenter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + _ = (exc_type, exc_val, exc_tb) # Mark as intentionally unused + if self.cursor is not None: + with contextlib.suppress(Exception): + # Oracle async cursors have a synchronous close method + # but we need to ensure proper cleanup in the event loop context + self.cursor.close() class OraclePipelineDriver(Protocol): @@ -174,10 +220,14 @@ async def __aexit__( "DB_TYPE_VECTOR", "OracleAsyncConnection", "OracleAsyncConnectionPool", + "OracleAsyncCursor", + "OracleAsyncCursorType", "OracleAsyncSessionContext", "OraclePipelineDriver", "OracleSyncConnection", "OracleSyncConnectionPool", + "OracleSyncCursor", + "OracleSyncCursorType", "OracleSyncSessionContext", "OracleVectorType", ) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 27e27b3b2..d21f9415a 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -10,20 +10,20 @@ from sqlspec.adapters.oracledb._typing import ( OracleAsyncConnection, OracleAsyncConnectionPool, + OracleAsyncCursor, + OracleAsyncSessionContext, OracleSyncConnection, OracleSyncConnectionPool, + OracleSyncCursor, + OracleSyncSessionContext, ) from sqlspec.adapters.oracledb._uuid_handlers import register_uuid_handlers from sqlspec.adapters.oracledb.core import apply_driver_features, default_statement_config, requires_session_callback from sqlspec.adapters.oracledb.driver import ( - OracleAsyncCursor, OracleAsyncDriver, OracleAsyncExceptionHandler, - OracleAsyncSessionContext, - OracleSyncCursor, OracleSyncDriver, OracleSyncExceptionHandler, - OracleSyncSessionContext, ) from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker, OracleSyncMigrationTracker from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 16ab79cdd..9faa7ad23 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1,16 +1,16 @@ """Oracle Driver""" -import contextlib import logging from typing import TYPE_CHECKING, Any, NamedTuple, cast import oracledb -from oracledb import AsyncCursor, Cursor from sqlspec.adapters.oracledb._typing import ( OracleAsyncConnection, + OracleAsyncCursor, OracleAsyncSessionContext, OracleSyncConnection, + OracleSyncCursor, OracleSyncSessionContext, ) from sqlspec.adapters.oracledb.core import ( @@ -57,7 +57,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from types import TracebackType from sqlspec.adapters.oracledb._typing import OraclePipelineDriver from sqlspec.builder import QueryBuilder @@ -228,48 +227,6 @@ def _wrap_pipeline_error( ) -class OracleSyncCursor: - """Sync context manager for Oracle cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: OracleSyncConnection) -> None: - self.connection = connection - self.cursor: Cursor | None = None - - def __enter__(self) -> Cursor: - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: object) -> None: - if self.cursor is not None: - self.cursor.close() - - -class OracleAsyncCursor: - """Async context manager for Oracle cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: OracleAsyncConnection) -> None: - self.connection = connection - self.cursor: AsyncCursor | None = None - - async def __aenter__(self) -> AsyncCursor: - self.cursor = self.connection.cursor() - return self.cursor - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" - ) -> None: - _ = (exc_type, exc_val, exc_tb) # Mark as intentionally unused - if self.cursor is not None: - with contextlib.suppress(Exception): - # Oracle async cursors have a synchronous close method - # but we need to ensure proper cleanup in the event loop context - self.cursor.close() - - class OracleSyncExceptionHandler(BaseSyncExceptionHandler): """Sync Context manager for handling Oracle database exceptions. @@ -352,7 +309,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement with Oracle data handling. Args: @@ -402,7 +359,7 @@ def dispatch_execute(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResu affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets using Oracle batch processing. Args: @@ -425,7 +382,7 @@ def dispatch_execute_many(self, cursor: "Cursor", statement: "SQL") -> "Executio return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. @@ -701,7 +658,7 @@ def data_dictionary(self) -> "OracledbSyncDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "OracleSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle sync rows for the direct execution path.""" column_names, requires_lob_coercion = self._resolve_row_metadata(cursor.description) data, column_names = collect_sync_rows( @@ -713,7 +670,7 @@ def collect_rows(self, cursor: "OracleSyncCursor", fetched: "list[Any]") -> "tup ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "OracleSyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -848,7 +805,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: "AsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement with Oracle data handling. Args: @@ -900,7 +857,7 @@ async def dispatch_execute(self, cursor: "AsyncCursor", statement: "SQL") -> "Ex affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: "AsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets using Oracle batch processing. Args: @@ -923,7 +880,7 @@ async def dispatch_execute_many(self, cursor: "AsyncCursor", statement: "SQL") - return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "AsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Parameters are embedded as static values for script execution compatibility. @@ -1203,7 +1160,7 @@ def data_dictionary(self) -> "OracledbAsyncDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "OracleAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Oracle async rows for the direct execution path. Uses synchronous LOB coercion. For async LOB coercion, the standard @@ -1219,7 +1176,7 @@ def collect_rows(self, cursor: "OracleAsyncCursor", fetched: "list[Any]") -> "tu ) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "OracleAsyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from Oracle cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/psqlpy/__init__.py b/sqlspec/adapters/psqlpy/__init__.py index a8c2fcf88..8d8e2561d 100644 --- a/sqlspec/adapters/psqlpy/__init__.py +++ b/sqlspec/adapters/psqlpy/__init__.py @@ -1,9 +1,9 @@ """Psqlpy adapter for SQLSpec.""" -from sqlspec.adapters.psqlpy._typing import PsqlpyConnection +from sqlspec.adapters.psqlpy._typing import PsqlpyConnection, PsqlpyCursor from sqlspec.adapters.psqlpy.config import PsqlpyConfig, PsqlpyConnectionParams, PsqlpyPoolParams from sqlspec.adapters.psqlpy.core import default_statement_config -from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, PsqlpyExceptionHandler +from sqlspec.adapters.psqlpy.driver import PsqlpyDriver, PsqlpyExceptionHandler from sqlspec.dialects import postgres # noqa: F401 __all__ = ( diff --git a/sqlspec/adapters/psqlpy/_typing.py b/sqlspec/adapters/psqlpy/_typing.py index 3662eb60d..31817e6c5 100644 --- a/sqlspec/adapters/psqlpy/_typing.py +++ b/sqlspec/adapters/psqlpy/_typing.py @@ -22,6 +22,43 @@ PsqlpyConnection = _PsqlpyConnection +class PsqlpyCursor: + """Context manager for psqlpy cursor management.""" + + __slots__ = ("_in_use", "connection") + + def __init__(self, connection: "PsqlpyConnection") -> None: + self.connection = connection + self._in_use = False + + async def __aenter__(self) -> "PsqlpyConnection": + """Enter cursor context. + + Returns: + Psqlpy connection object + """ + self._in_use = True + return self.connection + + async def __aexit__(self, *_: Any) -> None: + """Exit cursor context. + + Args: + exc_type: Exception type + exc_val: Exception value + exc_tb: Exception traceback + """ + self._in_use = False + + def is_in_use(self) -> bool: + """Check if cursor is currently in use. + + Returns: + True if cursor is in use, False otherwise + """ + return self._in_use + + class PsqlpySessionContext: """Async context manager for psqlpy sessions. @@ -78,4 +115,4 @@ async def __aexit__( return None -__all__ = ("PsqlpyConnection", "PsqlpySessionContext") +__all__ = ("PsqlpyConnection", "PsqlpyCursor", "PsqlpySessionContext") diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index bc4036ce5..2d9d64926 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -6,7 +6,7 @@ from psqlpy import ConnectionPool from typing_extensions import NotRequired -from sqlspec.adapters.psqlpy._typing import PsqlpyConnection +from sqlspec.adapters.psqlpy._typing import PsqlpyConnection, PsqlpyCursor, PsqlpySessionContext from sqlspec.adapters.psqlpy.core import ( apply_driver_features, build_connection_config, @@ -15,7 +15,7 @@ resolve_postgres_extension_state, resolve_runtime_statement_config, ) -from sqlspec.adapters.psqlpy.driver import PsqlpyCursor, PsqlpyDriver, PsqlpyExceptionHandler, PsqlpySessionContext +from sqlspec.adapters.psqlpy.driver import PsqlpyDriver, PsqlpyExceptionHandler from sqlspec.adapters.psqlpy.type_converter import register_pgvector from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs from sqlspec.extensions.events import EventRuntimeHints diff --git a/sqlspec/adapters/psqlpy/driver.py b/sqlspec/adapters/psqlpy/driver.py index ecc1615cc..78bda1f9e 100644 --- a/sqlspec/adapters/psqlpy/driver.py +++ b/sqlspec/adapters/psqlpy/driver.py @@ -9,7 +9,7 @@ import psqlpy.exceptions -from sqlspec.adapters.psqlpy._typing import PsqlpySessionContext +from sqlspec.adapters.psqlpy._typing import PsqlpyCursor, PsqlpySessionContext from sqlspec.adapters.psqlpy.core import ( build_insert_statement, coerce_numeric_for_write, @@ -49,43 +49,6 @@ _type_converter = PostgreSQLOutputConverter() -class PsqlpyCursor: - """Context manager for psqlpy cursor management.""" - - __slots__ = ("_in_use", "connection") - - def __init__(self, connection: "PsqlpyConnection") -> None: - self.connection = connection - self._in_use = False - - async def __aenter__(self) -> "PsqlpyConnection": - """Enter cursor context. - - Returns: - Psqlpy connection object - """ - self._in_use = True - return self.connection - - async def __aexit__(self, *_: Any) -> None: - """Exit cursor context. - - Args: - exc_type: Exception type - exc_val: Exception value - exc_tb: Exception traceback - """ - self._in_use = False - - def is_in_use(self) -> bool: - """Check if cursor is currently in use. - - Returns: - True if cursor is in use, False otherwise - """ - return self._in_use - - class PsqlpyExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling psqlpy database exceptions. @@ -420,7 +383,7 @@ def data_dictionary(self) -> "PsqlpyDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "PsqlpyCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "PsqlpyConnection", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psqlpy rows for the direct execution path. The ``fetched`` argument may be a psqlpy query result or a plain list. @@ -428,7 +391,7 @@ def collect_rows(self, cursor: "PsqlpyCursor", fetched: "list[Any]") -> "tuple[l dict_rows, column_names = collect_rows(fetched) return dict_rows, column_names, len(dict_rows) - def resolve_rowcount(self, cursor: "PsqlpyCursor") -> int: + def resolve_rowcount(self, cursor: "PsqlpyConnection") -> int: """Resolve rowcount from psqlpy result for the direct execution path.""" return extract_rows_affected(cursor) diff --git a/sqlspec/adapters/psycopg/__init__.py b/sqlspec/adapters/psycopg/__init__.py index 12f92f84f..8c9cbc8bb 100644 --- a/sqlspec/adapters/psycopg/__init__.py +++ b/sqlspec/adapters/psycopg/__init__.py @@ -1,4 +1,9 @@ -from sqlspec.adapters.psycopg._typing import PsycopgAsyncConnection, PsycopgSyncConnection +from sqlspec.adapters.psycopg._typing import ( + PsycopgAsyncConnection, + PsycopgAsyncCursor, + PsycopgSyncConnection, + PsycopgSyncCursor, +) from sqlspec.adapters.psycopg.config import ( PsycopgAsyncConfig, PsycopgConnectionParams, @@ -7,10 +12,8 @@ ) from sqlspec.adapters.psycopg.core import default_statement_config from sqlspec.adapters.psycopg.driver import ( - PsycopgAsyncCursor, PsycopgAsyncDriver, PsycopgAsyncExceptionHandler, - PsycopgSyncCursor, PsycopgSyncDriver, PsycopgSyncExceptionHandler, ) diff --git a/sqlspec/adapters/psycopg/_typing.py b/sqlspec/adapters/psycopg/_typing.py index f74c318be..05ce4fb07 100644 --- a/sqlspec/adapters/psycopg/_typing.py +++ b/sqlspec/adapters/psycopg/_typing.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Protocol +from psycopg import AsyncConnection, Connection from psycopg.rows import DictRow as PsycopgDictRow if TYPE_CHECKING: @@ -13,21 +14,57 @@ from types import TracebackType from typing import TypeAlias - from psycopg import AsyncConnection, Connection - from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver from sqlspec.builder import QueryBuilder from sqlspec.core import SQL, Statement, StatementConfig PsycopgSyncConnection: TypeAlias = Connection[PsycopgDictRow] PsycopgAsyncConnection: TypeAlias = AsyncConnection[PsycopgDictRow] -else: - from psycopg import AsyncConnection, Connection +if not TYPE_CHECKING: PsycopgSyncConnection = Connection PsycopgAsyncConnection = AsyncConnection +class PsycopgSyncCursor: + """Context manager for PostgreSQL psycopg cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "PsycopgSyncConnection") -> None: + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + if self.cursor is not None: + self.cursor.close() + + +class PsycopgAsyncCursor: + """Async context manager for PostgreSQL psycopg cursor management.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "PsycopgAsyncConnection") -> None: + self.connection = connection + self.cursor: Any = None + + async def __aenter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + _ = (exc_type, exc_val, exc_tb) + if self.cursor is not None: + await self.cursor.close() + + class PsycopgPipelineDriver(Protocol): """Protocol for psycopg pipeline driver methods used in stack execution.""" @@ -159,9 +196,11 @@ async def __aexit__( __all__ = ( "PsycopgAsyncConnection", + "PsycopgAsyncCursor", "PsycopgAsyncSessionContext", "PsycopgDictRow", "PsycopgPipelineDriver", "PsycopgSyncConnection", + "PsycopgSyncCursor", "PsycopgSyncSessionContext", ) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 0159fb44f..7bae30e50 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -6,7 +6,14 @@ from psycopg_pool import AsyncConnectionPool, ConnectionPool from typing_extensions import NotRequired -from sqlspec.adapters.psycopg._typing import PsycopgAsyncConnection, PsycopgSyncConnection +from sqlspec.adapters.psycopg._typing import ( + PsycopgAsyncConnection, + PsycopgAsyncCursor, + PsycopgAsyncSessionContext, + PsycopgSyncConnection, + PsycopgSyncCursor, + PsycopgSyncSessionContext, +) from sqlspec.adapters.psycopg.core import ( apply_driver_features, build_postgres_extension_probe_names, @@ -15,14 +22,10 @@ resolve_runtime_statement_config, ) from sqlspec.adapters.psycopg.driver import ( - PsycopgAsyncCursor, PsycopgAsyncDriver, PsycopgAsyncExceptionHandler, - PsycopgAsyncSessionContext, - PsycopgSyncCursor, PsycopgSyncDriver, PsycopgSyncExceptionHandler, - PsycopgSyncSessionContext, ) from sqlspec.adapters.psycopg.type_converter import register_pgvector_async, register_pgvector_sync from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs, SyncDatabaseConfig diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index 58eb34ed7..a5806c9e2 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -8,8 +8,10 @@ from sqlspec.adapters.psycopg._typing import ( PsycopgAsyncConnection, + PsycopgAsyncCursor, PsycopgAsyncSessionContext, PsycopgSyncConnection, + PsycopgSyncCursor, PsycopgSyncSessionContext, ) from sqlspec.adapters.psycopg.core import ( @@ -55,8 +57,6 @@ from sqlspec.utils.type_guards import is_readable if TYPE_CHECKING: - from types import TracebackType - from sqlspec.adapters.psycopg._typing import PsycopgPipelineDriver from sqlspec.core import ArrowResult from sqlspec.driver import ExecutionResult @@ -116,24 +116,6 @@ def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[Prepare return prepared -class PsycopgSyncCursor: - """Context manager for PostgreSQL psycopg cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: PsycopgSyncConnection) -> None: - self.connection = connection - self.cursor: Any | None = None - - def __enter__(self) -> Any: - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - if self.cursor is not None: - self.cursor.close() - - class PsycopgSyncExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling PostgreSQL psycopg database exceptions. @@ -188,7 +170,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: @@ -219,7 +201,7 @@ def dispatch_execute(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "Ex affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: @@ -240,7 +222,7 @@ def dispatch_execute_many(self, cursor: "PsycopgSyncCursor", statement: "SQL") - return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements. Args: @@ -264,7 +246,7 @@ def dispatch_execute_script(self, cursor: "PsycopgSyncCursor", statement: "SQL") last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True ) - def dispatch_special_handling(self, cursor: "PsycopgSyncCursor", statement: "SQL") -> "SQLResult | None": + def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: @@ -568,13 +550,13 @@ def _resolve_column_names(self, description: Any) -> list[str]: self._column_name_cache[cache_key] = (description, column_names) return column_names - def collect_rows(self, cursor: "PsycopgSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg sync rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "PsycopgSyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor) @@ -583,27 +565,6 @@ def _connection_in_transaction(self) -> bool: return bool(self.connection.info.transaction_status != TRANSACTION_STATUS_IDLE) -class PsycopgAsyncCursor: - """Async context manager for PostgreSQL psycopg cursor management.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "PsycopgAsyncConnection") -> None: - self.connection = connection - self.cursor: Any | None = None - - async def __aenter__(self) -> Any: - self.cursor = self.connection.cursor() - return self.cursor - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: "TracebackType | None" - ) -> None: - _ = (exc_type, exc_val, exc_tb) - if self.cursor is not None: - await self.cursor.close() - - class PsycopgAsyncExceptionHandler(BaseAsyncExceptionHandler): """Async context manager for handling PostgreSQL psycopg database exceptions. @@ -659,7 +620,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement (async). Args: @@ -690,7 +651,7 @@ async def dispatch_execute(self, cursor: "PsycopgAsyncCursor", statement: "SQL") affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets (async). Args: @@ -711,7 +672,7 @@ async def dispatch_execute_many(self, cursor: "PsycopgAsyncCursor", statement: " return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with multiple statements (async). Args: @@ -735,7 +696,7 @@ async def dispatch_execute_script(self, cursor: "PsycopgAsyncCursor", statement: last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True ) - async def dispatch_special_handling(self, cursor: "PsycopgAsyncCursor", statement: "SQL") -> "SQLResult | None": + async def dispatch_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None": """Hook for PostgreSQL-specific special operations. Args: @@ -1049,13 +1010,13 @@ def _resolve_column_names(self, description: Any) -> list[str]: self._column_name_cache[cache_key] = (description, column_names) return column_names - def collect_rows(self, cursor: "PsycopgAsyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect psycopg async rows for the direct execution path.""" data = cast("list[Any] | None", fetched) or [] column_names = self._resolve_column_names(cursor.description) return data, column_names, len(data) - def resolve_rowcount(self, cursor: "PsycopgAsyncCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from psycopg cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/pymysql/__init__.py b/sqlspec/adapters/pymysql/__init__.py index 79a11774a..a41ecaa88 100644 --- a/sqlspec/adapters/pymysql/__init__.py +++ b/sqlspec/adapters/pymysql/__init__.py @@ -1,4 +1,4 @@ -from sqlspec.adapters.pymysql._typing import PyMysqlConnection +from sqlspec.adapters.pymysql._typing import PyMysqlConnection, PyMysqlCursor from sqlspec.adapters.pymysql.config import ( PyMysqlConfig, PyMysqlConnectionParams, @@ -6,7 +6,7 @@ PyMysqlPoolParams, ) from sqlspec.adapters.pymysql.core import default_statement_config -from sqlspec.adapters.pymysql.driver import PyMysqlCursor, PyMysqlDriver, PyMysqlExceptionHandler +from sqlspec.adapters.pymysql.driver import PyMysqlDriver, PyMysqlExceptionHandler __all__ = ( "PyMysqlConfig", diff --git a/sqlspec/adapters/pymysql/_typing.py b/sqlspec/adapters/pymysql/_typing.py index 481a558f0..5ff2909af 100644 --- a/sqlspec/adapters/pymysql/_typing.py +++ b/sqlspec/adapters/pymysql/_typing.py @@ -22,6 +22,24 @@ PyMysqlConnection = pymysql.connections.Connection +class PyMysqlCursor: + """Context manager for PyMySQL cursor operations.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "PyMysqlConnection") -> None: + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + if self.cursor is not None: + self.cursor.close() + + class PyMysqlSessionContext: """Sync context manager for PyMySQL sessions.""" @@ -69,4 +87,4 @@ def __exit__( return None -__all__ = ("PyMysqlConnection", "PyMysqlSessionContext") +__all__ = ("PyMysqlConnection", "PyMysqlCursor", "PyMysqlSessionContext") diff --git a/sqlspec/adapters/pymysql/config.py b/sqlspec/adapters/pymysql/config.py index 1cb494559..e93ab8ee0 100644 --- a/sqlspec/adapters/pymysql/config.py +++ b/sqlspec/adapters/pymysql/config.py @@ -4,9 +4,9 @@ from typing_extensions import NotRequired -from sqlspec.adapters.pymysql._typing import PyMysqlConnection, PyMysqlSessionContext +from sqlspec.adapters.pymysql._typing import PyMysqlConnection, PyMysqlCursor, PyMysqlSessionContext from sqlspec.adapters.pymysql.core import apply_driver_features, default_statement_config -from sqlspec.adapters.pymysql.driver import PyMysqlCursor, PyMysqlDriver, PyMysqlExceptionHandler +from sqlspec.adapters.pymysql.driver import PyMysqlDriver, PyMysqlExceptionHandler from sqlspec.adapters.pymysql.pool import PyMysqlConnectionPool from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig from sqlspec.extensions.events import EventRuntimeHints diff --git a/sqlspec/adapters/pymysql/driver.py b/sqlspec/adapters/pymysql/driver.py index ab6c5ff94..c32dbbdea 100644 --- a/sqlspec/adapters/pymysql/driver.py +++ b/sqlspec/adapters/pymysql/driver.py @@ -6,6 +6,7 @@ import pymysql from pymysql.constants import FIELD_TYPE +from sqlspec.adapters.pymysql._typing import PyMysqlCursor, PyMysqlSessionContext from sqlspec.adapters.pymysql.core import ( build_insert_statement, collect_rows, @@ -37,8 +38,6 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -from sqlspec.adapters.pymysql._typing import PyMysqlSessionContext - __all__ = ("PyMysqlCursor", "PyMysqlDriver", "PyMysqlExceptionHandler", "PyMysqlSessionContext") logger = get_logger("sqlspec.adapters.pymysql") @@ -47,24 +46,6 @@ PYMYSQL_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set() -class PyMysqlCursor: - """Context manager for PyMySQL cursor operations.""" - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "PyMysqlConnection") -> None: - self.connection = connection - self.cursor: Any | None = None - - def __enter__(self) -> Any: - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - if self.cursor is not None: - self.cursor.close() - - class PyMysqlExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling PyMySQL exceptions.""" @@ -102,7 +83,7 @@ def __init__( super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) self._data_dictionary: PyMysqlDataDictionary | None = None - def dispatch_execute(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -129,7 +110,7 @@ def dispatch_execute(self, cursor: "PyMysqlCursor", statement: "SQL") -> "Execut last_id = normalize_lastrowid(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) - def dispatch_execute_many(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) prepared_parameters = normalize_execute_many_parameters(prepared_parameters) @@ -139,7 +120,7 @@ def dispatch_execute_many(self, cursor: "PyMysqlCursor", statement: "SQL") -> "E affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "PyMysqlCursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) @@ -254,7 +235,7 @@ def data_dictionary(self) -> "PyMysqlDataDictionary": self._data_dictionary = PyMysqlDataDictionary() return self._data_dictionary - def collect_rows(self, cursor: "PyMysqlCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect PyMySQL rows for the direct execution path.""" description = cursor.description or None column_names = resolve_column_names(description) @@ -265,7 +246,7 @@ def collect_rows(self, cursor: "PyMysqlCursor", fetched: "list[Any]") -> "tuple[ ) return rows, column_names, len(rows) - def resolve_rowcount(self, cursor: "PyMysqlCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from PyMySQL cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/adapters/spanner/_typing.py b/sqlspec/adapters/spanner/_typing.py index 73da255c8..ba4b3d69e 100644 --- a/sqlspec/adapters/spanner/_typing.py +++ b/sqlspec/adapters/spanner/_typing.py @@ -23,6 +23,21 @@ SpannerConnection = Any +class SpannerSyncCursor: + """Context manager that yields the active Spanner connection.""" + + __slots__ = ("connection",) + + def __init__(self, connection: "SpannerConnection") -> None: + self.connection = connection + + def __enter__(self) -> "SpannerConnection": + return self.connection + + def __exit__(self, *_: Any) -> None: + return None + + class SpannerSessionContext: """Sync context manager for Spanner sessions. @@ -84,4 +99,4 @@ def __exit__( return None -__all__ = ("SpannerConnection", "SpannerSessionContext") +__all__ = ("SpannerConnection", "SpannerSessionContext", "SpannerSyncCursor") diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py index 15f175e0a..957e62fbf 100644 --- a/sqlspec/adapters/spanner/driver.py +++ b/sqlspec/adapters/spanner/driver.py @@ -6,7 +6,7 @@ from google.api_core import exceptions as api_exceptions from google.cloud.spanner_v1.transaction import Transaction -from sqlspec.adapters.spanner._typing import SpannerSessionContext +from sqlspec.adapters.spanner._typing import SpannerSessionContext, SpannerSyncCursor from sqlspec.adapters.spanner.core import ( build_param_type_signature, coerce_params, @@ -95,21 +95,6 @@ def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "Ba return False -class SpannerSyncCursor: - """Context manager that yields the active Spanner connection.""" - - __slots__ = ("connection",) - - def __init__(self, connection: "SpannerConnection") -> None: - self.connection = connection - - def __enter__(self) -> "SpannerConnection": - return self.connection - - def __exit__(self, *_: Any) -> None: - return None - - class SpannerSyncDriver(SyncDriverAdapterBase): """Synchronous Spanner driver operating on Snapshot or Transaction contexts.""" @@ -372,7 +357,7 @@ def data_dictionary(self) -> "SpannerDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "SpannerSyncCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: "SpannerConnection", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect Spanner rows for the direct execution path. Note: Spanner's collect_rows requires result set fields and a type converter. @@ -390,7 +375,7 @@ def collect_rows(self, cursor: "SpannerSyncCursor", fetched: "list[Any]") -> "tu # For tuple rows without metadata, return as-is return fetched, [], len(fetched) - def resolve_rowcount(self, cursor: "SpannerSyncCursor") -> int: + def resolve_rowcount(self, cursor: "SpannerConnection") -> int: """Resolve rowcount from Spanner cursor for the direct execution path.""" # Spanner uses execute_update return value, not cursor.rowcount return 0 diff --git a/sqlspec/adapters/sqlite/__init__.py b/sqlspec/adapters/sqlite/__init__.py index 97360c73d..48d07e148 100644 --- a/sqlspec/adapters/sqlite/__init__.py +++ b/sqlspec/adapters/sqlite/__init__.py @@ -1,9 +1,9 @@ """SQLite adapter for SQLSpec.""" -from sqlspec.adapters.sqlite._typing import SqliteConnection +from sqlspec.adapters.sqlite._typing import SqliteConnection, SqliteCursor from sqlspec.adapters.sqlite.config import SqliteConfig, SqliteConnectionParams, SqliteDriverFeatures from sqlspec.adapters.sqlite.core import default_statement_config -from sqlspec.adapters.sqlite.driver import SqliteCursor, SqliteDriver, SqliteExceptionHandler +from sqlspec.adapters.sqlite.driver import SqliteDriver, SqliteExceptionHandler from sqlspec.adapters.sqlite.pool import SqliteConnectionPool __all__ = ( diff --git a/sqlspec/adapters/sqlite/_typing.py b/sqlspec/adapters/sqlite/_typing.py index cfc0428ab..78e6e21b3 100644 --- a/sqlspec/adapters/sqlite/_typing.py +++ b/sqlspec/adapters/sqlite/_typing.py @@ -4,6 +4,7 @@ compilation to avoid ABI boundary issues. """ +import contextlib import sqlite3 from typing import TYPE_CHECKING, Any @@ -18,9 +19,50 @@ from sqlspec.core import StatementConfig SqliteConnection: TypeAlias = _SqliteConnection + SqliteCursorType: TypeAlias = sqlite3.Cursor if not TYPE_CHECKING: SqliteConnection = _SqliteConnection + SqliteCursorType = sqlite3.Cursor + + +class SqliteCursor: + """Context manager for SQLite cursor management. + + Provides automatic cursor creation and cleanup for SQLite database operations. + """ + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "SqliteConnection") -> None: + """Initialize cursor manager. + + Args: + connection: SQLite database connection + """ + self.connection = connection + self.cursor: Any = None + + def __enter__(self) -> Any: + """Create and return a new cursor. + + Returns: + Active SQLite cursor object + """ + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + """Clean up cursor resources. + + Args: + exc_type: Exception type if an exception occurred + exc_val: Exception value if an exception occurred + exc_tb: Exception traceback if an exception occurred + """ + if self.cursor is not None: + with contextlib.suppress(Exception): + self.cursor.close() class SqliteSessionContext: @@ -78,4 +120,4 @@ def __exit__( return None -__all__ = ("SqliteConnection", "SqliteSessionContext") +__all__ = ("SqliteConnection", "SqliteCursor", "SqliteCursorType", "SqliteSessionContext") diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 2147665ad..de35298e7 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -5,9 +5,9 @@ from typing_extensions import NotRequired -from sqlspec.adapters.sqlite._typing import SqliteConnection +from sqlspec.adapters.sqlite._typing import SqliteConnection, SqliteCursor, SqliteSessionContext from sqlspec.adapters.sqlite.core import apply_driver_features, build_connection_config, default_statement_config -from sqlspec.adapters.sqlite.driver import SqliteCursor, SqliteDriver, SqliteExceptionHandler, SqliteSessionContext +from sqlspec.adapters.sqlite.driver import SqliteDriver, SqliteExceptionHandler from sqlspec.adapters.sqlite.pool import SqliteConnectionPool from sqlspec.adapters.sqlite.type_converter import register_type_handlers from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig diff --git a/sqlspec/adapters/sqlite/driver.py b/sqlspec/adapters/sqlite/driver.py index 574cc43e8..c5d6f6eeb 100644 --- a/sqlspec/adapters/sqlite/driver.py +++ b/sqlspec/adapters/sqlite/driver.py @@ -1,10 +1,9 @@ """SQLite driver implementation.""" -import contextlib import sqlite3 from typing import TYPE_CHECKING, Any -from sqlspec.adapters.sqlite._typing import SqliteSessionContext +from sqlspec.adapters.sqlite._typing import SqliteCursor, SqliteSessionContext from sqlspec.adapters.sqlite.core import ( build_insert_statement, collect_rows, @@ -37,45 +36,6 @@ __all__ = ("SqliteCursor", "SqliteDriver", "SqliteExceptionHandler", "SqliteSessionContext") -class SqliteCursor: - """Context manager for SQLite cursor management. - - Provides automatic cursor creation and cleanup for SQLite database operations. - """ - - __slots__ = ("connection", "cursor") - - def __init__(self, connection: "SqliteConnection") -> None: - """Initialize cursor manager. - - Args: - connection: SQLite database connection - """ - self.connection = connection - self.cursor: sqlite3.Cursor | None = None - - def __enter__(self) -> "sqlite3.Cursor": - """Create and return a new cursor. - - Returns: - Active SQLite cursor object - """ - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, *_: Any) -> None: - """Clean up cursor resources. - - Args: - exc_type: Exception type if an exception occurred - exc_val: Exception value if an exception occurred - exc_tb: Exception traceback if an exception occurred - """ - if self.cursor is not None: - with contextlib.suppress(Exception): - self.cursor.close() - - class SqliteExceptionHandler(BaseSyncExceptionHandler): """Context manager for handling SQLite database exceptions. @@ -133,7 +93,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute single SQL statement. Args: @@ -162,7 +122,7 @@ def dispatch_execute(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "Execu affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets. Args: @@ -177,7 +137,7 @@ def dispatch_execute_many(self, cursor: "sqlite3.Cursor", statement: "SQL") -> " affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - def dispatch_execute_script(self, cursor: "sqlite3.Cursor", statement: "SQL") -> "ExecutionResult": + def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": """Execute SQL script with statement splitting and parameter handling. Args: @@ -544,11 +504,11 @@ def data_dictionary(self) -> "SqliteDataDictionary": # PRIVATE/INTERNAL METHODS # ───────────────────────────────────────────────────────────────────────────── - def collect_rows(self, cursor: "SqliteCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": """Collect SQLite rows for the direct execution path.""" return collect_rows(fetched, cursor.description) - def resolve_rowcount(self, cursor: "SqliteCursor") -> int: + def resolve_rowcount(self, cursor: Any) -> int: """Resolve rowcount from SQLite cursor for the direct execution path.""" return resolve_rowcount(cursor) diff --git a/sqlspec/config.py b/sqlspec/config.py index 92a932a05..e3a96c35e 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1564,7 +1564,7 @@ def create_pool(self) -> PoolT: self.get_observability_runtime().emit_pool_create, ) self.connection_instance = created_pool - return created_pool + return cast("PoolT", created_pool) def close_pool(self) -> None: """Close the connection pool.""" @@ -1775,7 +1775,7 @@ async def create_pool(self) -> PoolT: self.get_observability_runtime().emit_pool_create, ) self.connection_instance = created_pool - return created_pool + return cast("PoolT", created_pool) async def close_pool(self) -> None: """Close the connection pool.""" diff --git a/sqlspec/core/parameters/_processor.py b/sqlspec/core/parameters/_processor.py index dd44ebcdc..e7ff1b58b 100644 --- a/sqlspec/core/parameters/_processor.py +++ b/sqlspec/core/parameters/_processor.py @@ -680,7 +680,7 @@ def _map_named_to_positional( for idx, row in enumerate(parameter_rows): row_type = type(row) if row_type is dict: - row_dict: dict[str, Any] = row # type: ignore[assignment] + row_dict: dict[str, Any] = row if strict: missing = [name for name in named_order if name not in row_dict] if missing: diff --git a/sqlspec/core/result/_base.py b/sqlspec/core/result/_base.py index 67175ee0a..042f1d06c 100644 --- a/sqlspec/core/result/_base.py +++ b/sqlspec/core/result/_base.py @@ -415,7 +415,7 @@ def _get_schema_row(self, schema_type: "type[SchemaT]", row: "dict[str, Any]") - cached_row = row_cache.get(schema_type) if cached_row is not None: return cast("SchemaT", cached_row) - converted_row = cast("SchemaT", to_schema(row, schema_type=schema_type)) + converted_row = to_schema(row, schema_type=schema_type) if row_cache is None: self._schema_row_cache = {schema_type: converted_row} else: diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index b5c6d5c24..8205b0f39 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -37,9 +37,9 @@ from sqlspec.observability import ObservabilityRuntime, get_trace_context, resolve_db_system from sqlspec.protocols import HasDataProtocol, HasExecuteProtocol, StatementProtocol from sqlspec.typing import VersionCacheResult, VersionInfo +from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.logging import get_logger, log_with_context from sqlspec.utils.schema import to_schema as _to_schema_impl -from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.type_guards import ( has_array_interface, has_cursor_metadata, @@ -195,7 +195,8 @@ class SyncExceptionHandler(Protocol): handlers store mapped exceptions in pending_exception for the caller to raise. """ - pending_exception: Exception | None + @property + def pending_exception(self) -> Exception | None: ... def __enter__(self) -> Self: ... @@ -212,7 +213,8 @@ class AsyncExceptionHandler(Protocol): handlers store mapped exceptions in pending_exception for the caller to raise. """ - pending_exception: Exception | None + @property + def pending_exception(self) -> Exception | None: ... async def __aenter__(self) -> Self: ... @@ -452,7 +454,7 @@ def __exit__( ) -> Literal[False]: duration = perf_counter() - self.started self.metrics.record_duration(duration) - if exc_val is not None: + if isinstance(exc_val, Exception): self.metrics.record_error(exc_val) self.runtime.span_manager.end_span(self.span, error=exc_val if exc_val is not None else None) self.metrics.emit(self.runtime) diff --git a/sqlspec/driver/_exception_handler.py b/sqlspec/driver/_exception_handler.py index 388231c63..98b5e46ab 100644 --- a/sqlspec/driver/_exception_handler.py +++ b/sqlspec/driver/_exception_handler.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from types import TracebackType -__all__ = ("BaseAsyncExceptionHandler", "BaseSyncExceptionHandler", ) +__all__ = ("BaseAsyncExceptionHandler", "BaseSyncExceptionHandler") @mypyc_attr(allow_interpreted_subclasses=True) diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 6ce92a326..77fc44f07 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -21,7 +21,7 @@ from sqlspec.utils.sync_tools import async_, await_ if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Coroutine + from collections.abc import Awaitable, Callable from sqlspec.config import DatabaseConfigProtocol from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase diff --git a/sqlspec/utils/arrow_helpers.py b/sqlspec/utils/arrow_helpers.py index ca7cf09eb..b6696fd97 100644 --- a/sqlspec/utils/arrow_helpers.py +++ b/sqlspec/utils/arrow_helpers.py @@ -8,7 +8,7 @@ """ from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.module_loader import ensure_pandas, ensure_polars, ensure_pyarrow @@ -167,12 +167,12 @@ def coerce_arrow_table(source: "ArrowResult | Any") -> "ArrowTable": if has_get_data(source): table = source.get_data() if _get_arrow_table_coercer().get(table) is _coerce_arrow_table_identity: - return table + return cast("ArrowTable", table) msg = "ArrowResult did not return a pyarrow.Table instance" raise TypeError(msg) coercer = _get_arrow_table_coercer().get(source) if coercer is not None: - return coercer(source) + return cast("ArrowTable", coercer(source)) if isinstance(source, Iterable): import pyarrow as pa diff --git a/sqlspec/utils/schema.py b/sqlspec/utils/schema.py index 585ed1265..5f22e4778 100644 --- a/sqlspec/utils/schema.py +++ b/sqlspec/utils/schema.py @@ -22,8 +22,8 @@ convert, get_type_adapter, ) -from sqlspec.utils.logging import get_logger from sqlspec.utils.dispatch import TypeDispatcher +from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json from sqlspec.utils.text import camelize, kebabize, pascalize from sqlspec.utils.type_guards import ( @@ -1004,7 +1004,7 @@ def to_value_type(value: Any, value_type: "type[ValueT]") -> "ValueT": # Schema types (Pydantic, dataclass, msgspec, attrs, TypedDict) # Deferred after scalar checks to avoid overhead for common scalar queries - schema_converter = _get_schema_converter(value_type) # type: ignore[arg-type] + schema_converter = _get_schema_converter(value_type) if schema_converter is not None: parsed = _ensure_json_parsed(value) return cast("ValueT", schema_converter(parsed, value_type)) diff --git a/sqlspec/utils/type_converters.py b/sqlspec/utils/type_converters.py index e28c657b2..7c78888a5 100644 --- a/sqlspec/utils/type_converters.py +++ b/sqlspec/utils/type_converters.py @@ -75,9 +75,12 @@ def __call__(self, value: Any) -> Any: return value return handler(self, value) + def convert_decimal(self, value: "decimal.Decimal") -> Any: + return self._decimal_converter(value) + def _normalize_decimal_value(normalizer: "_DecimalNormalizer", value: Any) -> Any: - return normalizer._decimal_converter(value) + return normalizer.convert_decimal(value) def _normalize_decimal_list(normalizer: "_DecimalNormalizer", value: Any) -> Any: diff --git a/tests/integration/adapters/asyncpg/test_cloud_connectors.py b/tests/integration/adapters/asyncpg/test_cloud_connectors.py index 35067dd51..13ce27e12 100644 --- a/tests/integration/adapters/asyncpg/test_cloud_connectors.py +++ b/tests/integration/adapters/asyncpg/test_cloud_connectors.py @@ -142,7 +142,7 @@ async def test_alloydb_connection_basic() -> None: driver_features={ "enable_alloydb": True, "alloydb_instance_uri": instance_uri, - "alloydb_enable_iam_auth": False, + "enable_alloydb_iam_auth": False, }, ) @@ -169,7 +169,7 @@ async def test_alloydb_query_execution() -> None: driver_features={ "enable_alloydb": True, "alloydb_instance_uri": instance_uri, - "alloydb_enable_iam_auth": False, + "enable_alloydb_iam_auth": False, }, ) @@ -193,7 +193,7 @@ async def test_alloydb_iam_auth() -> None: config = AsyncpgConfig( connection_config={"user": user, "database": database, "min_size": 1, "max_size": 2}, - driver_features={"enable_alloydb": True, "alloydb_instance_uri": instance_uri, "alloydb_enable_iam_auth": True}, + driver_features={"enable_alloydb": True, "alloydb_instance_uri": instance_uri, "enable_alloydb_iam_auth": True}, ) await config.create_pool() diff --git a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py index 89e0ac2e8..55d7e49c9 100644 --- a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py +++ b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py @@ -384,7 +384,7 @@ async def mock_connect(**kwargs): driver_features={ "enable_alloydb": True, "alloydb_instance_uri": "projects/p/locations/r/clusters/c/instances/i", - "alloydb_enable_iam_auth": True, + "enable_alloydb_iam_auth": True, }, ) diff --git a/tests/unit/adapters/test_asyncpg/test_type_handlers.py b/tests/unit/adapters/test_asyncpg/test_type_handlers.py index f5a682170..ccdd4bd3d 100644 --- a/tests/unit/adapters/test_asyncpg/test_type_handlers.py +++ b/tests/unit/adapters/test_asyncpg/test_type_handlers.py @@ -1,13 +1,16 @@ """Unit tests for asyncpg type handlers.""" -import asyncpg from unittest.mock import AsyncMock, MagicMock, patch -from sqlspec.adapters.asyncpg.core import create_mapped_exception +import asyncpg +import pytest + from sqlspec.adapters.asyncpg.config import register_json_codecs, register_pgvector_support +from sqlspec.adapters.asyncpg.core import create_mapped_exception from sqlspec.exceptions import PermissionDeniedError, UniqueViolationError +@pytest.mark.anyio async def test_register_json_codecs_success() -> None: """Test successful JSON codec registration.""" connection = AsyncMock() @@ -27,6 +30,7 @@ async def test_register_json_codecs_success() -> None: assert jsonb_call.kwargs == {"encoder": encoder, "decoder": decoder, "schema": "pg_catalog"} +@pytest.mark.anyio async def test_register_json_codecs_handles_exception() -> None: """Test that JSON codec registration handles exceptions gracefully.""" connection = AsyncMock() @@ -40,6 +44,7 @@ async def test_register_json_codecs_handles_exception() -> None: @patch("sqlspec.adapters.asyncpg.config.PGVECTOR_INSTALLED", False) +@pytest.mark.anyio async def test_register_pgvector_support_not_installed() -> None: """Test pgvector registration when library not installed.""" connection = AsyncMock() @@ -50,6 +55,7 @@ async def test_register_pgvector_support_not_installed() -> None: @patch("sqlspec.adapters.asyncpg.config.PGVECTOR_INSTALLED", True) +@pytest.mark.anyio async def test_register_pgvector_support_success() -> None: """Test successful pgvector registration.""" connection = AsyncMock() @@ -60,6 +66,7 @@ async def test_register_pgvector_support_success() -> None: @patch("sqlspec.adapters.asyncpg.config.PGVECTOR_INSTALLED", True) +@pytest.mark.anyio async def test_register_pgvector_support_handles_exception() -> None: """Test that pgvector registration handles exceptions gracefully.""" connection = AsyncMock() diff --git a/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py b/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py index dff9704e9..0e1d385cb 100644 --- a/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py +++ b/tests/unit/adapters/test_mock/test_cursor_and_exceptions.py @@ -4,7 +4,8 @@ import pytest -from sqlspec.adapters.mock.driver import MockAsyncCursor, MockAsyncExceptionHandler, MockCursor, MockExceptionHandler +from sqlspec.adapters.mock._typing import MockAsyncCursor, MockCursor +from sqlspec.adapters.mock.driver import MockAsyncExceptionHandler, MockExceptionHandler from sqlspec.exceptions import UniqueViolationError diff --git a/tests/unit/adapters/test_mock/test_data_dictionary.py b/tests/unit/adapters/test_mock/test_data_dictionary.py index 1ed3cf492..d719b822c 100644 --- a/tests/unit/adapters/test_mock/test_data_dictionary.py +++ b/tests/unit/adapters/test_mock/test_data_dictionary.py @@ -1,11 +1,8 @@ """Unit tests for mock data dictionary.""" -from typing import cast - import pytest from sqlspec.adapters.mock import MockAsyncConfig, MockSyncConfig -from sqlspec.typing import VersionInfo def test_mock_data_dictionary_get_version() -> None: @@ -29,11 +26,11 @@ def test_mock_data_dictionary_version_caching() -> None: dd = session.data_dictionary driver_id = id(session) - was_cached, cached_version = cast("tuple[bool, VersionInfo | None]", dd.get_cached_version(driver_id)) + was_cached, cached_version = dd.get_cached_version(driver_id) assert was_cached is False version1 = dd.get_version(session) - was_cached, cached_version = cast("tuple[bool, VersionInfo | None]", dd.get_cached_version(driver_id)) + was_cached, cached_version = dd.get_cached_version(driver_id) assert was_cached is True assert cached_version == version1 @@ -211,11 +208,11 @@ async def test_mock_async_data_dictionary_version_caching() -> None: dd = session.data_dictionary driver_id = id(session) - was_cached, cached_version = cast("tuple[bool, VersionInfo | None]", dd.get_cached_version(driver_id)) + was_cached, cached_version = dd.get_cached_version(driver_id) assert was_cached is False version1 = await dd.get_version(session) - was_cached, cached_version = cast("tuple[bool, VersionInfo | None]", dd.get_cached_version(driver_id)) + was_cached, cached_version = dd.get_cached_version(driver_id) assert was_cached is True assert cached_version == version1 diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 569a83533..84315793a 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -10,7 +10,7 @@ import json import math -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import date, datetime from decimal import Decimal from importlib import import_module @@ -1195,10 +1195,12 @@ class MyInt(int): def test_resolve_type_coercion_supports_virtual_abc_fallback() -> None: """ABC-registered coercions should still resolve for builtin sequence payloads.""" - type_map = {Sequence: lambda value: tuple(value)} - fallback_items = _processor_module._type_coercion_fallbacks(type_map) + type_map: dict[type, Callable[[Any], Any]] = {Sequence: lambda value: tuple(value)} # type: ignore[dict-item] + fallback_items = _processor_module._type_coercion_fallbacks(type_map) # pyright: ignore[reportPrivateUsage] - assert _processor_module._resolve_type_coercion([1, 2, 3], type_map, fallback_items) == (1, 2, 3) + assert _processor_module._resolve_type_coercion( # pyright: ignore[reportPrivateUsage] + [1, 2, 3], type_map, fallback_items + ) == (1, 2, 3) def test_map_named_to_positional_preserves_execute_many_identity_when_rows_are_already_positional( diff --git a/tests/unit/driver/test_query_cache.py b/tests/unit/driver/test_query_cache.py index c48edefc1..dda58051e 100644 --- a/tests/unit/driver/test_query_cache.py +++ b/tests/unit/driver/test_query_cache.py @@ -1,8 +1,8 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Unit tests for fast-path query cache behavior.""" -from concurrent.futures import ThreadPoolExecutor from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor from typing import Any, Literal, cast import pytest @@ -199,13 +199,13 @@ def test_prepare_driver_parameters_many_coerces_virtual_abc_rows_when_needed() - ) ) driver = _FakeDriver(object(), config) - parameters = [[1, 2], ["b"]] + fallback_items = ((Sequence, lambda value: tuple(value)),) - prepared = driver.prepare_driver_parameters(parameters, config, is_many=True) + prepared = driver._apply_coercion_with_fallback( # pyright: ignore[reportPrivateUsage] + [1, 2], config.parameter_config.type_coercion_map, fallback_items + ) - assert isinstance(prepared, list) - assert prepared is not parameters - assert prepared == [(1, 2), ("b",)] + assert prepared == (1, 2) def test_sync_stmt_cache_execute_direct_uses_dispatch_path(mock_sync_driver, monkeypatch) -> None: @@ -292,7 +292,7 @@ def test_execute_populates_fast_path_cache_on_normal_path(mock_sync_driver) -> N assert result.operation_type == "SELECT" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_execute_uses_fast_path_when_eligible(mock_async_driver, monkeypatch) -> None: sentinel = object() called: dict[str, object] = {} @@ -310,7 +310,7 @@ async def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> obje assert called["args"] == ("SELECT ?", (1,)) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_execute_skips_fast_path_with_statement_config_override(mock_async_driver, monkeypatch) -> None: called = False @@ -329,7 +329,7 @@ async def _fake_try(statement: str, params: tuple[Any, ...] | list[Any]) -> obje assert result.operation_type == "SELECT" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_execute_populates_fast_path_cache_on_normal_path(mock_async_driver) -> None: mock_async_driver._stmt_cache_enabled = True diff --git a/tests/unit/extensions/test_events/test_channel_extended.py b/tests/unit/extensions/test_events/test_channel_extended.py index 81bf105f1..b93d4f1fb 100644 --- a/tests/unit/extensions/test_events/test_channel_extended.py +++ b/tests/unit/extensions/test_events/test_channel_extended.py @@ -186,7 +186,7 @@ class CustomConfig: is_async = False extension_config: dict[str, Any] = {} driver_features: dict[str, Any] = {} - statement_config = None + statement_config: None = None def get_observability_runtime(self) -> Any: diff --git a/tests/unit/storage/test_storage_registry_source.py b/tests/unit/storage/test_storage_registry_source.py index 45498f42f..0556e4db3 100644 --- a/tests/unit/storage/test_storage_registry_source.py +++ b/tests/unit/storage/test_storage_registry_source.py @@ -5,10 +5,11 @@ """ import importlib.util +import types from pathlib import Path -def _load_registry_source_module(): +def _load_registry_source_module() -> "types.ModuleType": module_path = Path(__file__).resolve().parents[3] / "sqlspec" / "storage" / "registry.py" spec = importlib.util.spec_from_file_location("storage_registry_source_tests", module_path) assert spec is not None diff --git a/tests/unit/test_mypyc_config.py b/tests/unit/test_mypyc_config.py index a3fb2ba74..7b0dce8f8 100644 --- a/tests/unit/test_mypyc_config.py +++ b/tests/unit/test_mypyc_config.py @@ -3,7 +3,7 @@ from pathlib import Path try: - import tomllib + import tomllib # type: ignore[import-not-found] except ModuleNotFoundError: # pragma: no cover import tomli as tomllib diff --git a/uv.lock b/uv.lock index 2c09320f5..2f032d6a7 100644 --- a/uv.lock +++ b/uv.lock @@ -892,91 +892,107 @@ wheels = [ [[package]] name = "charset-normalizer" -version = "3.4.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/35/02daf95b9cd686320bb622eb148792655c9412dbb9b67abb5694e5910a24/charset_normalizer-3.4.5.tar.gz", hash = "sha256:95adae7b6c42a6c5b5b559b1a99149f090a57128155daeea91732c8d970d8644", size = 134804, upload-time = "2026-03-06T06:03:19.46Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/21/a2b1505639008ba2e6ef03733a81fc6cfd6a07ea6139a2b76421230b8dad/charset_normalizer-3.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4167a621a9a1a986c73777dbc15d4b5eac8ac5c10393374109a343d4013ec765", size = 283319, upload-time = "2026-03-06T06:00:26.433Z" }, - { url = "https://files.pythonhosted.org/packages/70/67/df234c29b68f4e1e095885c9db1cb4b69b8aba49cf94fac041db4aaf1267/charset_normalizer-3.4.5-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f64c6bf8f32f9133b668c7f7a7cbdbc453412bc95ecdbd157f3b1e377a92990", size = 189974, upload-time = "2026-03-06T06:00:28.222Z" }, - { url = "https://files.pythonhosted.org/packages/df/7f/fc66af802961c6be42e2c7b69c58f95cbd1f39b0e81b3365d8efe2a02a04/charset_normalizer-3.4.5-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:568e3c34b58422075a1b49575a6abc616d9751b4d61b23f712e12ebb78fe47b2", size = 207866, upload-time = "2026-03-06T06:00:29.769Z" }, - { url = "https://files.pythonhosted.org/packages/c9/23/404eb36fac4e95b833c50e305bba9a241086d427bb2167a42eac7c4f7da4/charset_normalizer-3.4.5-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:036c079aa08a6a592b82487f97c60b439428320ed1b2ea0b3912e99d30c77765", size = 203239, upload-time = "2026-03-06T06:00:31.086Z" }, - { url = "https://files.pythonhosted.org/packages/4b/2f/8a1d989bfadd120c90114ab33e0d2a0cbde05278c1fc15e83e62d570f50a/charset_normalizer-3.4.5-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:340810d34ef83af92148e96e3e44cb2d3f910d2bf95e5618a5c467d9f102231d", size = 196529, upload-time = "2026-03-06T06:00:32.608Z" }, - { url = "https://files.pythonhosted.org/packages/a5/0c/c75f85ff7ca1f051958bb518cd43922d86f576c03947a050fbedfdfb4f15/charset_normalizer-3.4.5-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:cd2d0f0ec9aa977a27731a3209ebbcacebebaf41f902bd453a928bfd281cf7f8", size = 184152, upload-time = "2026-03-06T06:00:33.93Z" }, - { url = "https://files.pythonhosted.org/packages/f9/20/4ed37f6199af5dde94d4aeaf577f3813a5ec6635834cda1d957013a09c76/charset_normalizer-3.4.5-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0b362bcd27819f9c07cbf23db4e0e8cd4b44c5ecd900c2ff907b2b92274a7412", size = 195226, upload-time = "2026-03-06T06:00:35.469Z" }, - { url = "https://files.pythonhosted.org/packages/28/31/7ba1102178cba7c34dcc050f43d427172f389729e356038f0726253dd914/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:77be992288f720306ab4108fe5c74797de327f3248368dfc7e1a916d6ed9e5a2", size = 192933, upload-time = "2026-03-06T06:00:36.83Z" }, - { url = "https://files.pythonhosted.org/packages/4b/23/f86443ab3921e6a60b33b93f4a1161222231f6c69bc24fb18f3bee7b8518/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:8b78d8a609a4b82c273257ee9d631ded7fac0d875bdcdccc109f3ee8328cfcb1", size = 185647, upload-time = "2026-03-06T06:00:38.367Z" }, - { url = "https://files.pythonhosted.org/packages/82/44/08b8be891760f1f5a6d23ce11d6d50c92981603e6eb740b4f72eea9424e2/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ba20bdf69bd127f66d0174d6f2a93e69045e0b4036dc1ca78e091bcc765830c4", size = 209533, upload-time = "2026-03-06T06:00:41.931Z" }, - { url = "https://files.pythonhosted.org/packages/3b/5f/df114f23406199f8af711ddccfbf409ffbc5b7cdc18fa19644997ff0c9bb/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:76a9d0de4d0eab387822e7b35d8f89367dd237c72e82ab42b9f7bf5e15ada00f", size = 195901, upload-time = "2026-03-06T06:00:43.978Z" }, - { url = "https://files.pythonhosted.org/packages/07/83/71ef34a76fe8aa05ff8f840244bda2d61e043c2ef6f30d200450b9f6a1be/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8fff79bf5978c693c9b1a4d71e4a94fddfb5fe744eb062a318e15f4a2f63a550", size = 204950, upload-time = "2026-03-06T06:00:45.202Z" }, - { url = "https://files.pythonhosted.org/packages/58/40/0253be623995365137d7dc68e45245036207ab2227251e69a3d93ce43183/charset_normalizer-3.4.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c7e84e0c0005e3bdc1a9211cd4e62c78ba80bc37b2365ef4410cd2007a9047f2", size = 198546, upload-time = "2026-03-06T06:00:46.481Z" }, - { url = "https://files.pythonhosted.org/packages/ed/5c/5f3cb5b259a130895ef5ae16b38eaf141430fa3f7af50cd06c5d67e4f7b2/charset_normalizer-3.4.5-cp310-cp310-win32.whl", hash = "sha256:58ad8270cfa5d4bef1bc85bd387217e14ff154d6630e976c6f56f9a040757475", size = 132516, upload-time = "2026-03-06T06:00:47.924Z" }, - { url = "https://files.pythonhosted.org/packages/a5/c3/84fb174e7770f2df2e1a2115090771bfbc2227fb39a765c6d00568d1aab4/charset_normalizer-3.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:02a9d1b01c1e12c27883b0c9349e0bcd9ae92e727ff1a277207e1a262b1cbf05", size = 142906, upload-time = "2026-03-06T06:00:49.389Z" }, - { url = "https://files.pythonhosted.org/packages/d7/b2/6f852f8b969f2cbd0d4092d2e60139ab1af95af9bb651337cae89ec0f684/charset_normalizer-3.4.5-cp310-cp310-win_arm64.whl", hash = "sha256:039215608ac7b358c4da0191d10fc76868567fbf276d54c14721bdedeb6de064", size = 133258, upload-time = "2026-03-06T06:00:51.051Z" }, - { url = "https://files.pythonhosted.org/packages/8f/9e/bcec3b22c64ecec47d39bf5167c2613efd41898c019dccd4183f6aa5d6a7/charset_normalizer-3.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:610f72c0ee565dfb8ae1241b666119582fdbfe7c0975c175be719f940e110694", size = 279531, upload-time = "2026-03-06T06:00:52.252Z" }, - { url = "https://files.pythonhosted.org/packages/58/12/81fd25f7e7078ab5d1eedbb0fac44be4904ae3370a3bf4533c8f2d159acd/charset_normalizer-3.4.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60d68e820af339df4ae8358c7a2e7596badeb61e544438e489035f9fbf3246a5", size = 188006, upload-time = "2026-03-06T06:00:53.8Z" }, - { url = "https://files.pythonhosted.org/packages/ae/6e/f2d30e8c27c1b0736a6520311982cf5286cfc7f6cac77d7bc1325e3a23f2/charset_normalizer-3.4.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:10b473fc8dca1c3ad8559985794815f06ca3fc71942c969129070f2c3cdf7281", size = 205085, upload-time = "2026-03-06T06:00:55.311Z" }, - { url = "https://files.pythonhosted.org/packages/d0/90/d12cefcb53b5931e2cf792a33718d7126efb116a320eaa0742c7059a95e4/charset_normalizer-3.4.5-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d4eb8ac7469b2a5d64b5b8c04f84d8bf3ad340f4514b98523805cbf46e3b3923", size = 200545, upload-time = "2026-03-06T06:00:56.532Z" }, - { url = "https://files.pythonhosted.org/packages/03/f4/44d3b830a20e89ff82a3134912d9a1cf6084d64f3b95dcad40f74449a654/charset_normalizer-3.4.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bcb3227c3d9aaf73eaaab1db7ccd80a8995c509ee9941e2aae060ca6e4e5d81", size = 193863, upload-time = "2026-03-06T06:00:57.823Z" }, - { url = "https://files.pythonhosted.org/packages/25/4b/f212119c18a6320a9d4a730d1b4057875cdeabf21b3614f76549042ef8a8/charset_normalizer-3.4.5-cp311-cp311-manylinux_2_31_armv7l.whl", hash = "sha256:75ee9c1cce2911581a70a3c0919d8bccf5b1cbc9b0e5171400ec736b4b569497", size = 181827, upload-time = "2026-03-06T06:00:59.323Z" }, - { url = "https://files.pythonhosted.org/packages/74/00/b26158e48b425a202a92965f8069e8a63d9af1481dfa206825d7f74d2a3c/charset_normalizer-3.4.5-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:1d1401945cb77787dbd3af2446ff2d75912327c4c3a1526ab7955ecf8600687c", size = 191085, upload-time = "2026-03-06T06:01:00.546Z" }, - { url = "https://files.pythonhosted.org/packages/c4/c2/1c1737bf6fd40335fe53d28fe49afd99ee4143cc57a845e99635ce0b9b6d/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a45e504f5e1be0bd385935a8e1507c442349ca36f511a47057a71c9d1d6ea9e", size = 190688, upload-time = "2026-03-06T06:01:02.479Z" }, - { url = "https://files.pythonhosted.org/packages/5a/3d/abb5c22dc2ef493cd56522f811246a63c5427c08f3e3e50ab663de27fcf4/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:e09f671a54ce70b79a1fc1dc6da3072b7ef7251fadb894ed92d9aa8218465a5f", size = 183077, upload-time = "2026-03-06T06:01:04.231Z" }, - { url = "https://files.pythonhosted.org/packages/44/33/5298ad4d419a58e25b3508e87f2758d1442ff00c2471f8e0403dab8edad5/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:d01de5e768328646e6a3fa9e562706f8f6641708c115c62588aef2b941a4f88e", size = 206706, upload-time = "2026-03-06T06:01:05.773Z" }, - { url = "https://files.pythonhosted.org/packages/7b/17/51e7895ac0f87c3b91d276a449ef09f5532a7529818f59646d7a55089432/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:131716d6786ad5e3dc542f5cc6f397ba3339dc0fb87f87ac30e550e8987756af", size = 191665, upload-time = "2026-03-06T06:01:07.473Z" }, - { url = "https://files.pythonhosted.org/packages/90/8f/cce9adf1883e98906dbae380d769b4852bb0fa0004bc7d7a2243418d3ea8/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:1a374cc0b88aa710e8865dc1bd6edb3743c59f27830f0293ab101e4cf3ce9f85", size = 201950, upload-time = "2026-03-06T06:01:08.973Z" }, - { url = "https://files.pythonhosted.org/packages/08/ca/bce99cd5c397a52919e2769d126723f27a4c037130374c051c00470bcd38/charset_normalizer-3.4.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d31f0d1671e1534e395f9eb84a68e0fb670e1edb1fe819a9d7f564ae3bc4e53f", size = 195830, upload-time = "2026-03-06T06:01:10.155Z" }, - { url = "https://files.pythonhosted.org/packages/87/4f/2e3d023a06911f1281f97b8f036edc9872167036ca6f55cc874a0be6c12c/charset_normalizer-3.4.5-cp311-cp311-win32.whl", hash = "sha256:cace89841c0599d736d3d74a27bc5821288bb47c5441923277afc6059d7fbcb4", size = 132029, upload-time = "2026-03-06T06:01:11.706Z" }, - { url = "https://files.pythonhosted.org/packages/fe/1f/a853b73d386521fd44b7f67ded6b17b7b2367067d9106a5c4b44f9a34274/charset_normalizer-3.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:f8102ae93c0bc863b1d41ea0f4499c20a83229f52ed870850892df555187154a", size = 142404, upload-time = "2026-03-06T06:01:12.865Z" }, - { url = "https://files.pythonhosted.org/packages/b4/10/dba36f76b71c38e9d391abe0fd8a5b818790e053c431adecfc98c35cd2a9/charset_normalizer-3.4.5-cp311-cp311-win_arm64.whl", hash = "sha256:ed98364e1c262cf5f9363c3eca8c2df37024f52a8fa1180a3610014f26eac51c", size = 132796, upload-time = "2026-03-06T06:01:14.106Z" }, - { url = "https://files.pythonhosted.org/packages/9c/b6/9ee9c1a608916ca5feae81a344dffbaa53b26b90be58cc2159e3332d44ec/charset_normalizer-3.4.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ed97c282ee4f994ef814042423a529df9497e3c666dca19be1d4cd1129dc7ade", size = 280976, upload-time = "2026-03-06T06:01:15.276Z" }, - { url = "https://files.pythonhosted.org/packages/f8/d8/a54f7c0b96f1df3563e9190f04daf981e365a9b397eedfdfb5dbef7e5c6c/charset_normalizer-3.4.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0294916d6ccf2d069727d65973c3a1ca477d68708db25fd758dd28b0827cff54", size = 189356, upload-time = "2026-03-06T06:01:16.511Z" }, - { url = "https://files.pythonhosted.org/packages/42/69/2bf7f76ce1446759a5787cb87d38f6a61eb47dbbdf035cfebf6347292a65/charset_normalizer-3.4.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:dc57a0baa3eeedd99fafaef7511b5a6ef4581494e8168ee086031744e2679467", size = 206369, upload-time = "2026-03-06T06:01:17.853Z" }, - { url = "https://files.pythonhosted.org/packages/10/9c/949d1a46dab56b959d9a87272482195f1840b515a3380e39986989a893ae/charset_normalizer-3.4.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ed1a9a204f317ef879b32f9af507d47e49cd5e7f8e8d5d96358c98373314fc60", size = 203285, upload-time = "2026-03-06T06:01:19.473Z" }, - { url = "https://files.pythonhosted.org/packages/67/5c/ae30362a88b4da237d71ea214a8c7eb915db3eec941adda511729ac25fa2/charset_normalizer-3.4.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ad83b8f9379176c841f8865884f3514d905bcd2a9a3b210eaa446e7d2223e4d", size = 196274, upload-time = "2026-03-06T06:01:20.728Z" }, - { url = "https://files.pythonhosted.org/packages/b2/07/c9f2cb0e46cb6d64fdcc4f95953747b843bb2181bda678dc4e699b8f0f9a/charset_normalizer-3.4.5-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:a118e2e0b5ae6b0120d5efa5f866e58f2bb826067a646431da4d6a2bdae7950e", size = 184715, upload-time = "2026-03-06T06:01:22.194Z" }, - { url = "https://files.pythonhosted.org/packages/36/64/6b0ca95c44fddf692cd06d642b28f63009d0ce325fad6e9b2b4d0ef86a52/charset_normalizer-3.4.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:754f96058e61a5e22e91483f823e07df16416ce76afa4ebf306f8e1d1296d43f", size = 193426, upload-time = "2026-03-06T06:01:23.795Z" }, - { url = "https://files.pythonhosted.org/packages/50/bc/a730690d726403743795ca3f5bb2baf67838c5fea78236098f324b965e40/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0c300cefd9b0970381a46394902cd18eaf2aa00163f999590ace991989dcd0fc", size = 191780, upload-time = "2026-03-06T06:01:25.053Z" }, - { url = "https://files.pythonhosted.org/packages/97/4f/6c0bc9af68222b22951552d73df4532b5be6447cee32d58e7e8c74ecbb7b/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c108f8619e504140569ee7de3f97d234f0fbae338a7f9f360455071ef9855a95", size = 185805, upload-time = "2026-03-06T06:01:26.294Z" }, - { url = "https://files.pythonhosted.org/packages/dd/b9/a523fb9b0ee90814b503452b2600e4cbc118cd68714d57041564886e7325/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d1028de43596a315e2720a9849ee79007ab742c06ad8b45a50db8cdb7ed4a82a", size = 208342, upload-time = "2026-03-06T06:01:27.55Z" }, - { url = "https://files.pythonhosted.org/packages/4d/61/c59e761dee4464050713e50e27b58266cc8e209e518c0b378c1580c959ba/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:19092dde50335accf365cce21998a1c6dd8eafd42c7b226eb54b2747cdce2fac", size = 193661, upload-time = "2026-03-06T06:01:29.051Z" }, - { url = "https://files.pythonhosted.org/packages/1c/43/729fa30aad69783f755c5ad8649da17ee095311ca42024742701e202dc59/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4354e401eb6dab9aed3c7b4030514328a6c748d05e1c3e19175008ca7de84fb1", size = 204819, upload-time = "2026-03-06T06:01:30.298Z" }, - { url = "https://files.pythonhosted.org/packages/87/33/d9b442ce5a91b96fc0840455a9e49a611bbadae6122778d0a6a79683dd31/charset_normalizer-3.4.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a68766a3c58fde7f9aaa22b3786276f62ab2f594efb02d0a1421b6282e852e98", size = 198080, upload-time = "2026-03-06T06:01:31.478Z" }, - { url = "https://files.pythonhosted.org/packages/56/5a/b8b5a23134978ee9885cee2d6995f4c27cc41f9baded0a9685eabc5338f0/charset_normalizer-3.4.5-cp312-cp312-win32.whl", hash = "sha256:1827734a5b308b65ac54e86a618de66f935a4f63a8a462ff1e19a6788d6c2262", size = 132630, upload-time = "2026-03-06T06:01:33.056Z" }, - { url = "https://files.pythonhosted.org/packages/70/53/e44a4c07e8904500aec95865dc3f6464dc3586a039ef0df606eb3ac38e35/charset_normalizer-3.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:728c6a963dfab66ef865f49286e45239384249672cd598576765acc2a640a636", size = 142856, upload-time = "2026-03-06T06:01:34.489Z" }, - { url = "https://files.pythonhosted.org/packages/ea/aa/c5628f7cad591b1cf45790b7a61483c3e36cf41349c98af7813c483fd6e8/charset_normalizer-3.4.5-cp312-cp312-win_arm64.whl", hash = "sha256:75dfd1afe0b1647449e852f4fb428195a7ed0588947218f7ba929f6538487f02", size = 132982, upload-time = "2026-03-06T06:01:35.641Z" }, - { url = "https://files.pythonhosted.org/packages/f5/48/9f34ec4bb24aa3fdba1890c1bddb97c8a4be1bd84ef5c42ac2352563ad05/charset_normalizer-3.4.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ac59c15e3f1465f722607800c68713f9fbc2f672b9eb649fe831da4019ae9b23", size = 280788, upload-time = "2026-03-06T06:01:37.126Z" }, - { url = "https://files.pythonhosted.org/packages/0e/09/6003e7ffeb90cc0560da893e3208396a44c210c5ee42efff539639def59b/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:165c7b21d19365464e8f70e5ce5e12524c58b48c78c1f5a57524603c1ab003f8", size = 188890, upload-time = "2026-03-06T06:01:38.73Z" }, - { url = "https://files.pythonhosted.org/packages/42/1e/02706edf19e390680daa694d17e2b8eab4b5f7ac285e2a51168b4b22ee6b/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:28269983f25a4da0425743d0d257a2d6921ea7d9b83599d4039486ec5b9f911d", size = 206136, upload-time = "2026-03-06T06:01:40.016Z" }, - { url = "https://files.pythonhosted.org/packages/c7/87/942c3def1b37baf3cf786bad01249190f3ca3d5e63a84f831e704977de1f/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d27ce22ec453564770d29d03a9506d449efbb9fa13c00842262b2f6801c48cce", size = 202551, upload-time = "2026-03-06T06:01:41.522Z" }, - { url = "https://files.pythonhosted.org/packages/94/0a/af49691938dfe175d71b8a929bd7e4ace2809c0c5134e28bc535660d5262/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0625665e4ebdddb553ab185de5db7054393af8879fb0c87bd5690d14379d6819", size = 195572, upload-time = "2026-03-06T06:01:43.208Z" }, - { url = "https://files.pythonhosted.org/packages/20/ea/dfb1792a8050a8e694cfbde1570ff97ff74e48afd874152d38163d1df9ae/charset_normalizer-3.4.5-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:c23eb3263356d94858655b3e63f85ac5d50970c6e8febcdde7830209139cc37d", size = 184438, upload-time = "2026-03-06T06:01:44.755Z" }, - { url = "https://files.pythonhosted.org/packages/72/12/c281e2067466e3ddd0595bfaea58a6946765ace5c72dfa3edc2f5f118026/charset_normalizer-3.4.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e6302ca4ae283deb0af68d2fbf467474b8b6aedcd3dab4db187e07f94c109763", size = 193035, upload-time = "2026-03-06T06:01:46.051Z" }, - { url = "https://files.pythonhosted.org/packages/ba/4f/3792c056e7708e10464bad0438a44708886fb8f92e3c3d29ec5e2d964d42/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e51ae7d81c825761d941962450f50d041db028b7278e7b08930b4541b3e45cb9", size = 191340, upload-time = "2026-03-06T06:01:47.547Z" }, - { url = "https://files.pythonhosted.org/packages/e7/86/80ddba897127b5c7a9bccc481b0cd36c8fefa485d113262f0fe4332f0bf4/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:597d10dec876923e5c59e48dbd366e852eacb2b806029491d307daea6b917d7c", size = 185464, upload-time = "2026-03-06T06:01:48.764Z" }, - { url = "https://files.pythonhosted.org/packages/4d/00/b5eff85ba198faacab83e0e4b6f0648155f072278e3b392a82478f8b988b/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:5cffde4032a197bd3b42fd0b9509ec60fb70918d6970e4cc773f20fc9180ca67", size = 208014, upload-time = "2026-03-06T06:01:50.371Z" }, - { url = "https://files.pythonhosted.org/packages/c8/11/d36f70be01597fd30850dde8a1269ebc8efadd23ba5785808454f2389bde/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:2da4eedcb6338e2321e831a0165759c0c620e37f8cd044a263ff67493be8ffb3", size = 193297, upload-time = "2026-03-06T06:01:51.933Z" }, - { url = "https://files.pythonhosted.org/packages/1a/1d/259eb0a53d4910536c7c2abb9cb25f4153548efb42800c6a9456764649c0/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:65a126fb4b070d05340a84fc709dd9e7c75d9b063b610ece8a60197a291d0adf", size = 204321, upload-time = "2026-03-06T06:01:53.887Z" }, - { url = "https://files.pythonhosted.org/packages/84/31/faa6c5b9d3688715e1ed1bb9d124c384fe2fc1633a409e503ffe1c6398c1/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7a80a9242963416bd81f99349d5f3fce1843c303bd404f204918b6d75a75fd6", size = 197509, upload-time = "2026-03-06T06:01:56.439Z" }, - { url = "https://files.pythonhosted.org/packages/fd/a5/c7d9dd1503ffc08950b3260f5d39ec2366dd08254f0900ecbcf3a6197c7c/charset_normalizer-3.4.5-cp313-cp313-win32.whl", hash = "sha256:f1d725b754e967e648046f00c4facc42d414840f5ccc670c5670f59f83693e4f", size = 132284, upload-time = "2026-03-06T06:01:57.812Z" }, - { url = "https://files.pythonhosted.org/packages/b9/0f/57072b253af40c8aa6636e6de7d75985624c1eb392815b2f934199340a89/charset_normalizer-3.4.5-cp313-cp313-win_amd64.whl", hash = "sha256:e37bd100d2c5d3ba35db9c7c5ba5a9228cbcffe5c4778dc824b164e5257813d7", size = 142630, upload-time = "2026-03-06T06:01:59.062Z" }, - { url = "https://files.pythonhosted.org/packages/31/41/1c4b7cc9f13bd9d369ce3bc993e13d374ce25fa38a2663644283ecf422c1/charset_normalizer-3.4.5-cp313-cp313-win_arm64.whl", hash = "sha256:93b3b2cc5cf1b8743660ce77a4f45f3f6d1172068207c1defc779a36eea6bb36", size = 133254, upload-time = "2026-03-06T06:02:00.281Z" }, - { url = "https://files.pythonhosted.org/packages/43/be/0f0fd9bb4a7fa4fb5067fb7d9ac693d4e928d306f80a0d02bde43a7c4aee/charset_normalizer-3.4.5-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8197abe5ca1ffb7d91e78360f915eef5addff270f8a71c1fc5be24a56f3e4873", size = 280232, upload-time = "2026-03-06T06:02:01.508Z" }, - { url = "https://files.pythonhosted.org/packages/28/02/983b5445e4bef49cd8c9da73a8e029f0825f39b74a06d201bfaa2e55142a/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2aecdb364b8a1802afdc7f9327d55dad5366bc97d8502d0f5854e50712dbc5f", size = 189688, upload-time = "2026-03-06T06:02:02.857Z" }, - { url = "https://files.pythonhosted.org/packages/d0/88/152745c5166437687028027dc080e2daed6fe11cfa95a22f4602591c42db/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a66aa5022bf81ab4b1bebfb009db4fd68e0c6d4307a1ce5ef6a26e5878dfc9e4", size = 206833, upload-time = "2026-03-06T06:02:05.127Z" }, - { url = "https://files.pythonhosted.org/packages/cb/0f/ebc15c8b02af2f19be9678d6eed115feeeccc45ce1f4b098d986c13e8769/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d77f97e515688bd615c1d1f795d540f32542d514242067adcb8ef532504cb9ee", size = 202879, upload-time = "2026-03-06T06:02:06.446Z" }, - { url = "https://files.pythonhosted.org/packages/38/9c/71336bff6934418dc8d1e8a1644176ac9088068bc571da612767619c97b3/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01a1ed54b953303ca7e310fafe0fe347aab348bd81834a0bcd602eb538f89d66", size = 195764, upload-time = "2026-03-06T06:02:08.763Z" }, - { url = "https://files.pythonhosted.org/packages/b7/95/ce92fde4f98615661871bc282a856cf9b8a15f686ba0af012984660d480b/charset_normalizer-3.4.5-cp314-cp314-manylinux_2_31_armv7l.whl", hash = "sha256:b2d37d78297b39a9eb9eb92c0f6df98c706467282055419df141389b23f93362", size = 183728, upload-time = "2026-03-06T06:02:10.137Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e7/f5b4588d94e747ce45ae680f0f242bc2d98dbd4eccfab73e6160b6893893/charset_normalizer-3.4.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e71bbb595973622b817c042bd943c3f3667e9c9983ce3d205f973f486fec98a7", size = 192937, upload-time = "2026-03-06T06:02:11.663Z" }, - { url = "https://files.pythonhosted.org/packages/f9/29/9d94ed6b929bf9f48bf6ede6e7474576499f07c4c5e878fb186083622716/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cd966c2559f501c6fd69294d082c2934c8dd4719deb32c22961a5ac6db0df1d", size = 192040, upload-time = "2026-03-06T06:02:13.489Z" }, - { url = "https://files.pythonhosted.org/packages/15/d2/1a093a1cf827957f9445f2fe7298bcc16f8fc5e05c1ed2ad1af0b239035e/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:d5e52d127045d6ae01a1e821acfad2f3a1866c54d0e837828538fabe8d9d1bd6", size = 184107, upload-time = "2026-03-06T06:02:14.83Z" }, - { url = "https://files.pythonhosted.org/packages/0f/7d/82068ce16bd36135df7b97f6333c5d808b94e01d4599a682e2337ed5fd14/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:30a2b1a48478c3428d047ed9690d57c23038dac838a87ad624c85c0a78ebeb39", size = 208310, upload-time = "2026-03-06T06:02:16.165Z" }, - { url = "https://files.pythonhosted.org/packages/84/4e/4dfb52307bb6af4a5c9e73e482d171b81d36f522b21ccd28a49656baa680/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:d8ed79b8f6372ca4254955005830fd61c1ccdd8c0fac6603e2c145c61dd95db6", size = 192918, upload-time = "2026-03-06T06:02:18.144Z" }, - { url = "https://files.pythonhosted.org/packages/08/a4/159ff7da662cf7201502ca89980b8f06acf3e887b278956646a8aeb178ab/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:c5af897b45fa606b12464ccbe0014bbf8c09191e0a66aab6aa9d5cf6e77e0c94", size = 204615, upload-time = "2026-03-06T06:02:19.821Z" }, - { url = "https://files.pythonhosted.org/packages/d6/62/0dd6172203cb6b429ffffc9935001fde42e5250d57f07b0c28c6046deb6b/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1088345bcc93c58d8d8f3d783eca4a6e7a7752bbff26c3eee7e73c597c191c2e", size = 197784, upload-time = "2026-03-06T06:02:21.86Z" }, - { url = "https://files.pythonhosted.org/packages/c7/5e/1aab5cb737039b9c59e63627dc8bbc0d02562a14f831cc450e5f91d84ce1/charset_normalizer-3.4.5-cp314-cp314-win32.whl", hash = "sha256:ee57b926940ba00bca7ba7041e665cc956e55ef482f851b9b65acb20d867e7a2", size = 133009, upload-time = "2026-03-06T06:02:23.289Z" }, - { url = "https://files.pythonhosted.org/packages/40/65/e7c6c77d7aaa4c0d7974f2e403e17f0ed2cb0fc135f77d686b916bf1eead/charset_normalizer-3.4.5-cp314-cp314-win_amd64.whl", hash = "sha256:4481e6da1830c8a1cc0b746b47f603b653dadb690bcd851d039ffaefe70533aa", size = 143511, upload-time = "2026-03-06T06:02:26.195Z" }, - { url = "https://files.pythonhosted.org/packages/ba/91/52b0841c71f152f563b8e072896c14e3d83b195c188b338d3cc2e582d1d4/charset_normalizer-3.4.5-cp314-cp314-win_arm64.whl", hash = "sha256:97ab7787092eb9b50fb47fa04f24c75b768a606af1bcba1957f07f128a7219e4", size = 133775, upload-time = "2026-03-06T06:02:27.473Z" }, - { url = "https://files.pythonhosted.org/packages/c5/60/3a621758945513adfd4db86827a5bafcc615f913dbd0b4c2ed64a65731be/charset_normalizer-3.4.5-py3-none-any.whl", hash = "sha256:9db5e3fcdcee89a78c04dffb3fe33c79f77bd741a624946db2591c81b2fc85b0", size = 55455, upload-time = "2026-03-06T06:03:17.827Z" }, +version = "3.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/60/e3bec1881450851b087e301bedc3daa9377a4d45f1c26aa90b0b235e38aa/charset_normalizer-3.4.6.tar.gz", hash = "sha256:1ae6b62897110aa7c79ea2f5dd38d1abca6db663687c0b1ad9aed6f6bae3d9d6", size = 143363, upload-time = "2026-03-15T18:53:25.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/8c/2c56124c6dc53a774d435f985b5973bc592f42d437be58c0c92d65ae7296/charset_normalizer-3.4.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2e1d8ca8611099001949d1cdfaefc510cf0f212484fe7c565f735b68c78c3c95", size = 298751, upload-time = "2026-03-15T18:50:00.003Z" }, + { url = "https://files.pythonhosted.org/packages/86/2a/2a7db6b314b966a3bcad8c731c0719c60b931b931de7ae9f34b2839289ee/charset_normalizer-3.4.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e25369dc110d58ddf29b949377a93e0716d72a24f62bad72b2b39f155949c1fd", size = 200027, upload-time = "2026-03-15T18:50:01.702Z" }, + { url = "https://files.pythonhosted.org/packages/68/f2/0fe775c74ae25e2a3b07b01538fc162737b3e3f795bada3bc26f4d4d495c/charset_normalizer-3.4.6-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:259695e2ccc253feb2a016303543d691825e920917e31f894ca1a687982b1de4", size = 220741, upload-time = "2026-03-15T18:50:03.194Z" }, + { url = "https://files.pythonhosted.org/packages/10/98/8085596e41f00b27dd6aa1e68413d1ddda7e605f34dd546833c61fddd709/charset_normalizer-3.4.6-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dda86aba335c902b6149a02a55b38e96287157e609200811837678214ba2b1db", size = 215802, upload-time = "2026-03-15T18:50:05.859Z" }, + { url = "https://files.pythonhosted.org/packages/fd/ce/865e4e09b041bad659d682bbd98b47fb490b8e124f9398c9448065f64fee/charset_normalizer-3.4.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51fb3c322c81d20567019778cb5a4a6f2dc1c200b886bc0d636238e364848c89", size = 207908, upload-time = "2026-03-15T18:50:07.676Z" }, + { url = "https://files.pythonhosted.org/packages/a8/54/8c757f1f7349262898c2f169e0d562b39dcb977503f18fdf0814e923db78/charset_normalizer-3.4.6-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:4482481cb0572180b6fd976a4d5c72a30263e98564da68b86ec91f0fe35e8565", size = 194357, upload-time = "2026-03-15T18:50:09.327Z" }, + { url = "https://files.pythonhosted.org/packages/6f/29/e88f2fac9218907fc7a70722b393d1bbe8334c61fe9c46640dba349b6e66/charset_normalizer-3.4.6-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:39f5068d35621da2881271e5c3205125cc456f54e9030d3f723288c873a71bf9", size = 205610, upload-time = "2026-03-15T18:50:10.732Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c5/21d7bb0cb415287178450171d130bed9d664211fdd59731ed2c34267b07d/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8bea55c4eef25b0b19a0337dc4e3f9a15b00d569c77211fa8cde38684f234fb7", size = 203512, upload-time = "2026-03-15T18:50:12.535Z" }, + { url = "https://files.pythonhosted.org/packages/a4/be/ce52f3c7fdb35cc987ad38a53ebcef52eec498f4fb6c66ecfe62cfe57ba2/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:f0cdaecd4c953bfae0b6bb64910aaaca5a424ad9c72d85cb88417bb9814f7550", size = 195398, upload-time = "2026-03-15T18:50:14.236Z" }, + { url = "https://files.pythonhosted.org/packages/81/a0/3ab5dd39d4859a3555e5dadfc8a9fa7f8352f8c183d1a65c90264517da0e/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:150b8ce8e830eb7ccb029ec9ca36022f756986aaaa7956aad6d9ec90089338c0", size = 221772, upload-time = "2026-03-15T18:50:15.581Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/6a4e41a97ba6b2fa87f849c41e4d229449a586be85053c4d90135fe82d26/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:e68c14b04827dd76dcbd1aeea9e604e3e4b78322d8faf2f8132c7138efa340a8", size = 205759, upload-time = "2026-03-15T18:50:17.047Z" }, + { url = "https://files.pythonhosted.org/packages/db/3b/34a712a5ee64a6957bf355b01dc17b12de457638d436fdb05d01e463cd1c/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:3778fd7d7cd04ae8f54651f4a7a0bd6e39a0cf20f801720a4c21d80e9b7ad6b0", size = 216938, upload-time = "2026-03-15T18:50:18.44Z" }, + { url = "https://files.pythonhosted.org/packages/cb/05/5bd1e12da9ab18790af05c61aafd01a60f489778179b621ac2a305243c62/charset_normalizer-3.4.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:dad6e0f2e481fffdcf776d10ebee25e0ef89f16d691f1e5dee4b586375fdc64b", size = 210138, upload-time = "2026-03-15T18:50:19.852Z" }, + { url = "https://files.pythonhosted.org/packages/bd/8e/3cb9e2d998ff6b21c0a1860343cb7b83eba9cdb66b91410e18fc4969d6ab/charset_normalizer-3.4.6-cp310-cp310-win32.whl", hash = "sha256:74a2e659c7ecbc73562e2a15e05039f1e22c75b7c7618b4b574a3ea9118d1557", size = 144137, upload-time = "2026-03-15T18:50:21.505Z" }, + { url = "https://files.pythonhosted.org/packages/d8/8f/78f5489ffadb0db3eb7aff53d31c24531d33eb545f0c6f6567c25f49a5ff/charset_normalizer-3.4.6-cp310-cp310-win_amd64.whl", hash = "sha256:aa9cccf4a44b9b62d8ba8b4dd06c649ba683e4bf04eea606d2e94cfc2d6ff4d6", size = 154244, upload-time = "2026-03-15T18:50:22.81Z" }, + { url = "https://files.pythonhosted.org/packages/e4/74/e472659dffb0cadb2f411282d2d76c60da1fc94076d7fffed4ae8a93ec01/charset_normalizer-3.4.6-cp310-cp310-win_arm64.whl", hash = "sha256:e985a16ff513596f217cee86c21371b8cd011c0f6f056d0920aa2d926c544058", size = 143312, upload-time = "2026-03-15T18:50:24.074Z" }, + { url = "https://files.pythonhosted.org/packages/62/28/ff6f234e628a2de61c458be2779cb182bc03f6eec12200d4a525bbfc9741/charset_normalizer-3.4.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:82060f995ab5003a2d6e0f4ad29065b7672b6593c8c63559beefe5b443242c3e", size = 293582, upload-time = "2026-03-15T18:50:25.454Z" }, + { url = "https://files.pythonhosted.org/packages/1c/b7/b1a117e5385cbdb3205f6055403c2a2a220c5ea80b8716c324eaf75c5c95/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60c74963d8350241a79cb8feea80e54d518f72c26db618862a8f53e5023deaf9", size = 197240, upload-time = "2026-03-15T18:50:27.196Z" }, + { url = "https://files.pythonhosted.org/packages/a1/5f/2574f0f09f3c3bc1b2f992e20bce6546cb1f17e111c5be07308dc5427956/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f6e4333fb15c83f7d1482a76d45a0818897b3d33f00efd215528ff7c51b8e35d", size = 217363, upload-time = "2026-03-15T18:50:28.601Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d1/0ae20ad77bc949ddd39b51bf383b6ca932f2916074c95cad34ae465ab71f/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:bc72863f4d9aba2e8fd9085e63548a324ba706d2ea2c83b260da08a59b9482de", size = 212994, upload-time = "2026-03-15T18:50:30.102Z" }, + { url = "https://files.pythonhosted.org/packages/60/ac/3233d262a310c1b12633536a07cde5ddd16985e6e7e238e9f3f9423d8eb9/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9cc4fc6c196d6a8b76629a70ddfcd4635a6898756e2d9cac5565cf0654605d73", size = 204697, upload-time = "2026-03-15T18:50:31.654Z" }, + { url = "https://files.pythonhosted.org/packages/25/3c/8a18fc411f085b82303cfb7154eed5bd49c77035eb7608d049468b53f87c/charset_normalizer-3.4.6-cp311-cp311-manylinux_2_31_armv7l.whl", hash = "sha256:0c173ce3a681f309f31b87125fecec7a5d1347261ea11ebbb856fa6006b23c8c", size = 191673, upload-time = "2026-03-15T18:50:33.433Z" }, + { url = "https://files.pythonhosted.org/packages/ff/a7/11cfe61d6c5c5c7438d6ba40919d0306ed83c9ab957f3d4da2277ff67836/charset_normalizer-3.4.6-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c907cdc8109f6c619e6254212e794d6548373cc40e1ec75e6e3823d9135d29cc", size = 201120, upload-time = "2026-03-15T18:50:35.105Z" }, + { url = "https://files.pythonhosted.org/packages/b5/10/cf491fa1abd47c02f69687046b896c950b92b6cd7337a27e6548adbec8e4/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:404a1e552cf5b675a87f0651f8b79f5f1e6fd100ee88dc612f89aa16abd4486f", size = 200911, upload-time = "2026-03-15T18:50:36.819Z" }, + { url = "https://files.pythonhosted.org/packages/28/70/039796160b48b18ed466fde0af84c1b090c4e288fae26cd674ad04a2d703/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:e3c701e954abf6fc03a49f7c579cc80c2c6cc52525340ca3186c41d3f33482ef", size = 192516, upload-time = "2026-03-15T18:50:38.228Z" }, + { url = "https://files.pythonhosted.org/packages/ff/34/c56f3223393d6ff3124b9e78f7de738047c2d6bc40a4f16ac0c9d7a1cb3c/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7a6967aaf043bceabab5412ed6bd6bd26603dae84d5cb75bf8d9a74a4959d398", size = 218795, upload-time = "2026-03-15T18:50:39.664Z" }, + { url = "https://files.pythonhosted.org/packages/e8/3b/ce2d4f86c5282191a041fdc5a4ce18f1c6bd40a5bd1f74cf8625f08d51c1/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:5feb91325bbceade6afab43eb3b508c63ee53579fe896c77137ded51c6b6958e", size = 201833, upload-time = "2026-03-15T18:50:41.552Z" }, + { url = "https://files.pythonhosted.org/packages/3b/9b/b6a9f76b0fd7c5b5ec58b228ff7e85095370282150f0bd50b3126f5506d6/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f820f24b09e3e779fe84c3c456cb4108a7aa639b0d1f02c28046e11bfcd088ed", size = 213920, upload-time = "2026-03-15T18:50:43.33Z" }, + { url = "https://files.pythonhosted.org/packages/ae/98/7bc23513a33d8172365ed30ee3a3b3fe1ece14a395e5fc94129541fc6003/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b35b200d6a71b9839a46b9b7fff66b6638bb52fc9658aa58796b0326595d3021", size = 206951, upload-time = "2026-03-15T18:50:44.789Z" }, + { url = "https://files.pythonhosted.org/packages/32/73/c0b86f3d1458468e11aec870e6b3feac931facbe105a894b552b0e518e79/charset_normalizer-3.4.6-cp311-cp311-win32.whl", hash = "sha256:9ca4c0b502ab399ef89248a2c84c54954f77a070f28e546a85e91da627d1301e", size = 143703, upload-time = "2026-03-15T18:50:46.103Z" }, + { url = "https://files.pythonhosted.org/packages/c6/e3/76f2facfe8eddee0bbd38d2594e709033338eae44ebf1738bcefe0a06185/charset_normalizer-3.4.6-cp311-cp311-win_amd64.whl", hash = "sha256:a9e68c9d88823b274cf1e72f28cb5dc89c990edf430b0bfd3e2fb0785bfeabf4", size = 153857, upload-time = "2026-03-15T18:50:47.563Z" }, + { url = "https://files.pythonhosted.org/packages/e2/dc/9abe19c9b27e6cd3636036b9d1b387b78c40dedbf0b47f9366737684b4b0/charset_normalizer-3.4.6-cp311-cp311-win_arm64.whl", hash = "sha256:97d0235baafca5f2b09cf332cc275f021e694e8362c6bb9c96fc9a0eb74fc316", size = 142751, upload-time = "2026-03-15T18:50:49.234Z" }, + { url = "https://files.pythonhosted.org/packages/e5/62/c0815c992c9545347aeea7859b50dc9044d147e2e7278329c6e02ac9a616/charset_normalizer-3.4.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:2ef7fedc7a6ecbe99969cd09632516738a97eeb8bd7258bf8a0f23114c057dab", size = 295154, upload-time = "2026-03-15T18:50:50.88Z" }, + { url = "https://files.pythonhosted.org/packages/a8/37/bdca6613c2e3c58c7421891d80cc3efa1d32e882f7c4a7ee6039c3fc951a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4ea868bc28109052790eb2b52a9ab33f3aa7adc02f96673526ff47419490e21", size = 199191, upload-time = "2026-03-15T18:50:52.658Z" }, + { url = "https://files.pythonhosted.org/packages/6c/92/9934d1bbd69f7f398b38c5dae1cbf9cc672e7c34a4adf7b17c0a9c17d15d/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:836ab36280f21fc1a03c99cd05c6b7af70d2697e374c7af0b61ed271401a72a2", size = 218674, upload-time = "2026-03-15T18:50:54.102Z" }, + { url = "https://files.pythonhosted.org/packages/af/90/25f6ab406659286be929fd89ab0e78e38aa183fc374e03aa3c12d730af8a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f1ce721c8a7dfec21fcbdfe04e8f68174183cf4e8188e0645e92aa23985c57ff", size = 215259, upload-time = "2026-03-15T18:50:55.616Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ef/79a463eb0fff7f96afa04c1d4c51f8fc85426f918db467854bfb6a569ce3/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e28d62a8fc7a1fa411c43bd65e346f3bce9716dc51b897fbe930c5987b402d5", size = 207276, upload-time = "2026-03-15T18:50:57.054Z" }, + { url = "https://files.pythonhosted.org/packages/f7/72/d0426afec4b71dc159fa6b4e68f868cd5a3ecd918fec5813a15d292a7d10/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:530d548084c4a9f7a16ed4a294d459b4f229db50df689bfe92027452452943a0", size = 195161, upload-time = "2026-03-15T18:50:58.686Z" }, + { url = "https://files.pythonhosted.org/packages/bf/18/c82b06a68bfcb6ce55e508225d210c7e6a4ea122bfc0748892f3dc4e8e11/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30f445ae60aad5e1f8bdbb3108e39f6fbc09f4ea16c815c66578878325f8f15a", size = 203452, upload-time = "2026-03-15T18:51:00.196Z" }, + { url = "https://files.pythonhosted.org/packages/44/d6/0c25979b92f8adafdbb946160348d8d44aa60ce99afdc27df524379875cb/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ac2393c73378fea4e52aa56285a3d64be50f1a12395afef9cce47772f60334c2", size = 202272, upload-time = "2026-03-15T18:51:01.703Z" }, + { url = "https://files.pythonhosted.org/packages/2e/3d/7fea3e8fe84136bebbac715dd1221cc25c173c57a699c030ab9b8900cbb7/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:90ca27cd8da8118b18a52d5f547859cc1f8354a00cd1e8e5120df3e30d6279e5", size = 195622, upload-time = "2026-03-15T18:51:03.526Z" }, + { url = "https://files.pythonhosted.org/packages/57/8a/d6f7fd5cb96c58ef2f681424fbca01264461336d2a7fc875e4446b1f1346/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8e5a94886bedca0f9b78fecd6afb6629142fd2605aa70a125d49f4edc6037ee6", size = 220056, upload-time = "2026-03-15T18:51:05.269Z" }, + { url = "https://files.pythonhosted.org/packages/16/50/478cdda782c8c9c3fb5da3cc72dd7f331f031e7f1363a893cdd6ca0f8de0/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:695f5c2823691a25f17bc5d5ffe79fa90972cc34b002ac6c843bb8a1720e950d", size = 203751, upload-time = "2026-03-15T18:51:06.858Z" }, + { url = "https://files.pythonhosted.org/packages/75/fc/cc2fcac943939c8e4d8791abfa139f685e5150cae9f94b60f12520feaa9b/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:231d4da14bcd9301310faf492051bee27df11f2bc7549bc0bb41fef11b82daa2", size = 216563, upload-time = "2026-03-15T18:51:08.564Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b7/a4add1d9a5f68f3d037261aecca83abdb0ab15960a3591d340e829b37298/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a056d1ad2633548ca18ffa2f85c202cfb48b68615129143915b8dc72a806a923", size = 209265, upload-time = "2026-03-15T18:51:10.312Z" }, + { url = "https://files.pythonhosted.org/packages/6c/18/c094561b5d64a24277707698e54b7f67bd17a4f857bbfbb1072bba07c8bf/charset_normalizer-3.4.6-cp312-cp312-win32.whl", hash = "sha256:c2274ca724536f173122f36c98ce188fd24ce3dad886ec2b7af859518ce008a4", size = 144229, upload-time = "2026-03-15T18:51:11.694Z" }, + { url = "https://files.pythonhosted.org/packages/ab/20/0567efb3a8fd481b8f34f739ebddc098ed062a59fed41a8d193a61939e8f/charset_normalizer-3.4.6-cp312-cp312-win_amd64.whl", hash = "sha256:c8ae56368f8cc97c7e40a7ee18e1cedaf8e780cd8bc5ed5ac8b81f238614facb", size = 154277, upload-time = "2026-03-15T18:51:13.004Z" }, + { url = "https://files.pythonhosted.org/packages/15/57/28d79b44b51933119e21f65479d0864a8d5893e494cf5daab15df0247c17/charset_normalizer-3.4.6-cp312-cp312-win_arm64.whl", hash = "sha256:899d28f422116b08be5118ef350c292b36fc15ec2daeb9ea987c89281c7bb5c4", size = 142817, upload-time = "2026-03-15T18:51:14.408Z" }, + { url = "https://files.pythonhosted.org/packages/1e/1d/4fdabeef4e231153b6ed7567602f3b68265ec4e5b76d6024cf647d43d981/charset_normalizer-3.4.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:11afb56037cbc4b1555a34dd69151e8e069bee82e613a73bef6e714ce733585f", size = 294823, upload-time = "2026-03-15T18:51:15.755Z" }, + { url = "https://files.pythonhosted.org/packages/47/7b/20e809b89c69d37be748d98e84dce6820bf663cf19cf6b942c951a3e8f41/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:423fb7e748a08f854a08a222b983f4df1912b1daedce51a72bd24fe8f26a1843", size = 198527, upload-time = "2026-03-15T18:51:17.177Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/4f8d27527d59c039dce6f7622593cdcd3d70a8504d87d09eb11e9fdc6062/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d73beaac5e90173ac3deb9928a74763a6d230f494e4bfb422c217a0ad8e629bf", size = 218388, upload-time = "2026-03-15T18:51:18.934Z" }, + { url = "https://files.pythonhosted.org/packages/f6/9b/4770ccb3e491a9bacf1c46cc8b812214fe367c86a96353ccc6daf87b01ec/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d60377dce4511655582e300dc1e5a5f24ba0cb229005a1d5c8d0cb72bb758ab8", size = 214563, upload-time = "2026-03-15T18:51:20.374Z" }, + { url = "https://files.pythonhosted.org/packages/2b/58/a199d245894b12db0b957d627516c78e055adc3a0d978bc7f65ddaf7c399/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:530e8cebeea0d76bdcf93357aa5e41336f48c3dc709ac52da2bb167c5b8271d9", size = 206587, upload-time = "2026-03-15T18:51:21.807Z" }, + { url = "https://files.pythonhosted.org/packages/7e/70/3def227f1ec56f5c69dfc8392b8bd63b11a18ca8178d9211d7cc5e5e4f27/charset_normalizer-3.4.6-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:a26611d9987b230566f24a0a125f17fe0de6a6aff9f25c9f564aaa2721a5fb88", size = 194724, upload-time = "2026-03-15T18:51:23.508Z" }, + { url = "https://files.pythonhosted.org/packages/58/ab/9318352e220c05efd31c2779a23b50969dc94b985a2efa643ed9077bfca5/charset_normalizer-3.4.6-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:34315ff4fc374b285ad7f4a0bf7dcbfe769e1b104230d40f49f700d4ab6bbd84", size = 202956, upload-time = "2026-03-15T18:51:25.239Z" }, + { url = "https://files.pythonhosted.org/packages/75/13/f3550a3ac25b70f87ac98c40d3199a8503676c2f1620efbf8d42095cfc40/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5f8ddd609f9e1af8c7bd6e2aca279c931aefecd148a14402d4e368f3171769fd", size = 201923, upload-time = "2026-03-15T18:51:26.682Z" }, + { url = "https://files.pythonhosted.org/packages/1b/db/c5c643b912740b45e8eec21de1bbab8e7fc085944d37e1e709d3dcd9d72f/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:80d0a5615143c0b3225e5e3ef22c8d5d51f3f72ce0ea6fb84c943546c7b25b6c", size = 195366, upload-time = "2026-03-15T18:51:28.129Z" }, + { url = "https://files.pythonhosted.org/packages/5a/67/3b1c62744f9b2448443e0eb160d8b001c849ec3fef591e012eda6484787c/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:92734d4d8d187a354a556626c221cd1a892a4e0802ccb2af432a1d85ec012194", size = 219752, upload-time = "2026-03-15T18:51:29.556Z" }, + { url = "https://files.pythonhosted.org/packages/f6/98/32ffbaf7f0366ffb0445930b87d103f6b406bc2c271563644bde8a2b1093/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:613f19aa6e082cf96e17e3ffd89383343d0d589abda756b7764cf78361fd41dc", size = 203296, upload-time = "2026-03-15T18:51:30.921Z" }, + { url = "https://files.pythonhosted.org/packages/41/12/5d308c1bbe60cabb0c5ef511574a647067e2a1f631bc8634fcafaccd8293/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:2b1a63e8224e401cafe7739f77efd3f9e7f5f2026bda4aead8e59afab537784f", size = 215956, upload-time = "2026-03-15T18:51:32.399Z" }, + { url = "https://files.pythonhosted.org/packages/53/e9/5f85f6c5e20669dbe56b165c67b0260547dea97dba7e187938833d791687/charset_normalizer-3.4.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6cceb5473417d28edd20c6c984ab6fee6c6267d38d906823ebfe20b03d607dc2", size = 208652, upload-time = "2026-03-15T18:51:34.214Z" }, + { url = "https://files.pythonhosted.org/packages/f1/11/897052ea6af56df3eef3ca94edafee410ca699ca0c7b87960ad19932c55e/charset_normalizer-3.4.6-cp313-cp313-win32.whl", hash = "sha256:d7de2637729c67d67cf87614b566626057e95c303bc0a55ffe391f5205e7003d", size = 143940, upload-time = "2026-03-15T18:51:36.15Z" }, + { url = "https://files.pythonhosted.org/packages/a1/5c/724b6b363603e419829f561c854b87ed7c7e31231a7908708ac086cdf3e2/charset_normalizer-3.4.6-cp313-cp313-win_amd64.whl", hash = "sha256:572d7c822caf521f0525ba1bce1a622a0b85cf47ffbdae6c9c19e3b5ac3c4389", size = 154101, upload-time = "2026-03-15T18:51:37.876Z" }, + { url = "https://files.pythonhosted.org/packages/01/a5/7abf15b4c0968e47020f9ca0935fb3274deb87cb288cd187cad92e8cdffd/charset_normalizer-3.4.6-cp313-cp313-win_arm64.whl", hash = "sha256:a4474d924a47185a06411e0064b803c68be044be2d60e50e8bddcc2649957c1f", size = 143109, upload-time = "2026-03-15T18:51:39.565Z" }, + { url = "https://files.pythonhosted.org/packages/25/6f/ffe1e1259f384594063ea1869bfb6be5cdb8bc81020fc36c3636bc8302a1/charset_normalizer-3.4.6-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:9cc6e6d9e571d2f863fa77700701dae73ed5f78881efc8b3f9a4398772ff53e8", size = 294458, upload-time = "2026-03-15T18:51:41.134Z" }, + { url = "https://files.pythonhosted.org/packages/56/60/09bb6c13a8c1016c2ed5c6a6488e4ffef506461aa5161662bd7636936fb1/charset_normalizer-3.4.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5960d965e67165d75b7c7ffc60a83ec5abfc5c11b764ec13ea54fbef8b4421", size = 199277, upload-time = "2026-03-15T18:51:42.953Z" }, + { url = "https://files.pythonhosted.org/packages/00/50/dcfbb72a5138bbefdc3332e8d81a23494bf67998b4b100703fd15fa52d81/charset_normalizer-3.4.6-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b3694e3f87f8ac7ce279d4355645b3c878d24d1424581b46282f24b92f5a4ae2", size = 218758, upload-time = "2026-03-15T18:51:44.339Z" }, + { url = "https://files.pythonhosted.org/packages/03/b3/d79a9a191bb75f5aa81f3aaaa387ef29ce7cb7a9e5074ba8ea095cc073c2/charset_normalizer-3.4.6-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5d11595abf8dd942a77883a39d81433739b287b6aa71620f15164f8096221b30", size = 215299, upload-time = "2026-03-15T18:51:45.871Z" }, + { url = "https://files.pythonhosted.org/packages/76/7e/bc8911719f7084f72fd545f647601ea3532363927f807d296a8c88a62c0d/charset_normalizer-3.4.6-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7bda6eebafd42133efdca535b04ccb338ab29467b3f7bf79569883676fc628db", size = 206811, upload-time = "2026-03-15T18:51:47.308Z" }, + { url = "https://files.pythonhosted.org/packages/e2/40/c430b969d41dda0c465aa36cc7c2c068afb67177bef50905ac371b28ccc7/charset_normalizer-3.4.6-cp314-cp314-manylinux_2_31_armv7l.whl", hash = "sha256:bbc8c8650c6e51041ad1be191742b8b421d05bbd3410f43fa2a00c8db87678e8", size = 193706, upload-time = "2026-03-15T18:51:48.849Z" }, + { url = "https://files.pythonhosted.org/packages/48/15/e35e0590af254f7df984de1323640ef375df5761f615b6225ba8deb9799a/charset_normalizer-3.4.6-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:22c6f0c2fbc31e76c3b8a86fba1a56eda6166e238c29cdd3d14befdb4a4e4815", size = 202706, upload-time = "2026-03-15T18:51:50.257Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bd/f736f7b9cc5e93a18b794a50346bb16fbfd6b37f99e8f306f7951d27c17c/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7edbed096e4a4798710ed6bc75dcaa2a21b68b6c356553ac4823c3658d53743a", size = 202497, upload-time = "2026-03-15T18:51:52.012Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ba/2cc9e3e7dfdf7760a6ed8da7446d22536f3d0ce114ac63dee2a5a3599e62/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:7f9019c9cb613f084481bd6a100b12e1547cf2efe362d873c2e31e4035a6fa43", size = 193511, upload-time = "2026-03-15T18:51:53.723Z" }, + { url = "https://files.pythonhosted.org/packages/9e/cb/5be49b5f776e5613be07298c80e1b02a2d900f7a7de807230595c85a8b2e/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:58c948d0d086229efc484fe2f30c2d382c86720f55cd9bc33591774348ad44e0", size = 220133, upload-time = "2026-03-15T18:51:55.333Z" }, + { url = "https://files.pythonhosted.org/packages/83/43/99f1b5dad345accb322c80c7821071554f791a95ee50c1c90041c157ae99/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:419a9d91bd238052642a51938af8ac05da5b3343becde08d5cdeab9046df9ee1", size = 203035, upload-time = "2026-03-15T18:51:56.736Z" }, + { url = "https://files.pythonhosted.org/packages/87/9a/62c2cb6a531483b55dddff1a68b3d891a8b498f3ca555fbcf2978e804d9d/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:5273b9f0b5835ff0350c0828faea623c68bfa65b792720c453e22b25cc72930f", size = 216321, upload-time = "2026-03-15T18:51:58.17Z" }, + { url = "https://files.pythonhosted.org/packages/6e/79/94a010ff81e3aec7c293eb82c28f930918e517bc144c9906a060844462eb/charset_normalizer-3.4.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:0e901eb1049fdb80f5bd11ed5ea1e498ec423102f7a9b9e4645d5b8204ff2815", size = 208973, upload-time = "2026-03-15T18:51:59.998Z" }, + { url = "https://files.pythonhosted.org/packages/2a/57/4ecff6d4ec8585342f0c71bc03efaa99cb7468f7c91a57b105bcd561cea8/charset_normalizer-3.4.6-cp314-cp314-win32.whl", hash = "sha256:b4ff1d35e8c5bd078be89349b6f3a845128e685e751b6ea1169cf2160b344c4d", size = 144610, upload-time = "2026-03-15T18:52:02.213Z" }, + { url = "https://files.pythonhosted.org/packages/80/94/8434a02d9d7f168c25767c64671fead8d599744a05d6a6c877144c754246/charset_normalizer-3.4.6-cp314-cp314-win_amd64.whl", hash = "sha256:74119174722c4349af9708993118581686f343adc1c8c9c007d59be90d077f3f", size = 154962, upload-time = "2026-03-15T18:52:03.658Z" }, + { url = "https://files.pythonhosted.org/packages/46/4c/48f2cdbfd923026503dfd67ccea45c94fd8fe988d9056b468579c66ed62b/charset_normalizer-3.4.6-cp314-cp314-win_arm64.whl", hash = "sha256:e5bcc1a1ae744e0bb59641171ae53743760130600da8db48cbb6e4918e186e4e", size = 143595, upload-time = "2026-03-15T18:52:05.123Z" }, + { url = "https://files.pythonhosted.org/packages/31/93/8878be7569f87b14f1d52032946131bcb6ebbd8af3e20446bc04053dc3f1/charset_normalizer-3.4.6-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:ad8faf8df23f0378c6d527d8b0b15ea4a2e23c89376877c598c4870d1b2c7866", size = 314828, upload-time = "2026-03-15T18:52:06.831Z" }, + { url = "https://files.pythonhosted.org/packages/06/b6/fae511ca98aac69ecc35cde828b0a3d146325dd03d99655ad38fc2cc3293/charset_normalizer-3.4.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f5ea69428fa1b49573eef0cc44a1d43bebd45ad0c611eb7d7eac760c7ae771bc", size = 208138, upload-time = "2026-03-15T18:52:08.239Z" }, + { url = "https://files.pythonhosted.org/packages/54/57/64caf6e1bf07274a1e0b7c160a55ee9e8c9ec32c46846ce59b9c333f7008/charset_normalizer-3.4.6-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:06a7e86163334edfc5d20fe104db92fcd666e5a5df0977cb5680a506fe26cc8e", size = 224679, upload-time = "2026-03-15T18:52:10.043Z" }, + { url = "https://files.pythonhosted.org/packages/aa/cb/9ff5a25b9273ef160861b41f6937f86fae18b0792fe0a8e75e06acb08f1d/charset_normalizer-3.4.6-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e1f6e2f00a6b8edb562826e4632e26d063ac10307e80f7461f7de3ad8ef3f077", size = 223475, upload-time = "2026-03-15T18:52:11.854Z" }, + { url = "https://files.pythonhosted.org/packages/fc/97/440635fc093b8d7347502a377031f9605a1039c958f3cd18dcacffb37743/charset_normalizer-3.4.6-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95b52c68d64c1878818687a473a10547b3292e82b6f6fe483808fb1468e2f52f", size = 215230, upload-time = "2026-03-15T18:52:13.325Z" }, + { url = "https://files.pythonhosted.org/packages/cd/24/afff630feb571a13f07c8539fbb502d2ab494019492aaffc78ef41f1d1d0/charset_normalizer-3.4.6-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:7504e9b7dc05f99a9bbb4525c67a2c155073b44d720470a148b34166a69c054e", size = 199045, upload-time = "2026-03-15T18:52:14.752Z" }, + { url = "https://files.pythonhosted.org/packages/e5/17/d1399ecdaf7e0498c327433e7eefdd862b41236a7e484355b8e0e5ebd64b/charset_normalizer-3.4.6-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:172985e4ff804a7ad08eebec0a1640ece87ba5041d565fff23c8f99c1f389484", size = 211658, upload-time = "2026-03-15T18:52:16.278Z" }, + { url = "https://files.pythonhosted.org/packages/b5/38/16baa0affb957b3d880e5ac2144caf3f9d7de7bc4a91842e447fbb5e8b67/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:4be9f4830ba8741527693848403e2c457c16e499100963ec711b1c6f2049b7c7", size = 210769, upload-time = "2026-03-15T18:52:17.782Z" }, + { url = "https://files.pythonhosted.org/packages/05/34/c531bc6ac4c21da9ddfddb3107be2287188b3ea4b53b70fc58f2a77ac8d8/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:79090741d842f564b1b2827c0b82d846405b744d31e84f18d7a7b41c20e473ff", size = 201328, upload-time = "2026-03-15T18:52:19.553Z" }, + { url = "https://files.pythonhosted.org/packages/fa/73/a5a1e9ca5f234519c1953608a03fe109c306b97fdfb25f09182babad51a7/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:87725cfb1a4f1f8c2fc9890ae2f42094120f4b44db9360be5d99a4c6b0e03a9e", size = 225302, upload-time = "2026-03-15T18:52:21.043Z" }, + { url = "https://files.pythonhosted.org/packages/ba/f6/cd782923d112d296294dea4bcc7af5a7ae0f86ab79f8fefbda5526b6cfc0/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:fcce033e4021347d80ed9c66dcf1e7b1546319834b74445f561d2e2221de5659", size = 211127, upload-time = "2026-03-15T18:52:22.491Z" }, + { url = "https://files.pythonhosted.org/packages/0e/c5/0b6898950627af7d6103a449b22320372c24c6feda91aa24e201a478d161/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:ca0276464d148c72defa8bb4390cce01b4a0e425f3b50d1435aa6d7a18107602", size = 222840, upload-time = "2026-03-15T18:52:24.113Z" }, + { url = "https://files.pythonhosted.org/packages/7d/25/c4bba773bef442cbdc06111d40daa3de5050a676fa26e85090fc54dd12f0/charset_normalizer-3.4.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:197c1a244a274bb016dd8b79204850144ef77fe81c5b797dc389327adb552407", size = 216890, upload-time = "2026-03-15T18:52:25.541Z" }, + { url = "https://files.pythonhosted.org/packages/35/1a/05dacadb0978da72ee287b0143097db12f2e7e8d3ffc4647da07a383b0b7/charset_normalizer-3.4.6-cp314-cp314t-win32.whl", hash = "sha256:2a24157fa36980478dd1770b585c0f30d19e18f4fb0c47c13aa568f871718579", size = 155379, upload-time = "2026-03-15T18:52:27.05Z" }, + { url = "https://files.pythonhosted.org/packages/5d/7a/d269d834cb3a76291651256f3b9a5945e81d0a49ab9f4a498964e83c0416/charset_normalizer-3.4.6-cp314-cp314t-win_amd64.whl", hash = "sha256:cd5e2801c89992ed8c0a3f0293ae83c159a60d9a5d685005383ef4caca77f2c4", size = 169043, upload-time = "2026-03-15T18:52:28.502Z" }, + { url = "https://files.pythonhosted.org/packages/23/06/28b29fba521a37a8932c6a84192175c34d49f84a6d4773fa63d05f9aff22/charset_normalizer-3.4.6-cp314-cp314t-win_arm64.whl", hash = "sha256:47955475ac79cc504ef2704b192364e51d0d473ad452caedd0002605f780101c", size = 148523, upload-time = "2026-03-15T18:52:29.956Z" }, + { url = "https://files.pythonhosted.org/packages/2a/68/687187c7e26cb24ccbd88e5069f5ef00eba804d36dde11d99aad0838ab45/charset_normalizer-3.4.6-py3-none-any.whl", hash = "sha256:947cf925bc916d90adba35a64c82aace04fa39b46b52d4630ece166655905a69", size = 61455, upload-time = "2026-03-15T18:53:23.833Z" }, ] [[package]] @@ -2684,11 +2700,11 @@ wheels = [ [[package]] name = "identify" -version = "2.6.17" +version = "2.6.18" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/84/376a3b96e5a8d33a7aa2c5b3b31a4b3c364117184bf0b17418055f6ace66/identify-2.6.17.tar.gz", hash = "sha256:f816b0b596b204c9fdf076ded172322f2723cf958d02f9c3587504834c8ff04d", size = 99579, upload-time = "2026-03-01T20:04:12.702Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/c4/7fb4db12296cdb11893d61c92048fe617ee853f8523b9b296ac03b43757e/identify-2.6.18.tar.gz", hash = "sha256:873ac56a5e3fd63e7438a7ecbc4d91aca692eb3fefa4534db2b7913f3fc352fd", size = 99580, upload-time = "2026-03-15T18:39:50.319Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl", hash = "sha256:be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0", size = 99382, upload-time = "2026-03-01T20:04:11.439Z" }, + { url = "https://files.pythonhosted.org/packages/46/33/92ef41c6fad0233e41d3d84ba8e8ad18d1780f1e5d99b3c683e6d7f98b63/identify-2.6.18-py2.py3-none-any.whl", hash = "sha256:8db9d3c8ea9079db92cafb0ebf97abdc09d52e97f4dcf773a2e694048b7cd737", size = 99394, upload-time = "2026-03-15T18:39:48.915Z" }, ] [[package]] From 82b016bad308d09f9654a0c5a16d1e83363f3265 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 22:21:43 +0000 Subject: [PATCH 30/39] feat: Improve parameter conversion for `POSITIONAL_PYFORMAT` to ensure distinct placeholders and add robust rollback handling to the migration tracker. --- sqlspec/core/parameters/_converter.py | 25 ++++----- sqlspec/migrations/tracker.py | 5 ++ .../test_parameter_styles.py | 20 ++++---- tests/unit/core/test_parameters.py | 31 +++++++++++ .../migrations/test_tracker_idempotency.py | 51 ++++++++++++++++--- 5 files changed, 104 insertions(+), 28 deletions(-) diff --git a/sqlspec/core/parameters/_converter.py b/sqlspec/core/parameters/_converter.py index 94bbbd0dd..50eb68931 100644 --- a/sqlspec/core/parameters/_converter.py +++ b/sqlspec/core/parameters/_converter.py @@ -87,8 +87,8 @@ def _is_positional_style(style: "ParameterStyle") -> bool: } -def _parameter_lookup_key(param: "ParameterInfo", use_sequential_for_qmark: bool) -> str: - if use_sequential_for_qmark and param.style == ParameterStyle.QMARK: +def _parameter_lookup_key(param: "ParameterInfo") -> str: + if param.style in {ParameterStyle.QMARK, ParameterStyle.POSITIONAL_PYFORMAT}: return f"{param.placeholder_text}_{param.ordinal}" return param.placeholder_text @@ -162,18 +162,16 @@ def convert_placeholder_style( def _build_conversion_plan( self, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" - ) -> "tuple[list[ParameterInfo], dict[str, int], bool]": + ) -> "tuple[list[ParameterInfo], dict[str, int]]": ordered_params = _ordered_parameter_info(param_info) - source_style = _single_parameter_style(ordered_params) - use_sequential_for_qmark = source_style == ParameterStyle.QMARK and target_style == ParameterStyle.NUMERIC unique_params: dict[str, int] = {} for param in ordered_params: - param_key = _parameter_lookup_key(param, use_sequential_for_qmark) + param_key = _parameter_lookup_key(param) if param_key not in unique_params: unique_params[param_key] = len(unique_params) - return ordered_params, unique_params, use_sequential_for_qmark + return ordered_params, unique_params def _convert_placeholders_to_style( self, sql: str, param_info: "list[ParameterInfo]", target_style: "ParameterStyle" @@ -183,7 +181,7 @@ def _convert_placeholders_to_style( msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) - ordered_params, unique_params, use_sequential_for_qmark = self._build_conversion_plan(param_info, target_style) + ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) placeholder_text_len_cache: dict[str, int] = {} # Build SQL using forward iteration with list join (O(n) vs O(n^2) string slicing) @@ -200,7 +198,7 @@ def _convert_placeholders_to_style( # Generate new placeholder based on target style if is_positional_style: - param_key = _parameter_lookup_key(param, use_sequential_for_qmark) + param_key = _parameter_lookup_key(param) new_placeholder = generator(unique_params[param_key]) else: param_name = _normalized_named_parameter_name(param) @@ -223,13 +221,14 @@ def convert_parameter_info_style( msg = f"Unsupported target parameter style: {target_style}" raise ValueError(msg) - ordered_params, unique_params, use_sequential_for_qmark = self._build_conversion_plan(param_info, target_style) + ordered_params, unique_params = self._build_conversion_plan(param_info, target_style) is_positional_style = _is_positional_style(target_style) converted_param_info: list[ParameterInfo] = [] + delta = 0 for param in ordered_params: if is_positional_style: - converted_index = unique_params[_parameter_lookup_key(param, use_sequential_for_qmark)] + converted_index = unique_params[_parameter_lookup_key(param)] placeholder_text = generator(converted_index) name = None if target_style in {ParameterStyle.NUMERIC, ParameterStyle.POSITIONAL_COLON}: @@ -238,15 +237,17 @@ def convert_parameter_info_style( name = _normalized_named_parameter_name(param) placeholder_text = generator(name) + converted_position = param.position + delta converted_param_info.append( ParameterInfo( name=name, style=target_style, - position=param.position, + position=converted_position, ordinal=param.ordinal, placeholder_text=placeholder_text, ) ) + delta += len(placeholder_text) - len(param.placeholder_text) return converted_param_info diff --git a/sqlspec/migrations/tracker.py b/sqlspec/migrations/tracker.py index 64224538e..c5f372cda 100644 --- a/sqlspec/migrations/tracker.py +++ b/sqlspec/migrations/tracker.py @@ -6,6 +6,7 @@ import logging import os from collections.abc import Mapping +from contextlib import suppress from typing import TYPE_CHECKING, Any from rich.console import Console @@ -92,6 +93,8 @@ def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: console.print("[green]Migration tracking table schema updated successfully[/]") except Exception as exc: + with suppress(Exception): + driver.rollback() log_with_context( logger, logging.ERROR, @@ -467,6 +470,8 @@ async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> N console.print("[green]Migration tracking table schema updated successfully[/]") except Exception as exc: + with suppress(Exception): + await driver.rollback() log_with_context( logger, logging.ERROR, diff --git a/tests/integration/adapters/cockroach_psycopg/test_parameter_styles.py b/tests/integration/adapters/cockroach_psycopg/test_parameter_styles.py index 510506bb8..5f186414b 100644 --- a/tests/integration/adapters/cockroach_psycopg/test_parameter_styles.py +++ b/tests/integration/adapters/cockroach_psycopg/test_parameter_styles.py @@ -258,7 +258,7 @@ def test_sql_injection_prevention( class TestAsyncNumericParameterStyle: """Test NUMERIC ($1, $2) parameter style (native) for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_numeric_single_parameter( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -271,7 +271,7 @@ async def test_numeric_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_numeric_multiple_parameters( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -289,7 +289,7 @@ async def test_numeric_multiple_parameters( class TestAsyncQmarkConversion: """Test QMARK (?) to NUMERIC ($1) conversion for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_qmark_single_parameter( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -302,7 +302,7 @@ async def test_qmark_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_qmark_multiple_parameters( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -319,7 +319,7 @@ async def test_qmark_multiple_parameters( class TestAsyncNamedColonConversion: """Test NAMED_COLON (:name) to NUMERIC ($1) conversion for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_named_colon_single_parameter( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -332,7 +332,7 @@ async def test_named_colon_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_named_colon_multiple_parameters( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -350,7 +350,7 @@ async def test_named_colon_multiple_parameters( class TestAsyncNamedPyformatConversion: """Test NAMED_PYFORMAT (%(name)s) to NUMERIC ($1) conversion for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_named_pyformat_parameters( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -368,7 +368,7 @@ async def test_named_pyformat_parameters( class TestAsyncSQLObject: """Test parameter conversion with SQL objects for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_sql_object_with_qmark( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -383,7 +383,7 @@ async def test_sql_object_with_qmark( class TestAsyncExecuteMany: """Test parameter conversion with execute_many for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_execute_many_with_numeric( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: @@ -401,7 +401,7 @@ async def test_execute_many_with_numeric( class TestAsyncEdgeCases: """Test edge cases for async driver.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_sql_injection_prevention( self, cockroach_psycopg_async_parameter_session: CockroachPsycopgAsyncDriver ) -> None: diff --git a/tests/unit/core/test_parameters.py b/tests/unit/core/test_parameters.py index 84315793a..bdce468cf 100644 --- a/tests/unit/core/test_parameters.py +++ b/tests/unit/core/test_parameters.py @@ -1625,6 +1625,22 @@ def test_duplicate_parameters_qmark_to_numeric(converter: ParameterConverter) -> assert "$2" in converted_sql +def test_distinct_positional_pyformat_parameters_to_numeric(converter: ParameterConverter) -> None: + """Repeated ``%s`` placeholders should remain distinct when normalized.""" + sql = "SELECT * FROM table WHERE col1 = %s AND col2 = %s" + + converted_sql, _ = converter.convert_placeholder_style(sql, None, ParameterStyle.NUMERIC) + converted_info = converter.convert_parameter_info_style( + converter.validator.extract_parameters(sql), ParameterStyle.NUMERIC + ) + + assert converted_sql == "SELECT * FROM table WHERE col1 = $1 AND col2 = $2" + assert [(param.name, param.position, param.placeholder_text) for param in converted_info] == [ + ("1", 33, "$1"), + ("2", 47, "$2"), + ] + + def test_vector_similarity_search_example(converter: ParameterConverter) -> None: """Test the exact example from the bug report.""" sql = """SELECT @@ -2341,6 +2357,21 @@ def test_psycopg_pyformat_normalized_for_parsing( assert result.sqlglot_sql is not None assert "%(name)s" not in result.sqlglot_sql + def test_psycopg_positional_pyformat_preserves_distinct_positions( + self, processor: ParameterProcessor, psycopg_config: ParameterStyleConfig | None + ) -> None: + """Psycopg: repeated ``%s`` placeholders should normalize to distinct numeric positions.""" + if psycopg_config is None: + pytest.skip("psycopg adapter not available") + + sql = "SELECT * FROM t WHERE name = ? AND value > ?" + params = ["test", 100] + + result = processor.process(sql, params, psycopg_config, dialect="postgres") + + assert result.sql == "SELECT * FROM t WHERE name = %s AND value > %s" + assert result.sqlglot_sql == "SELECT * FROM t WHERE name = $1 AND value > $2" + class TestMySQLAdaptersBehavior: """Test MySQL adapter parameter handling.""" diff --git a/tests/unit/migrations/test_tracker_idempotency.py b/tests/unit/migrations/test_tracker_idempotency.py index 3d533c5f3..500b97107 100644 --- a/tests/unit/migrations/test_tracker_idempotency.py +++ b/tests/unit/migrations/test_tracker_idempotency.py @@ -120,7 +120,25 @@ def test_sync_update_version_record_no_commit_on_idempotent_path() -> None: driver.commit.assert_not_called() -@pytest.mark.asyncio +def test_sync_schema_migration_rolls_back_on_failure() -> None: + """Schema migration should rollback after a failed column update.""" + tracker = SyncMigrationTracker() + driver = Mock() + driver.driver_features = {} + driver.data_dictionary.get_columns.return_value = [{"column_name": "version"}] + driver.rollback = Mock() + + def raise_on_add(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("boom") + + tracker._add_column = Mock(side_effect=raise_on_add) # type: ignore[method-assign] + + tracker._migrate_schema_if_needed(driver) # pyright: ignore[reportPrivateUsage] + + driver.rollback.assert_called_once() + + +@pytest.mark.anyio async def test_async_update_version_record_success() -> None: """Test async update succeeds when old version exists.""" from unittest.mock import AsyncMock @@ -144,7 +162,7 @@ async def mock_execute(sql: Any) -> Mock: assert "ddl_migrations" in update_sql -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_update_version_record_idempotent_when_already_updated() -> None: """Test async update is idempotent when version already exists.""" from unittest.mock import AsyncMock @@ -177,7 +195,28 @@ async def mock_execute(sql: Any) -> Mock: assert driver.execute.call_count == 2 -@pytest.mark.asyncio +@pytest.mark.anyio +async def test_async_schema_migration_rolls_back_on_failure() -> None: + """Async schema migration should rollback after a failed column update.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + driver.driver_features = {} + driver.data_dictionary.get_columns = AsyncMock(return_value=[{"column_name": "version"}]) + driver.rollback = AsyncMock() + + async def raise_on_add(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("boom") + + tracker._add_column = AsyncMock(side_effect=raise_on_add) # type: ignore[method-assign] + + await tracker._migrate_schema_if_needed(driver) # pyright: ignore[reportPrivateUsage] + + driver.rollback.assert_awaited_once() + + +@pytest.mark.anyio async def test_async_update_version_record_raises_when_neither_version_exists() -> None: """Test async update raises ValueError when neither old nor new version exists.""" from unittest.mock import AsyncMock @@ -206,7 +245,7 @@ async def mock_execute(sql: Any) -> Mock: await tracker.update_version_record(driver, "20251011120000", "0001") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_update_version_record_empty_database() -> None: """Test async update raises when database is empty.""" from unittest.mock import AsyncMock @@ -235,7 +274,7 @@ async def mock_execute(sql: Any) -> Mock: await tracker.update_version_record(driver, "20251011120000", "0001") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_update_version_record_commits_after_success() -> None: """Test async update commits transaction after successful update.""" from unittest.mock import AsyncMock @@ -259,7 +298,7 @@ async def mock_execute(sql: Any) -> Mock: driver.commit.assert_called_once() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_update_version_record_no_commit_on_idempotent_path() -> None: """Test async update does not commit when taking idempotent path.""" from unittest.mock import AsyncMock From 15c0ac8f0115b902615fc42e9c2cab6995ad5133 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 22:28:33 +0000 Subject: [PATCH 31/39] test(mypyc): derive inventory summary from current surface --- tests/unit/utils/test_mypyc_inventory.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_mypyc_inventory.py b/tests/unit/utils/test_mypyc_inventory.py index 949383df6..7ec52d10a 100644 --- a/tests/unit/utils/test_mypyc_inventory.py +++ b/tests/unit/utils/test_mypyc_inventory.py @@ -19,8 +19,23 @@ def test_build_inventory_reports_current_compiled_surface() -> None: module = _load_mypyc_inventory_module() inventory = module.build_inventory() + project_root = Path(__file__).resolve().parents[3] + include_patterns, exclude_patterns = module.load_mypyc_patterns(project_root) + modules = module.list_sqlspec_modules(project_root) + compiled = [ + path for path in modules if module.classify_module(path, include_patterns, exclude_patterns) == "compiled" + ] + interpreted = [ + path for path in modules if module.classify_module(path, include_patterns, exclude_patterns) == "interpreted" + ] - assert inventory["summary"] == {"compiled_count": 60, "interpreted_count": 335, "total_modules": 395} + assert inventory["summary"] == { + "compiled_count": len(compiled), + "interpreted_count": len(interpreted), + "total_modules": len(modules), + } + assert inventory["compiled_modules"] == compiled + assert inventory["interpreted_modules"] == interpreted hot_surfaces = inventory["hot_surfaces"] assert hot_surfaces["sqlspec/config.py"]["status"] == "interpreted" From 2918aff1b71793feab1ed420511566b24d2cb525 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 15 Mar 2026 23:52:47 +0000 Subject: [PATCH 32/39] feat: Explicitly define session factory, session context, and default statement configuration as class variables in adapter configs, and update OracleDB cursor type aliases. --- sqlspec/adapters/adbc/config.py | 4 + sqlspec/adapters/aiosqlite/__init__.py | 3 +- sqlspec/adapters/aiosqlite/_typing.py | 10 +-- sqlspec/adapters/aiosqlite/config.py | 8 +- sqlspec/adapters/aiosqlite/driver.py | 16 ++-- sqlspec/adapters/asyncmy/_typing.py | 12 +-- sqlspec/adapters/asyncmy/config.py | 8 +- sqlspec/adapters/asyncpg/config.py | 6 +- sqlspec/adapters/bigquery/config.py | 4 + sqlspec/adapters/cockroach_asyncpg/config.py | 6 +- sqlspec/adapters/cockroach_psycopg/config.py | 30 +++++--- sqlspec/adapters/duckdb/config.py | 8 +- sqlspec/adapters/mock/_typing.py | 14 ++-- sqlspec/adapters/mock/config.py | 8 ++ sqlspec/adapters/mysqlconnector/_typing.py | 33 ++++---- sqlspec/adapters/mysqlconnector/config.py | 18 +++-- sqlspec/adapters/oracledb/_typing.py | 20 ++--- sqlspec/adapters/oracledb/config.py | 18 ++--- sqlspec/adapters/psqlpy/config.py | 6 +- sqlspec/adapters/psycopg/_typing.py | 16 ++-- sqlspec/adapters/psycopg/config.py | 16 +++- sqlspec/adapters/pymysql/_typing.py | 8 +- sqlspec/adapters/pymysql/config.py | 8 +- sqlspec/adapters/spanner/config.py | 4 + sqlspec/adapters/sqlite/_typing.py | 10 +-- sqlspec/adapters/sqlite/config.py | 8 +- sqlspec/storage/backends/base.py | 80 ++++++++------------ 27 files changed, 221 insertions(+), 161 deletions(-) diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index 1668e0b82..01078b2ad 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -180,6 +180,10 @@ class AdbcConfig(NoPoolSyncConfig[AdbcConnection, AdbcDriver]): supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk") + _connection_context_class: "ClassVar[type[AdbcConnectionContext]]" = AdbcConnectionContext + _session_factory_class: "ClassVar[type[_AdbcSessionConnectionHandler]]" = _AdbcSessionConnectionHandler + _session_context_class: "ClassVar[type[AdbcSessionContext]]" = AdbcSessionContext + _default_statement_config = StatementConfig() def __init__( self, diff --git a/sqlspec/adapters/aiosqlite/__init__.py b/sqlspec/adapters/aiosqlite/__init__.py index f6aad1b27..b0c6812f9 100644 --- a/sqlspec/adapters/aiosqlite/__init__.py +++ b/sqlspec/adapters/aiosqlite/__init__.py @@ -1,4 +1,4 @@ -from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection, AiosqliteCursor +from sqlspec.adapters.aiosqlite._typing import AiosqliteConnection, AiosqliteCursor, AiosqliteRawCursor from sqlspec.adapters.aiosqlite.config import AiosqliteConfig, AiosqliteConnectionParams, AiosqlitePoolParams from sqlspec.adapters.aiosqlite.core import default_statement_config from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver, AiosqliteExceptionHandler @@ -21,5 +21,6 @@ "AiosqlitePoolClosedError", "AiosqlitePoolConnection", "AiosqlitePoolParams", + "AiosqliteRawCursor", "default_statement_config", ) diff --git a/sqlspec/adapters/aiosqlite/_typing.py b/sqlspec/adapters/aiosqlite/_typing.py index a979859fd..8a4635e3e 100644 --- a/sqlspec/adapters/aiosqlite/_typing.py +++ b/sqlspec/adapters/aiosqlite/_typing.py @@ -21,11 +21,11 @@ from sqlspec.core import StatementConfig AiosqliteConnection: TypeAlias = _AiosqliteConnection - AiosqliteCursorType: TypeAlias = aiosqlite.Cursor + AiosqliteRawCursor: TypeAlias = aiosqlite.Cursor if not TYPE_CHECKING: AiosqliteConnection = _AiosqliteConnection - AiosqliteCursorType = aiosqlite.Cursor + AiosqliteRawCursor = aiosqlite.Cursor class AiosqliteCursor: @@ -35,9 +35,9 @@ class AiosqliteCursor: def __init__(self, connection: "AiosqliteConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: AiosqliteRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "AiosqliteRawCursor": self.cursor = await self.connection.cursor() return self.cursor @@ -106,4 +106,4 @@ async def __aexit__( return None -__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteCursorType", "AiosqliteSessionContext") +__all__ = ("AiosqliteConnection", "AiosqliteCursor", "AiosqliteRawCursor", "AiosqliteSessionContext") diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 703bcfe38..09c040f72 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -139,15 +139,15 @@ class AiosqliteConfig(AsyncDatabaseConfig["AiosqliteConnection", AiosqliteConnec driver_type: "ClassVar[type[AiosqliteDriver]]" = AiosqliteDriver connection_type: "ClassVar[type[AiosqliteConnection]]" = AiosqliteConnection - _connection_context_class: "ClassVar[type[AiosqliteConnectionContext]]" = AiosqliteConnectionContext - _session_factory_class: "ClassVar[type[_AiosqliteSessionFactory]]" = _AiosqliteSessionFactory - _session_context_class: "ClassVar[type[AiosqliteSessionContext]]" = AiosqliteSessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[AiosqliteConnectionContext]]" = AiosqliteConnectionContext + _session_factory_class: "ClassVar[type[_AiosqliteSessionFactory]]" = _AiosqliteSessionFactory + _session_context_class: "ClassVar[type[AiosqliteSessionContext]]" = AiosqliteSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/aiosqlite/driver.py b/sqlspec/adapters/aiosqlite/driver.py index 9807261e1..99dd7b96e 100644 --- a/sqlspec/adapters/aiosqlite/driver.py +++ b/sqlspec/adapters/aiosqlite/driver.py @@ -7,7 +7,7 @@ import aiosqlite -from sqlspec.adapters.aiosqlite._typing import AiosqliteCursor, AiosqliteSessionContext +from sqlspec.adapters.aiosqlite._typing import AiosqliteCursor, AiosqliteRawCursor, AiosqliteSessionContext from sqlspec.adapters.aiosqlite.core import ( build_insert_statement, collect_rows, @@ -30,7 +30,13 @@ from sqlspec.driver import ExecutionResult from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry -__all__ = ("AiosqliteCursor", "AiosqliteDriver", "AiosqliteExceptionHandler", "AiosqliteSessionContext") +__all__ = ( + "AiosqliteCursor", + "AiosqliteDriver", + "AiosqliteExceptionHandler", + "AiosqliteRawCursor", + "AiosqliteSessionContext", +) SQLITE_CONSTRAINT_UNIQUE_CODE = 2067 SQLITE_CONSTRAINT_FOREIGNKEY_CODE = 787 @@ -87,7 +93,7 @@ def __init__( # CORE DISPATCH METHODS # ───────────────────────────────────────────────────────────────────────────── - async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute(self, cursor: "AiosqliteRawCursor", statement: "SQL") -> "ExecutionResult": """Execute single SQL statement.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) @@ -111,7 +117,7 @@ async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionRes affected_rows = resolve_rowcount(cursor) return self.create_execution_result(cursor, rowcount_override=affected_rows) - async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_many(self, cursor: "AiosqliteRawCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL with multiple parameter sets.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) @@ -121,7 +127,7 @@ async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "Executi return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) - async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + async def dispatch_execute_script(self, cursor: "AiosqliteRawCursor", statement: "SQL") -> "ExecutionResult": """Execute SQL script.""" sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) diff --git a/sqlspec/adapters/asyncmy/_typing.py b/sqlspec/adapters/asyncmy/_typing.py index 4211f0f1e..f1e601c66 100644 --- a/sqlspec/adapters/asyncmy/_typing.py +++ b/sqlspec/adapters/asyncmy/_typing.py @@ -18,7 +18,7 @@ from sqlspec.core import StatementConfig class AsyncmyConnectionProtocol(Protocol): - def cursor(self) -> Any: ... + def cursor(self) -> "AsyncmyRawCursor": ... async def commit(self) -> Any: ... @@ -27,11 +27,11 @@ async def rollback(self) -> Any: ... async def close(self) -> Any: ... AsyncmyConnection: TypeAlias = AsyncmyConnectionProtocol - AsyncmyCursorType: TypeAlias = _AsyncmyCursor + AsyncmyRawCursor: TypeAlias = _AsyncmyCursor if not TYPE_CHECKING: AsyncmyConnection = Connection - AsyncmyCursorType = _AsyncmyCursor + AsyncmyRawCursor = _AsyncmyCursor class AsyncmyCursor: @@ -44,9 +44,9 @@ class AsyncmyCursor: def __init__(self, connection: "AsyncmyConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: AsyncmyRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "AsyncmyRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -110,4 +110,4 @@ async def __aexit__( return None -__all__ = ("AsyncmyConnection", "AsyncmyCursor", "AsyncmyCursorType", "AsyncmySessionContext") +__all__ = ("AsyncmyConnection", "AsyncmyCursor", "AsyncmyRawCursor", "AsyncmySessionContext") diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 4e6a9e3da..e887ad128 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -151,15 +151,15 @@ class AsyncmyConfig(AsyncDatabaseConfig[AsyncmyConnection, "AsyncmyPool", Asyncm driver_type: ClassVar[type[AsyncmyDriver]] = AsyncmyDriver connection_type: "ClassVar[type[Any]]" = cast("type[Any]", AsyncmyConnection) - _connection_context_class: "ClassVar[type[AsyncmyConnectionContext]]" = AsyncmyConnectionContext - _session_factory_class: "ClassVar[type[_AsyncmySessionFactory]]" = _AsyncmySessionFactory - _session_context_class: "ClassVar[type[AsyncmySessionContext]]" = AsyncmySessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[AsyncmyConnectionContext]]" = AsyncmyConnectionContext + _session_factory_class: "ClassVar[type[_AsyncmySessionFactory]]" = _AsyncmySessionFactory + _session_context_class: "ClassVar[type[AsyncmySessionContext]]" = AsyncmySessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 46e429246..c3a6800e1 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -288,13 +288,15 @@ class AsyncpgConfig(AsyncDatabaseConfig[AsyncpgConnection, "Pool[Record]", Async driver_type: "ClassVar[type[AsyncpgDriver]]" = AsyncpgDriver connection_type: "ClassVar[type[AsyncpgConnection]]" = type(AsyncpgConnection) # type: ignore[assignment] - _connection_context_class: "ClassVar[type[AsyncpgConnectionContext]]" = AsyncpgConnectionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[AsyncpgConnectionContext]]" = AsyncpgConnectionContext + _session_factory_class: "ClassVar[type[_AsyncpgSessionFactory]]" = _AsyncpgSessionFactory + _session_context_class: "ClassVar[type[AsyncpgSessionContext]]" = AsyncpgSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index 6ecdbd5f0..510e8635e 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -160,6 +160,10 @@ class BigQueryConfig(NoPoolSyncConfig[BigQueryConnection, BigQueryDriver]): supports_native_parquet_export: ClassVar[bool] = True requires_staging_for_load: ClassVar[bool] = True staging_protocols: "ClassVar[tuple[str, ...]]" = ("gs://",) + _connection_context_class: "ClassVar[type[BigQueryConnectionContext]]" = BigQueryConnectionContext + _session_factory_class: "ClassVar[type[_BigQuerySessionConnectionHandler]]" = _BigQuerySessionConnectionHandler + _session_context_class: "ClassVar[type[BigQuerySessionContext]]" = BigQuerySessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/cockroach_asyncpg/config.py b/sqlspec/adapters/cockroach_asyncpg/config.py index 1f1d97c16..78e8de5ea 100644 --- a/sqlspec/adapters/cockroach_asyncpg/config.py +++ b/sqlspec/adapters/cockroach_asyncpg/config.py @@ -144,13 +144,15 @@ class CockroachAsyncpgConfig( driver_type: "ClassVar[type[CockroachAsyncpgDriver]]" = CockroachAsyncpgDriver connection_type: "ClassVar[type[CockroachAsyncpgConnection]]" = CockroachAsyncpgConnection # type: ignore[assignment] - _connection_context_class: "ClassVar[type[CockroachAsyncpgConnectionContext]]" = CockroachAsyncpgConnectionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[CockroachAsyncpgConnectionContext]]" = CockroachAsyncpgConnectionContext + _session_factory_class: "ClassVar[type[_CockroachAsyncpgSessionFactory]]" = _CockroachAsyncpgSessionFactory + _session_context_class: "ClassVar[type[CockroachAsyncpgSessionContext]]" = CockroachAsyncpgSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/cockroach_psycopg/config.py b/sqlspec/adapters/cockroach_psycopg/config.py index 83f1b490f..30b56810f 100644 --- a/sqlspec/adapters/cockroach_psycopg/config.py +++ b/sqlspec/adapters/cockroach_psycopg/config.py @@ -43,6 +43,8 @@ "CockroachPsycopgSyncConfig", ) +default_statement_config = build_statement_config() + class CockroachPsycopgConnectionConfig(TypedDict): """CockroachDB connection parameters.""" @@ -163,14 +165,19 @@ class CockroachPsycopgSyncConfig( driver_type: "ClassVar[type[CockroachPsycopgSyncDriver]]" = CockroachPsycopgSyncDriver connection_type: "ClassVar[type[CockroachSyncConnection]]" = CockroachSyncConnection - _connection_context_class: "ClassVar[type[CockroachPsycopgSyncConnectionContext]]" = ( - CockroachPsycopgSyncConnectionContext - ) supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[CockroachPsycopgSyncConnectionContext]]" = ( + CockroachPsycopgSyncConnectionContext + ) + _session_factory_class: "ClassVar[type[_CockroachPsycopgSyncSessionConnectionHandler]]" = ( + _CockroachPsycopgSyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[CockroachPsycopgSyncSessionContext]]" = CockroachPsycopgSyncSessionContext + _default_statement_config = default_statement_config def __init__( self, @@ -186,7 +193,7 @@ def __init__( **kwargs: Any, ) -> None: connection_config = normalize_connection_config(connection_config) - statement_config = statement_config or build_statement_config() + statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) driver_features.setdefault("enable_auto_retry", True) @@ -278,7 +285,7 @@ def provide_session( return CockroachPsycopgSyncSessionContext( acquire_connection=handler.acquire_connection, release_connection=handler.release_connection, - statement_config=statement_config or self.statement_config or build_statement_config(), + statement_config=statement_config or self.statement_config or default_statement_config, driver_features=driver_features, prepare_driver=self._prepare_driver, ) @@ -359,14 +366,19 @@ class CockroachPsycopgAsyncConfig( driver_type: "ClassVar[type[CockroachPsycopgAsyncDriver]]" = CockroachPsycopgAsyncDriver connection_type: "ClassVar[type[CockroachAsyncConnection]]" = CockroachAsyncConnection - _connection_context_class: "ClassVar[type[CockroachPsycopgAsyncConnectionContext]]" = ( - CockroachPsycopgAsyncConnectionContext - ) supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[CockroachPsycopgAsyncConnectionContext]]" = ( + CockroachPsycopgAsyncConnectionContext + ) + _session_factory_class: "ClassVar[type[_CockroachPsycopgAsyncSessionConnectionHandler]]" = ( + _CockroachPsycopgAsyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[CockroachPsycopgAsyncSessionContext]]" = CockroachPsycopgAsyncSessionContext + _default_statement_config = default_statement_config def __init__( self, @@ -382,7 +394,7 @@ def __init__( **kwargs: Any, ) -> None: connection_config = normalize_connection_config(connection_config) - statement_config = statement_config or build_statement_config() + statement_config = statement_config or default_statement_config statement_config, driver_features = apply_driver_features(statement_config, driver_features) driver_features.setdefault("enable_auto_retry", True) diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index f6714a5a2..ae6cdce6d 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -252,16 +252,16 @@ class DuckDBConfig(SyncDatabaseConfig[DuckDBConnection, DuckDBConnectionPool, Du driver_type: "ClassVar[type[DuckDBDriver]]" = DuckDBDriver connection_type: "ClassVar[type[DuckDBConnection]]" = DuckDBConnection - _connection_context_class: "ClassVar[type[DuckDBConnectionContext]]" = DuckDBConnectionContext - _session_factory_class: "ClassVar[type[_DuckDBSessionConnectionHandler]]" = _DuckDBSessionConnectionHandler - _session_context_class: "ClassVar[type[DuckDBSessionContext]]" = DuckDBSessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True storage_partition_strategies: "ClassVar[tuple[str, ...]]" = ("fixed", "rows_per_chunk", "manifest") + _connection_context_class: "ClassVar[type[DuckDBConnectionContext]]" = DuckDBConnectionContext + _session_factory_class: "ClassVar[type[_DuckDBSessionConnectionHandler]]" = _DuckDBSessionConnectionHandler + _session_context_class: "ClassVar[type[DuckDBSessionContext]]" = DuckDBSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/mock/_typing.py b/sqlspec/adapters/mock/_typing.py index 2bc615a47..b347befef 100644 --- a/sqlspec/adapters/mock/_typing.py +++ b/sqlspec/adapters/mock/_typing.py @@ -39,20 +39,22 @@ def __init__(self, connection: "MockConnection") -> None: Args: connection: SQLite database connection + """ self.connection = connection - self.cursor: Any = None + self.cursor: MockRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "MockRawCursor": """Create and return a new cursor. Returns: Active SQLite cursor object + """ self.cursor = self.connection.cursor() return self.cursor - def __exit__(self, *_: Any) -> None: + def __exit__(self, *_: object) -> None: """Clean up cursor resources.""" if self.cursor is not None: with contextlib.suppress(Exception): @@ -69,15 +71,17 @@ def __init__(self, connection: "MockConnection") -> None: Args: connection: SQLite database connection + """ self.connection = connection - self.cursor: Any = None + self.cursor: MockRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "MockRawCursor": """Create and return a new cursor. Returns: Active SQLite cursor object + """ self.cursor = self.connection.cursor() return self.cursor diff --git a/sqlspec/adapters/mock/config.py b/sqlspec/adapters/mock/config.py index b4accaab1..c2ef5e063 100644 --- a/sqlspec/adapters/mock/config.py +++ b/sqlspec/adapters/mock/config.py @@ -171,6 +171,10 @@ class MockSyncConfig(NoPoolSyncConfig["MockConnection", "MockSyncDriver"]): supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[MockSyncConnectionContext]]" = MockSyncConnectionContext + _session_factory_class: "ClassVar[type[_MockSyncSessionFactory]]" = _MockSyncSessionFactory + _session_context_class: "ClassVar[type[MockSyncSessionContext]]" = MockSyncSessionContext + _default_statement_config = default_statement_config def __init__( self, @@ -338,6 +342,10 @@ class MockAsyncConfig(NoPoolAsyncConfig["MockConnection", "MockAsyncDriver"]): supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[MockAsyncConnectionContext]]" = MockAsyncConnectionContext + _session_factory_class: "ClassVar[type[_MockAsyncSessionFactory]]" = _MockAsyncSessionFactory + _session_context_class: "ClassVar[type[MockAsyncSessionContext]]" = MockAsyncSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/mysqlconnector/_typing.py b/sqlspec/adapters/mysqlconnector/_typing.py index e625f3f8b..d90a6723f 100644 --- a/sqlspec/adapters/mysqlconnector/_typing.py +++ b/sqlspec/adapters/mysqlconnector/_typing.py @@ -7,17 +7,16 @@ from typing import TYPE_CHECKING, Any from mysql.connector import MySQLConnection as _MysqlConnectorSyncConnection - -try: - from mysql.connector.aio import ( - MySQLConnection as _MysqlConnectorAsyncConnection, # pyright: ignore[reportMissingImports] - ) -except ImportError: # pragma: no cover - optional async import - _MysqlConnectorAsyncConnection = _MysqlConnectorSyncConnection # type: ignore[assignment,misc] - +from mysql.connector.aio import ( + MySQLConnection as _MysqlConnectorAsyncConnection, # pyright: ignore[reportMissingImports] +) +from mysql.connector.aio.cursor import ( + MySQLCursor as _MysqlConnectorAsyncRawCursor, # pyright: ignore[reportMissingImports] +) +from mysql.connector.cursor import MySQLCursor as _MysqlConnectorSyncRawCursor if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable from types import TracebackType from typing import Protocol, TypeAlias @@ -25,7 +24,7 @@ from sqlspec.core import StatementConfig class MysqlConnectorAsyncConnectionProtocol(Protocol): - def cursor(self, **kwargs: Any) -> Any: ... + def cursor(self, **kwargs: Any) -> "Awaitable[MysqlConnectorAsyncRawCursor]": ... async def commit(self) -> Any: ... @@ -35,10 +34,14 @@ async def close(self) -> Any: ... MysqlConnectorSyncConnection: TypeAlias = _MysqlConnectorSyncConnection MysqlConnectorAsyncConnection: TypeAlias = MysqlConnectorAsyncConnectionProtocol + MysqlConnectorSyncRawCursor: TypeAlias = _MysqlConnectorSyncRawCursor + MysqlConnectorAsyncRawCursor: TypeAlias = _MysqlConnectorAsyncRawCursor if not TYPE_CHECKING: MysqlConnectorSyncConnection = _MysqlConnectorSyncConnection MysqlConnectorAsyncConnection = _MysqlConnectorAsyncConnection + MysqlConnectorSyncRawCursor = _MysqlConnectorSyncRawCursor + MysqlConnectorAsyncRawCursor = _MysqlConnectorAsyncRawCursor class MysqlConnectorSyncCursor: @@ -48,9 +51,9 @@ class MysqlConnectorSyncCursor: def __init__(self, connection: "MysqlConnectorSyncConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: MysqlConnectorSyncRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "MysqlConnectorSyncRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -66,9 +69,9 @@ class MysqlConnectorAsyncCursor: def __init__(self, connection: "MysqlConnectorAsyncConnection") -> None: self.connection = connection - self.cursor: Any | None = None + self.cursor: MysqlConnectorAsyncRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "MysqlConnectorAsyncRawCursor": self.cursor = await self.connection.cursor() return self.cursor @@ -174,8 +177,10 @@ async def __aexit__( __all__ = ( "MysqlConnectorAsyncConnection", "MysqlConnectorAsyncCursor", + "MysqlConnectorAsyncRawCursor", "MysqlConnectorAsyncSessionContext", "MysqlConnectorSyncConnection", "MysqlConnectorSyncCursor", + "MysqlConnectorSyncRawCursor", "MysqlConnectorSyncSessionContext", ) diff --git a/sqlspec/adapters/mysqlconnector/config.py b/sqlspec/adapters/mysqlconnector/config.py index 63dd3e062..0aa4a5cad 100644 --- a/sqlspec/adapters/mysqlconnector/config.py +++ b/sqlspec/adapters/mysqlconnector/config.py @@ -221,6 +221,11 @@ class MysqlConnectorSyncConfig( driver_type: ClassVar[type[MysqlConnectorSyncDriver]] = MysqlConnectorSyncDriver connection_type: ClassVar[type[MysqlConnectorSyncConnection]] = MysqlConnectorSyncConnection + supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True _connection_context_class: "ClassVar[type[MysqlConnectorSyncConnectionContext]]" = ( MysqlConnectorSyncConnectionContext ) @@ -229,11 +234,6 @@ class MysqlConnectorSyncConfig( ) _session_context_class: "ClassVar[type[MysqlConnectorSyncSessionContext]]" = MysqlConnectorSyncSessionContext _default_statement_config = default_statement_config - supports_transactional_ddl: ClassVar[bool] = False - supports_native_arrow_export: ClassVar[bool] = True - supports_native_parquet_export: ClassVar[bool] = True - supports_native_arrow_import: ClassVar[bool] = True - supports_native_parquet_import: ClassVar[bool] = True def __init__( self, @@ -331,6 +331,14 @@ class MysqlConnectorAsyncConfig(NoPoolAsyncConfig[MysqlConnectorAsyncConnection, supports_native_parquet_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[MysqlConnectorAsyncConnectionContext]]" = ( + MysqlConnectorAsyncConnectionContext + ) + _session_factory_class: "ClassVar[type[_MysqlConnectorAsyncSessionConnectionHandler]]" = ( + _MysqlConnectorAsyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[MysqlConnectorAsyncSessionContext]]" = MysqlConnectorAsyncSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/oracledb/_typing.py b/sqlspec/adapters/oracledb/_typing.py index c58304d2b..a30b0205b 100644 --- a/sqlspec/adapters/oracledb/_typing.py +++ b/sqlspec/adapters/oracledb/_typing.py @@ -25,8 +25,8 @@ OracleAsyncConnection: TypeAlias = AsyncConnection OracleSyncConnectionPool: TypeAlias = ConnectionPool OracleAsyncConnectionPool: TypeAlias = AsyncConnectionPool - OracleSyncCursorType: TypeAlias = Cursor - OracleAsyncCursorType: TypeAlias = AsyncCursor + OracleSyncRawCursor: TypeAlias = Cursor + OracleAsyncRawCursor: TypeAlias = AsyncCursor OracleVectorType: TypeAlias = int if not TYPE_CHECKING: @@ -42,8 +42,8 @@ OracleAsyncConnection = AsyncConnection OracleSyncConnectionPool = ConnectionPool OracleAsyncConnectionPool = AsyncConnectionPool - OracleSyncCursorType = Cursor - OracleAsyncCursorType = AsyncCursor + OracleSyncRawCursor = Cursor + OracleAsyncRawCursor = AsyncCursor class OracleSyncCursor: @@ -53,9 +53,9 @@ class OracleSyncCursor: def __init__(self, connection: OracleSyncConnection) -> None: self.connection = connection - self.cursor: Any = None + self.cursor: OracleSyncRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "OracleSyncRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -71,9 +71,9 @@ class OracleAsyncCursor: def __init__(self, connection: OracleAsyncConnection) -> None: self.connection = connection - self.cursor: Any = None + self.cursor: OracleAsyncRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "OracleAsyncRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -221,13 +221,13 @@ async def __aexit__( "OracleAsyncConnection", "OracleAsyncConnectionPool", "OracleAsyncCursor", - "OracleAsyncCursorType", + "OracleAsyncRawCursor", "OracleAsyncSessionContext", "OraclePipelineDriver", "OracleSyncConnection", "OracleSyncConnectionPool", "OracleSyncCursor", - "OracleSyncCursorType", + "OracleSyncRawCursor", "OracleSyncSessionContext", "OracleVectorType", ) diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index d21f9415a..033edc728 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -181,15 +181,15 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConne driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker - _connection_context_class: "ClassVar[type[OracleSyncConnectionContext]]" = OracleSyncConnectionContext - _session_factory_class: "ClassVar[type[_OracleSyncSessionConnectionHandler]]" = _OracleSyncSessionConnectionHandler - _session_context_class: "ClassVar[type[OracleSyncSessionContext]]" = OracleSyncSessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: ClassVar[bool] = False supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[OracleSyncConnectionContext]]" = OracleSyncConnectionContext + _session_factory_class: "ClassVar[type[_OracleSyncSessionConnectionHandler]]" = _OracleSyncSessionConnectionHandler + _session_context_class: "ClassVar[type[OracleSyncSessionContext]]" = OracleSyncSessionContext + _default_statement_config = default_statement_config def __init__( self, @@ -379,17 +379,17 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncC connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker + supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True _connection_context_class: "ClassVar[type[OracleAsyncConnectionContext]]" = OracleAsyncConnectionContext _session_factory_class: "ClassVar[type[_OracleAsyncSessionConnectionHandler]]" = ( _OracleAsyncSessionConnectionHandler ) _session_context_class: "ClassVar[type[OracleAsyncSessionContext]]" = OracleAsyncSessionContext _default_statement_config = default_statement_config - supports_transactional_ddl: ClassVar[bool] = False - supports_native_arrow_export: ClassVar[bool] = True - supports_native_arrow_import: ClassVar[bool] = True - supports_native_parquet_export: ClassVar[bool] = True - supports_native_parquet_import: ClassVar[bool] = True def __init__( self, diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index 2d9d64926..6766ef563 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -180,13 +180,15 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD driver_type: ClassVar[type[PsqlpyDriver]] = PsqlpyDriver connection_type: "ClassVar[type[PsqlpyConnection]]" = PsqlpyConnection - _connection_context_class: "ClassVar[type[PsqlpyConnectionContext]]" = PsqlpyConnectionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[PsqlpyConnectionContext]]" = PsqlpyConnectionContext + _session_factory_class: "ClassVar[type[_PsqlpySessionFactory]]" = _PsqlpySessionFactory + _session_context_class: "ClassVar[type[PsqlpySessionContext]]" = PsqlpySessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/psycopg/_typing.py b/sqlspec/adapters/psycopg/_typing.py index 05ce4fb07..b2c42a5bf 100644 --- a/sqlspec/adapters/psycopg/_typing.py +++ b/sqlspec/adapters/psycopg/_typing.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Protocol -from psycopg import AsyncConnection, Connection +from psycopg import AsyncConnection, AsyncCursor, Connection, Cursor from psycopg.rows import DictRow as PsycopgDictRow if TYPE_CHECKING: @@ -20,10 +20,14 @@ PsycopgSyncConnection: TypeAlias = Connection[PsycopgDictRow] PsycopgAsyncConnection: TypeAlias = AsyncConnection[PsycopgDictRow] + PsycopgSyncRawCursor: TypeAlias = Cursor[PsycopgDictRow] + PsycopgAsyncRawCursor: TypeAlias = AsyncCursor[PsycopgDictRow] if not TYPE_CHECKING: PsycopgSyncConnection = Connection PsycopgAsyncConnection = AsyncConnection + PsycopgSyncRawCursor = Cursor + PsycopgAsyncRawCursor = AsyncCursor class PsycopgSyncCursor: @@ -33,9 +37,9 @@ class PsycopgSyncCursor: def __init__(self, connection: "PsycopgSyncConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: PsycopgSyncRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "PsycopgSyncRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -51,9 +55,9 @@ class PsycopgAsyncCursor: def __init__(self, connection: "PsycopgAsyncConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: PsycopgAsyncRawCursor | None = None - async def __aenter__(self) -> Any: + async def __aenter__(self) -> "PsycopgAsyncRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -197,10 +201,12 @@ async def __aexit__( __all__ = ( "PsycopgAsyncConnection", "PsycopgAsyncCursor", + "PsycopgAsyncRawCursor", "PsycopgAsyncSessionContext", "PsycopgDictRow", "PsycopgPipelineDriver", "PsycopgSyncConnection", "PsycopgSyncCursor", + "PsycopgSyncRawCursor", "PsycopgSyncSessionContext", ) diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 7bae30e50..c0b9fe83a 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -186,13 +186,17 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool driver_type: "ClassVar[type[PsycopgSyncDriver]]" = PsycopgSyncDriver connection_type: "ClassVar[type[PsycopgSyncConnection]]" = PsycopgSyncConnection - _connection_context_class: "ClassVar[type[PsycopgSyncConnectionContext]]" = PsycopgSyncConnectionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[PsycopgSyncConnectionContext]]" = PsycopgSyncConnectionContext + _session_factory_class: "ClassVar[type[_PsycopgSyncSessionConnectionHandler]]" = ( + _PsycopgSyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[PsycopgSyncSessionContext]]" = PsycopgSyncSessionContext + _default_statement_config = default_statement_config def __init__( self, @@ -454,13 +458,17 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec driver_type: ClassVar[type[PsycopgAsyncDriver]] = PsycopgAsyncDriver connection_type: "ClassVar[type[PsycopgAsyncConnection]]" = PsycopgAsyncConnection - _connection_context_class: "ClassVar[type[PsycopgAsyncConnectionContext]]" = PsycopgAsyncConnectionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: ClassVar[bool] = True supports_native_arrow_import: ClassVar[bool] = True supports_native_parquet_export: ClassVar[bool] = True supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[PsycopgAsyncConnectionContext]]" = PsycopgAsyncConnectionContext + _session_factory_class: "ClassVar[type[_PsycopgAsyncSessionConnectionHandler]]" = ( + _PsycopgAsyncSessionConnectionHandler + ) + _session_context_class: "ClassVar[type[PsycopgAsyncSessionContext]]" = PsycopgAsyncSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/pymysql/_typing.py b/sqlspec/adapters/pymysql/_typing.py index 5ff2909af..7aae2125b 100644 --- a/sqlspec/adapters/pymysql/_typing.py +++ b/sqlspec/adapters/pymysql/_typing.py @@ -17,9 +17,11 @@ from sqlspec.core import StatementConfig PyMysqlConnection: TypeAlias = pymysql.connections.Connection + PyMysqlRawCursor: TypeAlias = pymysql.cursors.Cursor if not TYPE_CHECKING: PyMysqlConnection = pymysql.connections.Connection + PyMysqlRawCursor = pymysql.cursors.Cursor class PyMysqlCursor: @@ -29,9 +31,9 @@ class PyMysqlCursor: def __init__(self, connection: "PyMysqlConnection") -> None: self.connection = connection - self.cursor: Any = None + self.cursor: PyMysqlRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "PyMysqlRawCursor": self.cursor = self.connection.cursor() return self.cursor @@ -87,4 +89,4 @@ def __exit__( return None -__all__ = ("PyMysqlConnection", "PyMysqlCursor", "PyMysqlSessionContext") +__all__ = ("PyMysqlConnection", "PyMysqlCursor", "PyMysqlRawCursor", "PyMysqlSessionContext") diff --git a/sqlspec/adapters/pymysql/config.py b/sqlspec/adapters/pymysql/config.py index e93ab8ee0..421dc39ac 100644 --- a/sqlspec/adapters/pymysql/config.py +++ b/sqlspec/adapters/pymysql/config.py @@ -118,15 +118,15 @@ class PyMysqlConfig(SyncDatabaseConfig[PyMysqlConnection, PyMysqlConnectionPool, driver_type: "ClassVar[type[PyMysqlDriver]]" = PyMysqlDriver connection_type: "ClassVar[type[PyMysqlConnection]]" = cast("type[PyMysqlConnection]", PyMysqlConnection) - _connection_context_class: "ClassVar[type[PyMysqlConnectionContext]]" = PyMysqlConnectionContext - _session_factory_class: "ClassVar[type[_PyMysqlSessionConnectionHandler]]" = _PyMysqlSessionConnectionHandler - _session_context_class: "ClassVar[type[PyMysqlSessionContext]]" = PyMysqlSessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = False supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[PyMysqlConnectionContext]]" = PyMysqlConnectionContext + _session_factory_class: "ClassVar[type[_PyMysqlSessionConnectionHandler]]" = _PyMysqlSessionConnectionHandler + _session_context_class: "ClassVar[type[PyMysqlSessionContext]]" = PyMysqlSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/spanner/config.py b/sqlspec/adapters/spanner/config.py index 61008b63b..bb13eb096 100644 --- a/sqlspec/adapters/spanner/config.py +++ b/sqlspec/adapters/spanner/config.py @@ -167,6 +167,10 @@ class SpannerSyncConfig(SyncDatabaseConfig["SpannerConnection", "AbstractSession supports_native_parquet_export: ClassVar[bool] = False supports_native_parquet_import: ClassVar[bool] = False requires_staging_for_load: ClassVar[bool] = False + _connection_context_class: "ClassVar[type[SpannerConnectionContext]]" = SpannerConnectionContext + _session_factory_class: "ClassVar[type[_SpannerSessionConnectionHandler]]" = _SpannerSessionConnectionHandler + _session_context_class: "ClassVar[type[SpannerSessionContext]]" = SpannerSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/adapters/sqlite/_typing.py b/sqlspec/adapters/sqlite/_typing.py index 78e6e21b3..7ef3baa3c 100644 --- a/sqlspec/adapters/sqlite/_typing.py +++ b/sqlspec/adapters/sqlite/_typing.py @@ -19,11 +19,11 @@ from sqlspec.core import StatementConfig SqliteConnection: TypeAlias = _SqliteConnection - SqliteCursorType: TypeAlias = sqlite3.Cursor + SqliteRawCursor: TypeAlias = sqlite3.Cursor if not TYPE_CHECKING: SqliteConnection = _SqliteConnection - SqliteCursorType = sqlite3.Cursor + SqliteRawCursor = sqlite3.Cursor class SqliteCursor: @@ -41,9 +41,9 @@ def __init__(self, connection: "SqliteConnection") -> None: connection: SQLite database connection """ self.connection = connection - self.cursor: Any = None + self.cursor: SqliteRawCursor | None = None - def __enter__(self) -> Any: + def __enter__(self) -> "SqliteRawCursor": """Create and return a new cursor. Returns: @@ -120,4 +120,4 @@ def __exit__( return None -__all__ = ("SqliteConnection", "SqliteCursor", "SqliteCursorType", "SqliteSessionContext") +__all__ = ("SqliteConnection", "SqliteCursor", "SqliteRawCursor", "SqliteSessionContext") diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index de35298e7..b6f0a2305 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -118,15 +118,15 @@ class SqliteConfig(SyncDatabaseConfig[SqliteConnection, SqliteConnectionPool, Sq driver_type: "ClassVar[type[SqliteDriver]]" = SqliteDriver connection_type: "ClassVar[type[SqliteConnection]]" = SqliteConnection - _connection_context_class: "ClassVar[type[SqliteConnectionContext]]" = SqliteConnectionContext - _session_factory_class: "ClassVar[type[_SqliteSessionConnectionHandler]]" = _SqliteSessionConnectionHandler - _session_context_class: "ClassVar[type[SqliteSessionContext]]" = SqliteSessionContext - _default_statement_config = default_statement_config supports_transactional_ddl: "ClassVar[bool]" = True supports_native_arrow_export: "ClassVar[bool]" = True supports_native_arrow_import: "ClassVar[bool]" = True supports_native_parquet_export: "ClassVar[bool]" = True supports_native_parquet_import: "ClassVar[bool]" = True + _connection_context_class: "ClassVar[type[SqliteConnectionContext]]" = SqliteConnectionContext + _session_factory_class: "ClassVar[type[_SqliteSessionConnectionHandler]]" = _SqliteSessionConnectionHandler + _session_context_class: "ClassVar[type[SqliteSessionContext]]" = SqliteSessionContext + _default_statement_config = default_statement_config def __init__( self, diff --git a/sqlspec/storage/backends/base.py b/sqlspec/storage/backends/base.py index 60a90c2b3..9adbf81e9 100644 --- a/sqlspec/storage/backends/base.py +++ b/sqlspec/storage/backends/base.py @@ -1,6 +1,7 @@ """Base class for storage backends.""" import asyncio +import contextlib from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator from typing import TYPE_CHECKING, Any, NoReturn, cast @@ -8,12 +9,12 @@ from mypy_extensions import mypyc_attr from typing_extensions import Self -if TYPE_CHECKING: - from types import TracebackType - from sqlspec.typing import ArrowRecordBatch, ArrowTable from sqlspec.utils.sync_tools import CapacityLimiter +if TYPE_CHECKING: + from types import TracebackType + __all__ = ( "AsyncArrowBatchIterator", "AsyncBytesIterator", @@ -21,7 +22,6 @@ "AsyncObStoreStreamIterator", "AsyncThreadedBytesIterator", "ObjectStoreBase", - "storage_limiter", ) # Dedicated capacity limiter for storage I/O operations (100 concurrent ops) @@ -71,6 +71,7 @@ def __init__(self, sync_iterator: "Iterator[ArrowRecordBatch]") -> None: Args: sync_iterator: The synchronous iterator to wrap. + """ self._sync_iter = sync_iterator @@ -89,6 +90,7 @@ async def __anext__(self) -> "ArrowRecordBatch": Raises: StopAsyncIteration: When the iterator is exhausted. + """ result = await asyncio.to_thread(_next_or_sentinel, self._sync_iter) if result is _EXHAUSTED: @@ -116,6 +118,7 @@ def __init__(self, sync_iterator: "Iterator[bytes]") -> None: Args: sync_iterator: The synchronous iterator to wrap. + """ self._sync_iter = sync_iterator @@ -131,6 +134,7 @@ async def __anext__(self) -> bytes: Raises: StopAsyncIteration: When the iterator is exhausted. + """ try: return next(self._sync_iter) @@ -163,6 +167,7 @@ def __init__(self, data: bytes, chunk_size: int = 65536) -> None: Args: data: The bytes data to iterate over in chunks. chunk_size: Size of each chunk to yield (default: 65536 bytes). + """ self._data = data self._chunk_size = chunk_size @@ -183,8 +188,8 @@ async def __anext__(self) -> bytes: Raises: StopAsyncIteration: When all data has been yielded. - """ + """ if self._offset >= len(self._data): raise StopAsyncIteration @@ -211,6 +216,7 @@ def __init__(self, stream: Any, chunk_size: "int | None" = None) -> None: Args: stream: The native obstore async stream to wrap. chunk_size: Optional chunk size to re-chunk streamed data. + """ self._stream = stream self._buffer = bytearray() @@ -229,6 +235,7 @@ async def __anext__(self) -> bytes: Raises: StopAsyncIteration: When the stream is exhausted. + """ if self._chunk_size is None: try: @@ -265,25 +272,18 @@ class AsyncThreadedBytesIterator: allowing it to be compiled by mypyc. It offloads blocking read/close calls to a thread pool to avoid blocking the event loop. - Call aclose() or use as an async context manager to ensure cleanup when - consumers exit early. + NOTE: We specifically avoid __del__ here as it causes segmentation faults + in mypyc compiled mode during GC teardown. """ __slots__ = ("_chunk_size", "_closed", "_file_obj") def __init__(self, file_obj: Any, chunk_size: int = 65536) -> None: - """Initialize the threaded bytes iterator. - - Args: - file_obj: Synchronous file-like object supporting read() and close(). - chunk_size: Size of each chunk to read (default: 65536 bytes). - """ self._file_obj = file_obj self._chunk_size = chunk_size self._closed = False def __aiter__(self) -> "AsyncThreadedBytesIterator": - """Return self as the async iterator.""" return self async def __aenter__(self) -> Self: @@ -296,53 +296,35 @@ async def __aexit__( """Close the underlying file when exiting a context.""" await self.aclose() - def __del__(self) -> None: - """Best-effort cleanup for early exit.""" - self._close_sync() - - def _raise_stop(self) -> NoReturn: - raise StopAsyncIteration - - def _close_sync(self) -> None: - if self._closed: - return - self._closed = True - try: - self._file_obj.close() - except Exception: - return - - async def _close_async(self) -> None: + async def aclose(self) -> None: + """Close the underlying file object.""" if self._closed: return self._closed = True - await asyncio.to_thread(self._file_obj.close) + with contextlib.suppress(Exception): + await asyncio.to_thread(self._file_obj.close) - async def aclose(self) -> None: - """Close the underlying file object.""" - await self._close_async() + def _raise_stop(self) -> NoReturn: + raise StopAsyncIteration async def __anext__(self) -> bytes: - """Read the next chunk of bytes in a thread pool. + if self._closed: + self._raise_stop() - Returns: - The next chunk of bytes. - """ try: + # We use a simple while loop if we needed to retry, but here one read is enough chunk = await asyncio.to_thread(self._file_obj.read, self._chunk_size) - except EOFError: - await self._close_async() + if not chunk: + await self.aclose() + self._raise_stop() + return cast("bytes", chunk) + except (EOFError, StopAsyncIteration): + await self.aclose() self._raise_stop() - except BaseException: - await asyncio.shield(self._close_async()) + except Exception: + await asyncio.shield(self.aclose()) raise - if not chunk: - await self._close_async() - self._raise_stop() - - return cast("bytes", chunk) - @mypyc_attr(allow_interpreted_subclasses=True) class ObjectStoreBase(ABC): From dadf62d88d91e4a778c93774cbcaf7049180db1e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 00:10:32 +0000 Subject: [PATCH 33/39] feat: Allow `psycopg.sql.SQL` objects for prepared statements by updating type hints and adding casts. --- sqlspec/adapters/psycopg/core.py | 3 ++- sqlspec/adapters/psycopg/driver.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index 3b6d11b93..0f362a8f2 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, cast from psycopg import sql as psycopg_sql +from typing_extensions import LiteralString from sqlspec.core import ( SQL, @@ -87,7 +88,7 @@ class PreparedStackOperation(NamedTuple): operation_index: int operation: "StackOperation" statement: "SQL" - sql: str + sql: "LiteralString | psycopg_sql.SQL" parameters: "tuple[Any, ...] | dict[str, Any] | None" diff --git a/sqlspec/adapters/psycopg/driver.py b/sqlspec/adapters/psycopg/driver.py index a5806c9e2..bfd09847e 100644 --- a/sqlspec/adapters/psycopg/driver.py +++ b/sqlspec/adapters/psycopg/driver.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, cast import psycopg +from psycopg import sql as psycopg_sql +from typing_extensions import LiteralString from sqlspec.adapters.psycopg._typing import ( PsycopgAsyncConnection, @@ -109,7 +111,7 @@ def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[Prepare operation_index=index, operation=operation, statement=sql_statement, - sql=sql_text, + sql=cast("LiteralString | psycopg_sql.SQL", sql_text), parameters=prepared_parameters, ) ) @@ -394,10 +396,11 @@ def _raise_pending_exception(exception_ctx: "PsycopgSyncExceptionHandler") -> No cursor = resource_stack.enter_context(self.with_cursor(self.connection)) try: + sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast] if prepared.parameters: - cursor.execute(prepared.sql, prepared.parameters) + cursor.execute(sql, prepared.parameters) else: - cursor.execute(prepared.sql) + cursor.execute(sql) except Exception as exc: stack_error = StackExecutionError( prepared.operation_index, @@ -852,10 +855,11 @@ def _raise_pending_exception(exception_ctx: "PsycopgAsyncExceptionHandler") -> N cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection)) try: + sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast] if prepared.parameters: - await cursor.execute(prepared.sql, prepared.parameters) + await cursor.execute(sql, prepared.parameters) else: - await cursor.execute(prepared.sql) + await cursor.execute(sql) except Exception as exc: stack_error = StackExecutionError( prepared.operation_index, From 42781ab77cb2ceb170f55a2203e1874695a4f4d6 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 14:11:12 +0000 Subject: [PATCH 34/39] feat: Implement `_read_chunk_or_sentinel` helper for efficient file reading and update `AsyncThreadedBytesIterator` to use it --- sqlspec/storage/backends/base.py | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/sqlspec/storage/backends/base.py b/sqlspec/storage/backends/base.py index 9adbf81e9..73ee19042 100644 --- a/sqlspec/storage/backends/base.py +++ b/sqlspec/storage/backends/base.py @@ -4,7 +4,7 @@ import contextlib from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator -from typing import TYPE_CHECKING, Any, NoReturn, cast +from typing import TYPE_CHECKING, Any, cast from mypy_extensions import mypyc_attr from typing_extensions import Self @@ -54,6 +54,20 @@ def _next_or_sentinel(iterator: "Iterator[Any]") -> "Any": return _EXHAUSTED +def _read_chunk_or_sentinel(file_obj: Any, chunk_size: int) -> Any: + """Read a chunk from a file-like object or return sentinel if exhausted. + + This helper is used by AsyncThreadedBytesIterator to offload blocking reads. + """ + try: + chunk = file_obj.read(chunk_size) + if not chunk: + return _EXHAUSTED + return chunk + except EOFError: + return _EXHAUSTED + + class AsyncArrowBatchIterator: """Async iterator wrapper for sync Arrow batch iterators. @@ -304,26 +318,18 @@ async def aclose(self) -> None: with contextlib.suppress(Exception): await asyncio.to_thread(self._file_obj.close) - def _raise_stop(self) -> NoReturn: - raise StopAsyncIteration - async def __anext__(self) -> bytes: if self._closed: - self._raise_stop() + raise StopAsyncIteration - try: - # We use a simple while loop if we needed to retry, but here one read is enough - chunk = await asyncio.to_thread(self._file_obj.read, self._chunk_size) - if not chunk: - await self.aclose() - self._raise_stop() - return cast("bytes", chunk) - except (EOFError, StopAsyncIteration): + # Offload blocking read to a thread pool + result = await asyncio.to_thread(_read_chunk_or_sentinel, self._file_obj, self._chunk_size) + + if result is _EXHAUSTED: await self.aclose() - self._raise_stop() - except Exception: - await asyncio.shield(self.aclose()) - raise + raise StopAsyncIteration # noqa: TRY301 + + return cast("bytes", result) @mypyc_attr(allow_interpreted_subclasses=True) From 5c5dd7cc3169fdcd80463d71c3800151d8ea72b7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 16:37:43 +0000 Subject: [PATCH 35/39] feat: Refactor storage backend iterators into a separate module to prevent mypyc segfaults --- pyproject.toml | 3 +- sqlspec/storage/backends/_iterators.py | 330 +++++++++++++++++++ sqlspec/storage/backends/base.py | 318 +----------------- sqlspec/storage/backends/fsspec.py | 2 +- sqlspec/storage/backends/local.py | 2 +- sqlspec/storage/backends/obstore.py | 2 +- tests/unit/storage/test_storage_iterators.py | 2 +- tools/scripts/mypyc_inventory.py | 6 +- 8 files changed, 349 insertions(+), 316 deletions(-) create mode 100644 sqlspec/storage/backends/_iterators.py diff --git a/pyproject.toml b/pyproject.toml index baea0b24f..6d54e0a99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,6 +199,7 @@ exclude = [ "sqlspec/adapters/**/data_dictionary.py", # Cross-module inheritance causes mypyc segfaults "sqlspec/observability/_formatting.py", # Inherits from non-compiled logging.Formatter "sqlspec/utils/arrow_helpers.py", # Arrow operations cause segfaults when compiled + "sqlspec/storage/backends/_iterators.py", # Async __anext__ + asyncio.to_thread causes mypyc segfault ] include = [ "sqlspec/core/**/*.py", # Core module @@ -209,7 +210,7 @@ include = [ "sqlspec/driver/**/*.py", # Driver module "sqlspec/storage/registry.py", # Safe storage registry/runtime routing "sqlspec/storage/errors.py", # Safe storage error normalization - "sqlspec/storage/backends/base.py", # Storage backend runtime base classes and iterators + "sqlspec/storage/backends/base.py", # Storage backend runtime base classes (iterators in _iterators.py) "sqlspec/data_dictionary/**/*.py", # Data dictionary mixin (required for adapter inheritance) "sqlspec/adapters/**/core.py", # Adapter compiled helpers "sqlspec/adapters/**/type_converter.py", # All adapters type converters diff --git a/sqlspec/storage/backends/_iterators.py b/sqlspec/storage/backends/_iterators.py new file mode 100644 index 000000000..ec6b2b1aa --- /dev/null +++ b/sqlspec/storage/backends/_iterators.py @@ -0,0 +1,330 @@ +"""Async iterator classes for storage backends. + +This module is intentionally excluded from mypyc compilation because +async __anext__ methods that use asyncio.to_thread cause segfaults +when compiled — the C coroutine state machine cannot survive the +suspend/resume cycle across thread boundaries. +""" + +import asyncio +import contextlib +from typing import TYPE_CHECKING, Any, cast + +from typing_extensions import Self + +if TYPE_CHECKING: + from collections.abc import Iterator + from types import TracebackType + + from sqlspec.typing import ArrowRecordBatch + +__all__ = ( + "AsyncArrowBatchIterator", + "AsyncBytesIterator", + "AsyncChunkedBytesIterator", + "AsyncObStoreStreamIterator", + "AsyncThreadedBytesIterator", +) + + +class _ExhaustedSentinel: + """Sentinel value to signal iterator exhaustion across thread boundaries. + + StopIteration cannot be raised into asyncio Futures, so we use this sentinel + to signal iterator exhaustion from the thread pool back to the async context. + """ + + __slots__ = () + + +_EXHAUSTED = _ExhaustedSentinel() + + +def _next_or_sentinel(iterator: "Iterator[Any]") -> "Any": + """Get next item or return sentinel if exhausted. + + This helper wraps next() to catch StopIteration in the thread, + since StopIteration cannot propagate through asyncio Futures. + """ + try: + return next(iterator) + except StopIteration: + return _EXHAUSTED + + +def _read_chunk_or_sentinel(file_obj: Any, chunk_size: int) -> Any: + """Read a chunk from a file-like object or return sentinel if exhausted. + + This helper is used by AsyncThreadedBytesIterator to offload blocking reads. + """ + try: + chunk = file_obj.read(chunk_size) + except EOFError: + return _EXHAUSTED + if not chunk: + return _EXHAUSTED + return chunk + + +class AsyncArrowBatchIterator: + """Async iterator wrapper for sync Arrow batch iterators. + + This class implements the async iterator protocol without using async generators, + allowing it to be compiled by mypyc (which doesn't support async generators). + + The class wraps a synchronous iterator and exposes it as an async iterator, + enabling usage with `async for` syntax. + """ + + __slots__ = ("_sync_iter",) + + def __init__(self, sync_iterator: "Iterator[ArrowRecordBatch]") -> None: + """Initialize the async iterator wrapper. + + Args: + sync_iterator: The synchronous iterator to wrap. + + """ + self._sync_iter = sync_iterator + + def __aiter__(self) -> "AsyncArrowBatchIterator": + """Return self as the async iterator.""" + return self + + async def __anext__(self) -> "ArrowRecordBatch": + """Get the next item from the iterator asynchronously. + + Uses asyncio.to_thread to offload the blocking next() call + to a thread pool, preventing event loop blocking. + + Returns: + The next Arrow record batch. + + Raises: + StopAsyncIteration: When the iterator is exhausted. + + """ + result = await asyncio.to_thread(_next_or_sentinel, self._sync_iter) + if result is _EXHAUSTED: + raise StopAsyncIteration + return cast("ArrowRecordBatch", result) + + +class AsyncBytesIterator: + """Async iterator wrapper for sync bytes iterators. + + This class implements the async iterator protocol without using async generators, + allowing it to be compiled by mypyc (which doesn't support async generators). + + The class wraps a synchronous iterator and exposes it as an async iterator, + enabling usage with `async for` syntax. + + Note: This class blocks the event loop during I/O. For non-blocking streaming, + use AsyncChunkedBytesIterator with pre-loaded data instead. + """ + + __slots__ = ("_sync_iter",) + + def __init__(self, sync_iterator: "Iterator[bytes]") -> None: + """Initialize the async iterator wrapper. + + Args: + sync_iterator: The synchronous iterator to wrap. + + """ + self._sync_iter = sync_iterator + + def __aiter__(self) -> "AsyncBytesIterator": + """Return self as the async iterator.""" + return self + + async def __anext__(self) -> bytes: + """Get the next item from the iterator asynchronously. + + Returns: + The next chunk of bytes. + + Raises: + StopAsyncIteration: When the iterator is exhausted. + + """ + try: + return next(self._sync_iter) + except StopIteration: + raise StopAsyncIteration from None + + +class AsyncChunkedBytesIterator: + """Async iterator that yields pre-loaded bytes data in chunks. + + This class implements the async iterator protocol without using async generators, + allowing it to be compiled by mypyc (which doesn't support async generators). + + Unlike AsyncBytesIterator, this class works with pre-loaded data and yields + control to the event loop between chunks via asyncio.sleep(0), ensuring + the event loop is not blocked during iteration. + + Usage pattern: + # Load data in thread pool to avoid blocking + data = await asyncio.to_thread(read_bytes, path) + # Stream chunks without blocking event loop + return AsyncChunkedBytesIterator(data, chunk_size=65536) + """ + + __slots__ = ("_chunk_size", "_data", "_offset") + + def __init__(self, data: bytes, chunk_size: int = 65536) -> None: + """Initialize the chunked bytes iterator. + + Args: + data: The bytes data to iterate over in chunks. + chunk_size: Size of each chunk to yield (default: 65536 bytes). + + """ + self._data = data + self._chunk_size = chunk_size + self._offset = 0 + + def __aiter__(self) -> "AsyncChunkedBytesIterator": + """Return self as the async iterator.""" + return self + + async def __anext__(self) -> bytes: + """Get the next chunk of bytes asynchronously. + + Yields control to the event loop via asyncio.sleep(0) before returning + each chunk, ensuring other tasks can run during iteration. + + Returns: + The next chunk of bytes. + + Raises: + StopAsyncIteration: When all data has been yielded. + + """ + if self._offset >= len(self._data): + raise StopAsyncIteration + + # Yield to event loop to allow other tasks to run + await asyncio.sleep(0) + + chunk = self._data[self._offset : self._offset + self._chunk_size] + self._offset += self._chunk_size + return chunk + + +class AsyncObStoreStreamIterator: + """Async iterator wrapper for obstore streaming. + + This class wraps obstore's native async stream and ensures it yields + bytes objects while remaining compatible with mypyc. + """ + + __slots__ = ("_buffer", "_chunk_size", "_stream", "_stream_exhausted") + + def __init__(self, stream: Any, chunk_size: "int | None" = None) -> None: + """Initialize the obstore stream wrapper. + + Args: + stream: The native obstore async stream to wrap. + chunk_size: Optional chunk size to re-chunk streamed data. + + """ + self._stream = stream + self._buffer = bytearray() + self._chunk_size = chunk_size if chunk_size is not None and chunk_size > 0 else None + self._stream_exhausted = False + + def __aiter__(self) -> "AsyncObStoreStreamIterator": + """Return self as the async iterator.""" + return self + + async def __anext__(self) -> bytes: + """Get the next chunk from the obstore stream asynchronously. + + Returns: + The next chunk of bytes. + + Raises: + StopAsyncIteration: When the stream is exhausted. + + """ + if self._chunk_size is None: + try: + chunk = await self._stream.__anext__() + return bytes(chunk) + except StopAsyncIteration: + raise StopAsyncIteration from None + + while not self._stream_exhausted and len(self._buffer) < self._chunk_size: + try: + chunk = await self._stream.__anext__() + except StopAsyncIteration: + self._stream_exhausted = True + break + self._buffer.extend(bytes(chunk)) + + if self._buffer: + if len(self._buffer) >= self._chunk_size: + data = bytes(self._buffer[: self._chunk_size]) + del self._buffer[: self._chunk_size] + return data + if self._stream_exhausted: + data = bytes(self._buffer) + self._buffer.clear() + return data + + raise StopAsyncIteration from None + + +class AsyncThreadedBytesIterator: + """Async iterator that reads from a synchronous file-like object in a thread pool. + + This class implements the async iterator protocol without using async generators, + allowing it to be compiled by mypyc. It offloads blocking read/close calls + to a thread pool to avoid blocking the event loop. + + NOTE: We specifically avoid __del__ here as it causes segmentation faults + in mypyc compiled mode during GC teardown. + """ + + __slots__ = ("_chunk_size", "_closed", "_file_obj") + + def __init__(self, file_obj: Any, chunk_size: int = 65536) -> None: + self._file_obj = file_obj + self._chunk_size = chunk_size + self._closed = False + + def __aiter__(self) -> "AsyncThreadedBytesIterator": + return self + + async def __aenter__(self) -> Self: + """Return the iterator for async context manager usage.""" + return self + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + """Close the underlying file when exiting a context.""" + await self.aclose() + + async def aclose(self) -> None: + """Close the underlying file object.""" + if self._closed: + return + self._closed = True + with contextlib.suppress(Exception): + await asyncio.to_thread(self._file_obj.close) + + async def __anext__(self) -> bytes: + if self._closed: + raise StopAsyncIteration + + # Offload blocking read to a thread pool + result = await asyncio.to_thread(_read_chunk_or_sentinel, self._file_obj, self._chunk_size) + + if result is _EXHAUSTED: + await self.aclose() + raise StopAsyncIteration + + return cast("bytes", result) diff --git a/sqlspec/storage/backends/base.py b/sqlspec/storage/backends/base.py index 73ee19042..d681e17a2 100644 --- a/sqlspec/storage/backends/base.py +++ b/sqlspec/storage/backends/base.py @@ -1,20 +1,21 @@ """Base class for storage backends.""" -import asyncio -import contextlib from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator -from typing import TYPE_CHECKING, Any, cast +from typing import Any from mypy_extensions import mypyc_attr -from typing_extensions import Self +from sqlspec.storage.backends._iterators import ( + AsyncArrowBatchIterator, + AsyncBytesIterator, + AsyncChunkedBytesIterator, + AsyncObStoreStreamIterator, + AsyncThreadedBytesIterator, +) from sqlspec.typing import ArrowRecordBatch, ArrowTable from sqlspec.utils.sync_tools import CapacityLimiter -if TYPE_CHECKING: - from types import TracebackType - __all__ = ( "AsyncArrowBatchIterator", "AsyncBytesIterator", @@ -29,309 +30,6 @@ storage_limiter = CapacityLimiter(100) -class _ExhaustedSentinel: - """Sentinel value to signal iterator exhaustion across thread boundaries. - - StopIteration cannot be raised into asyncio Futures, so we use this sentinel - to signal iterator exhaustion from the thread pool back to the async context. - """ - - __slots__ = () - - -_EXHAUSTED = _ExhaustedSentinel() - - -def _next_or_sentinel(iterator: "Iterator[Any]") -> "Any": - """Get next item or return sentinel if exhausted. - - This helper wraps next() to catch StopIteration in the thread, - since StopIteration cannot propagate through asyncio Futures. - """ - try: - return next(iterator) - except StopIteration: - return _EXHAUSTED - - -def _read_chunk_or_sentinel(file_obj: Any, chunk_size: int) -> Any: - """Read a chunk from a file-like object or return sentinel if exhausted. - - This helper is used by AsyncThreadedBytesIterator to offload blocking reads. - """ - try: - chunk = file_obj.read(chunk_size) - if not chunk: - return _EXHAUSTED - return chunk - except EOFError: - return _EXHAUSTED - - -class AsyncArrowBatchIterator: - """Async iterator wrapper for sync Arrow batch iterators. - - This class implements the async iterator protocol without using async generators, - allowing it to be compiled by mypyc (which doesn't support async generators). - - The class wraps a synchronous iterator and exposes it as an async iterator, - enabling usage with `async for` syntax. - """ - - __slots__ = ("_sync_iter",) - - def __init__(self, sync_iterator: "Iterator[ArrowRecordBatch]") -> None: - """Initialize the async iterator wrapper. - - Args: - sync_iterator: The synchronous iterator to wrap. - - """ - self._sync_iter = sync_iterator - - def __aiter__(self) -> "AsyncArrowBatchIterator": - """Return self as the async iterator.""" - return self - - async def __anext__(self) -> "ArrowRecordBatch": - """Get the next item from the iterator asynchronously. - - Uses asyncio.to_thread to offload the blocking next() call - to a thread pool, preventing event loop blocking. - - Returns: - The next Arrow record batch. - - Raises: - StopAsyncIteration: When the iterator is exhausted. - - """ - result = await asyncio.to_thread(_next_or_sentinel, self._sync_iter) - if result is _EXHAUSTED: - raise StopAsyncIteration - return cast("ArrowRecordBatch", result) - - -class AsyncBytesIterator: - """Async iterator wrapper for sync bytes iterators. - - This class implements the async iterator protocol without using async generators, - allowing it to be compiled by mypyc (which doesn't support async generators). - - The class wraps a synchronous iterator and exposes it as an async iterator, - enabling usage with `async for` syntax. - - Note: This class blocks the event loop during I/O. For non-blocking streaming, - use AsyncChunkedBytesIterator with pre-loaded data instead. - """ - - __slots__ = ("_sync_iter",) - - def __init__(self, sync_iterator: "Iterator[bytes]") -> None: - """Initialize the async iterator wrapper. - - Args: - sync_iterator: The synchronous iterator to wrap. - - """ - self._sync_iter = sync_iterator - - def __aiter__(self) -> "AsyncBytesIterator": - """Return self as the async iterator.""" - return self - - async def __anext__(self) -> bytes: - """Get the next item from the iterator asynchronously. - - Returns: - The next chunk of bytes. - - Raises: - StopAsyncIteration: When the iterator is exhausted. - - """ - try: - return next(self._sync_iter) - except StopIteration: - raise StopAsyncIteration from None - - -class AsyncChunkedBytesIterator: - """Async iterator that yields pre-loaded bytes data in chunks. - - This class implements the async iterator protocol without using async generators, - allowing it to be compiled by mypyc (which doesn't support async generators). - - Unlike AsyncBytesIterator, this class works with pre-loaded data and yields - control to the event loop between chunks via asyncio.sleep(0), ensuring - the event loop is not blocked during iteration. - - Usage pattern: - # Load data in thread pool to avoid blocking - data = await asyncio.to_thread(read_bytes, path) - # Stream chunks without blocking event loop - return AsyncChunkedBytesIterator(data, chunk_size=65536) - """ - - __slots__ = ("_chunk_size", "_data", "_offset") - - def __init__(self, data: bytes, chunk_size: int = 65536) -> None: - """Initialize the chunked bytes iterator. - - Args: - data: The bytes data to iterate over in chunks. - chunk_size: Size of each chunk to yield (default: 65536 bytes). - - """ - self._data = data - self._chunk_size = chunk_size - self._offset = 0 - - def __aiter__(self) -> "AsyncChunkedBytesIterator": - """Return self as the async iterator.""" - return self - - async def __anext__(self) -> bytes: - """Get the next chunk of bytes asynchronously. - - Yields control to the event loop via asyncio.sleep(0) before returning - each chunk, ensuring other tasks can run during iteration. - - Returns: - The next chunk of bytes. - - Raises: - StopAsyncIteration: When all data has been yielded. - - """ - if self._offset >= len(self._data): - raise StopAsyncIteration - - # Yield to event loop to allow other tasks to run - await asyncio.sleep(0) - - chunk = self._data[self._offset : self._offset + self._chunk_size] - self._offset += self._chunk_size - return chunk - - -class AsyncObStoreStreamIterator: - """Async iterator wrapper for obstore streaming. - - This class wraps obstore's native async stream and ensures it yields - bytes objects while remaining compatible with mypyc. - """ - - __slots__ = ("_buffer", "_chunk_size", "_stream", "_stream_exhausted") - - def __init__(self, stream: Any, chunk_size: "int | None" = None) -> None: - """Initialize the obstore stream wrapper. - - Args: - stream: The native obstore async stream to wrap. - chunk_size: Optional chunk size to re-chunk streamed data. - - """ - self._stream = stream - self._buffer = bytearray() - self._chunk_size = chunk_size if chunk_size is not None and chunk_size > 0 else None - self._stream_exhausted = False - - def __aiter__(self) -> "AsyncObStoreStreamIterator": - """Return self as the async iterator.""" - return self - - async def __anext__(self) -> bytes: - """Get the next chunk from the obstore stream asynchronously. - - Returns: - The next chunk of bytes. - - Raises: - StopAsyncIteration: When the stream is exhausted. - - """ - if self._chunk_size is None: - try: - chunk = await self._stream.__anext__() - return bytes(chunk) - except StopAsyncIteration: - raise StopAsyncIteration from None - - while not self._stream_exhausted and len(self._buffer) < self._chunk_size: - try: - chunk = await self._stream.__anext__() - except StopAsyncIteration: - self._stream_exhausted = True - break - self._buffer.extend(bytes(chunk)) - - if self._buffer: - if len(self._buffer) >= self._chunk_size: - data = bytes(self._buffer[: self._chunk_size]) - del self._buffer[: self._chunk_size] - return data - if self._stream_exhausted: - data = bytes(self._buffer) - self._buffer.clear() - return data - - raise StopAsyncIteration from None - - -class AsyncThreadedBytesIterator: - """Async iterator that reads from a synchronous file-like object in a thread pool. - - This class implements the async iterator protocol without using async generators, - allowing it to be compiled by mypyc. It offloads blocking read/close calls - to a thread pool to avoid blocking the event loop. - - NOTE: We specifically avoid __del__ here as it causes segmentation faults - in mypyc compiled mode during GC teardown. - """ - - __slots__ = ("_chunk_size", "_closed", "_file_obj") - - def __init__(self, file_obj: Any, chunk_size: int = 65536) -> None: - self._file_obj = file_obj - self._chunk_size = chunk_size - self._closed = False - - def __aiter__(self) -> "AsyncThreadedBytesIterator": - return self - - async def __aenter__(self) -> Self: - """Return the iterator for async context manager usage.""" - return self - - async def __aexit__( - self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" - ) -> None: - """Close the underlying file when exiting a context.""" - await self.aclose() - - async def aclose(self) -> None: - """Close the underlying file object.""" - if self._closed: - return - self._closed = True - with contextlib.suppress(Exception): - await asyncio.to_thread(self._file_obj.close) - - async def __anext__(self) -> bytes: - if self._closed: - raise StopAsyncIteration - - # Offload blocking read to a thread pool - result = await asyncio.to_thread(_read_chunk_or_sentinel, self._file_obj, self._chunk_size) - - if result is _EXHAUSTED: - await self.aclose() - raise StopAsyncIteration # noqa: TRY301 - - return cast("bytes", result) - - @mypyc_attr(allow_interpreted_subclasses=True) class ObjectStoreBase(ABC): """Base class for storage backends. diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index 1c2bee189..f8b6622e1 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -10,7 +10,7 @@ from mypy_extensions import mypyc_attr from sqlspec.storage._utils import import_pyarrow_parquet, resolve_storage_path -from sqlspec.storage.backends.base import AsyncArrowBatchIterator, AsyncThreadedBytesIterator +from sqlspec.storage.backends._iterators import AsyncArrowBatchIterator, AsyncThreadedBytesIterator from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.utils.logging import get_logger, log_with_context from sqlspec.utils.module_loader import ensure_fsspec diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 4afa023eb..d0e93e11f 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -15,7 +15,7 @@ from sqlspec.exceptions import FileNotFoundInStorageError from sqlspec.storage._utils import import_pyarrow_parquet -from sqlspec.storage.backends.base import AsyncArrowBatchIterator, AsyncThreadedBytesIterator +from sqlspec.storage.backends._iterators import AsyncArrowBatchIterator, AsyncThreadedBytesIterator from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.utils.sync_tools import async_ diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 1843008a8..57270a212 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -19,7 +19,7 @@ from sqlspec.exceptions import StorageOperationFailedError from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet, resolve_storage_path -from sqlspec.storage.backends.base import AsyncArrowBatchIterator, AsyncObStoreStreamIterator +from sqlspec.storage.backends._iterators import AsyncArrowBatchIterator, AsyncObStoreStreamIterator from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.typing import ArrowRecordBatch, ArrowTable from sqlspec.utils.logging import get_logger, log_with_context diff --git a/tests/unit/storage/test_storage_iterators.py b/tests/unit/storage/test_storage_iterators.py index 54edf5a1a..447dbcaff 100644 --- a/tests/unit/storage/test_storage_iterators.py +++ b/tests/unit/storage/test_storage_iterators.py @@ -2,7 +2,7 @@ import io -from sqlspec.storage.backends.base import AsyncThreadedBytesIterator +from sqlspec.storage.backends._iterators import AsyncThreadedBytesIterator async def test_async_threaded_bytes_iterator_aclose_closes_file() -> None: diff --git a/tools/scripts/mypyc_inventory.py b/tools/scripts/mypyc_inventory.py index b856778b2..5ff9195e1 100644 --- a/tools/scripts/mypyc_inventory.py +++ b/tools/scripts/mypyc_inventory.py @@ -69,7 +69,11 @@ }, "sqlspec/storage/backends/base.py": { "classification": "compile_now", - "reason": "Mypyc-safe runtime base classes and iterator wrappers.", + "reason": "ObjectStoreBase ABC and storage_limiter (iterators split to _iterators.py).", + }, + "sqlspec/storage/backends/_iterators.py": { + "classification": "keep_interpreted", + "reason": "Async __anext__ + asyncio.to_thread causes mypyc segfault on coroutine suspend/resume across threads.", }, "sqlspec/utils/arrow_helpers.py": { "classification": "keep_interpreted", From e9eb4a6ba29edc58046f2dca8c28f69bfd07a024 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 17:19:59 +0000 Subject: [PATCH 36/39] feat: Implement UUID coercions for various adapters and enhance serialization support --- sqlspec/_serialization.py | 21 +++++++++++++- sqlspec/adapters/adbc/core.py | 2 ++ sqlspec/adapters/aiosqlite/core.py | 3 +- sqlspec/adapters/asyncmy/core.py | 3 +- sqlspec/adapters/asyncpg/core.py | 2 ++ sqlspec/adapters/bigquery/core.py | 2 ++ sqlspec/adapters/duckdb/core.py | 3 +- sqlspec/adapters/mock/core.py | 3 +- sqlspec/adapters/mysqlconnector/core.py | 3 +- sqlspec/adapters/oracledb/core.py | 2 ++ sqlspec/adapters/psqlpy/core.py | 4 +-- sqlspec/adapters/psycopg/core.py | 9 ++++-- sqlspec/adapters/pymysql/core.py | 3 +- sqlspec/adapters/spanner/type_converter.py | 15 ++++++++-- sqlspec/adapters/sqlite/core.py | 3 +- sqlspec/utils/type_converters.py | 33 ++++++++++++++++++++++ 16 files changed, 97 insertions(+), 14 deletions(-) diff --git a/sqlspec/_serialization.py b/sqlspec/_serialization.py index 4ad689226..fe4af788e 100644 --- a/sqlspec/_serialization.py +++ b/sqlspec/_serialization.py @@ -11,15 +11,30 @@ import datetime import enum import json +import uuid as _uuid_mod from abc import ABC, abstractmethod from decimal import Decimal from typing import Any, Final, Literal, Protocol, overload -from sqlspec._typing import NUMPY_INSTALLED +from sqlspec._typing import NUMPY_INSTALLED, UUID_UTILS_INSTALLED from sqlspec.core.filters import OffsetPagination from sqlspec.typing import MSGSPEC_INSTALLED, ORJSON_INSTALLED, PYDANTIC_INSTALLED, BaseModel +def _get_uuid_utils_type() -> "type[Any] | None": + if not UUID_UTILS_INSTALLED: + return None + try: + import uuid_utils as _uuid_utils_mod # pyright: ignore[reportMissingImports] + except ImportError: + return None + else: + return _uuid_utils_mod.UUID # type: ignore[no-any-return,unused-ignore] + + +_UUID_UTILS_TYPE: "type[Any] | None" = _get_uuid_utils_type() + + def _type_to_string(value: Any) -> Any: # pragma: no cover """Convert special types to strings for JSON serialization. @@ -44,6 +59,10 @@ def _type_to_string(value: Any) -> Any: # pragma: no cover return str(value.value) if PYDANTIC_INSTALLED and isinstance(value, BaseModel): return value.model_dump_json() + if isinstance(value, _uuid_mod.UUID): + return str(value) + if _UUID_UTILS_TYPE is not None and isinstance(value, _UUID_UTILS_TYPE): + return str(value) if isinstance(value, OffsetPagination): return {"items": value.items, "limit": value.limit, "offset": value.offset, "total": value.total} if NUMPY_INSTALLED: diff --git a/sqlspec/adapters/adbc/core.py b/sqlspec/adapters/adbc/core.py index ef5cf8789..abc273aef 100644 --- a/sqlspec/adapters/adbc/core.py +++ b/sqlspec/adapters/adbc/core.py @@ -37,6 +37,7 @@ from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.module_loader import import_string from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount, has_sqlstate if TYPE_CHECKING: @@ -653,6 +654,7 @@ def build_profile() -> "DriverParameterProfile": tuple: _convert_array_for_postgres_adbc, list: _convert_array_for_postgres_adbc, dict: _identity, + **build_uuid_coercions(native=True), }, extras={ "type_coercion_overrides": {list: _convert_array_for_postgres_adbc, tuple: _convert_array_for_postgres_adbc} diff --git a/sqlspec/adapters/aiosqlite/core.py b/sqlspec/adapters/aiosqlite/core.py index 9801d7ba4..a23ecd2cd 100644 --- a/sqlspec/adapters/aiosqlite/core.py +++ b/sqlspec/adapters/aiosqlite/core.py @@ -18,7 +18,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter, build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount, has_sqlite_error if TYPE_CHECKING: @@ -273,6 +273,7 @@ def build_profile() -> "DriverParameterProfile": datetime: _TIME_TO_ISO, date: _TIME_TO_ISO, Decimal: _DECIMAL_TO_STRING, + **build_uuid_coercions(), }, default_dialect="sqlite", ) diff --git a/sqlspec/adapters/asyncmy/core.py b/sqlspec/adapters/asyncmy/core.py index e47c60bb0..5cd66d506 100644 --- a/sqlspec/adapters/asyncmy/core.py +++ b/sqlspec/adapters/asyncmy/core.py @@ -21,6 +21,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: @@ -138,7 +139,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int}, + custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, default_dialect="mysql", ) diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index 303342a37..c74c339b1 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -33,6 +33,7 @@ from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_sqlstate if TYPE_CHECKING: @@ -112,6 +113,7 @@ def _build_asyncpg_custom_type_coercions() -> "dict[type, Callable[[Any], Any]]" datetime.datetime: _convert_datetime_param, datetime.date: _convert_date_param, datetime.time: _convert_time_param, + **build_uuid_coercions(native=True), } diff --git a/sqlspec/adapters/bigquery/core.py b/sqlspec/adapters/bigquery/core.py index 10d563de9..b1a26debf 100644 --- a/sqlspec/adapters/bigquery/core.py +++ b/sqlspec/adapters/bigquery/core.py @@ -34,6 +34,7 @@ ) from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_errors, has_value_attribute if TYPE_CHECKING: @@ -643,6 +644,7 @@ def build_profile() -> "DriverParameterProfile": dict: _identity, list: _identity, type(None): _return_none, + **build_uuid_coercions(), }, default_ast_transformer=build_null_pruning_transform(dialect="bigquery"), extras={"json_tuple_strategy": "tuple", "type_coercion_overrides": {list: _identity, tuple: _tuple_to_list}}, diff --git a/sqlspec/adapters/duckdb/core.py b/sqlspec/adapters/duckdb/core.py index 997e0e659..bd7149e88 100644 --- a/sqlspec/adapters/duckdb/core.py +++ b/sqlspec/adapters/duckdb/core.py @@ -21,7 +21,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import to_json -from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter, build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount if TYPE_CHECKING: @@ -144,6 +144,7 @@ def build_profile() -> "DriverParameterProfile": datetime: _TIME_TO_ISO, date: _TIME_TO_ISO, Decimal: _DECIMAL_TO_STRING, + **build_uuid_coercions(), }, default_dialect="duckdb", ) diff --git a/sqlspec/adapters/mock/core.py b/sqlspec/adapters/mock/core.py index 67e75ea69..6bf9cfd57 100644 --- a/sqlspec/adapters/mock/core.py +++ b/sqlspec/adapters/mock/core.py @@ -22,7 +22,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter, build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount, has_sqlite_error if TYPE_CHECKING: @@ -260,6 +260,7 @@ def build_profile() -> "DriverParameterProfile": datetime: _TIME_TO_ISO, date: _TIME_TO_ISO, Decimal: _DECIMAL_TO_STRING, + **build_uuid_coercions(), }, default_dialect="sqlite", ) diff --git a/sqlspec/adapters/mysqlconnector/core.py b/sqlspec/adapters/mysqlconnector/core.py index 29bfa90c7..1f51cbc2d 100644 --- a/sqlspec/adapters/mysqlconnector/core.py +++ b/sqlspec/adapters/mysqlconnector/core.py @@ -21,6 +21,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: @@ -121,7 +122,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int}, + custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, default_dialect="mysql", ) diff --git a/sqlspec/adapters/oracledb/core.py b/sqlspec/adapters/oracledb/core.py index b82ceaee0..f34cb9f20 100644 --- a/sqlspec/adapters/oracledb/core.py +++ b/sqlspec/adapters/oracledb/core.py @@ -32,6 +32,7 @@ ) from sqlspec.typing import NUMPY_INSTALLED from sqlspec.utils.serializers import to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount, is_readable if TYPE_CHECKING: @@ -706,6 +707,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", + custom_type_coercions={**build_uuid_coercions()}, default_dialect="oracle", ) diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index 008e64379..efa6204bb 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -32,7 +32,7 @@ from sqlspec.utils.dispatch import TypeDispatcher from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import to_json -from sqlspec.utils.type_converters import build_nested_decimal_normalizer +from sqlspec.utils.type_converters import build_nested_decimal_normalizer, build_uuid_coercions from sqlspec.utils.type_guards import has_query_result_metadata if TYPE_CHECKING: @@ -196,7 +196,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={decimal.Decimal: float}, + custom_type_coercions={decimal.Decimal: float, **build_uuid_coercions(native=True)}, default_dialect="postgres", ) diff --git a/sqlspec/adapters/psycopg/core.py b/sqlspec/adapters/psycopg/core.py index 0f362a8f2..633540dfa 100644 --- a/sqlspec/adapters/psycopg/core.py +++ b/sqlspec/adapters/psycopg/core.py @@ -37,7 +37,7 @@ ) from sqlspec.typing import PGVECTOR_INSTALLED from sqlspec.utils.serializers import to_json -from sqlspec.utils.type_converters import build_json_list_converter, build_json_tuple_converter +from sqlspec.utils.type_converters import build_json_list_converter, build_json_tuple_converter, build_uuid_coercions from sqlspec.utils.type_guards import has_rowcount, has_sqlstate # Module-level lazy import for psycopg errors (mypyc optimization) @@ -139,7 +139,12 @@ def _identity(value: Any) -> Any: def _build_psycopg_custom_type_coercions() -> "dict[type, Callable[[Any], Any]]": """Return custom type coercions for psycopg.""" - return {datetime.datetime: _identity, datetime.date: _identity, datetime.time: _identity} + return { + datetime.datetime: _identity, + datetime.date: _identity, + datetime.time: _identity, + **build_uuid_coercions(native=True), + } def _build_psycopg_parameter_config( diff --git a/sqlspec/adapters/pymysql/core.py b/sqlspec/adapters/pymysql/core.py index e40c9672d..f2f60458c 100644 --- a/sqlspec/adapters/pymysql/core.py +++ b/sqlspec/adapters/pymysql/core.py @@ -21,6 +21,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.type_converters import build_uuid_coercions from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: @@ -121,7 +122,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int}, + custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, default_dialect="mysql", ) diff --git a/sqlspec/adapters/spanner/type_converter.py b/sqlspec/adapters/spanner/type_converter.py index b327f7376..d6f621a31 100644 --- a/sqlspec/adapters/spanner/type_converter.py +++ b/sqlspec/adapters/spanner/type_converter.py @@ -21,10 +21,20 @@ from typing import TYPE_CHECKING, Any, cast from uuid import UUID +from sqlspec._typing import UUID_UTILS_INSTALLED from sqlspec.core.type_converter import CachedOutputConverter, convert_uuid from sqlspec.utils.serializers import from_json from sqlspec.utils.type_converters import should_json_encode_sequence +_UUID_TYPES: "tuple[type[Any], ...]" = (UUID,) +if UUID_UTILS_INSTALLED: + try: + import uuid_utils as _uuid_utils # pyright: ignore[reportMissingImports] + + _UUID_TYPES = (UUID, _uuid_utils.UUID) + except ImportError: + pass + if TYPE_CHECKING: from collections.abc import Callable @@ -259,8 +269,9 @@ def coerce_params_for_spanner( json_object_type = _get_json_object_type() coerced: dict[str, Any] = {} for key, value in params.items(): - if isinstance(value, UUID): - coerced[key] = bytes_to_spanner(uuid_to_spanner(value)) + if isinstance(value, _UUID_TYPES): + std_uuid = value if isinstance(value, UUID) else UUID(bytes=value.bytes) + coerced[key] = bytes_to_spanner(uuid_to_spanner(std_uuid)) elif isinstance(value, bytes): coerced[key] = bytes_to_spanner(value) elif isinstance(value, datetime) and value.tzinfo is None: diff --git a/sqlspec/adapters/sqlite/core.py b/sqlspec/adapters/sqlite/core.py index 3894e900f..99e7dc84d 100644 --- a/sqlspec/adapters/sqlite/core.py +++ b/sqlspec/adapters/sqlite/core.py @@ -21,7 +21,7 @@ UniqueViolationError, ) from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter +from sqlspec.utils.type_converters import build_decimal_converter, build_time_iso_converter, build_uuid_coercions from sqlspec.utils.type_guards import has_sqlite_error if TYPE_CHECKING: @@ -312,6 +312,7 @@ def build_profile() -> "DriverParameterProfile": datetime: _TIME_TO_ISO, date: _TIME_TO_ISO, Decimal: _DECIMAL_TO_STRING, + **build_uuid_coercions(), }, default_dialect="sqlite", ) diff --git a/sqlspec/utils/type_converters.py b/sqlspec/utils/type_converters.py index 7c78888a5..ce33cf6e7 100644 --- a/sqlspec/utils/type_converters.py +++ b/sqlspec/utils/type_converters.py @@ -16,6 +16,7 @@ "build_json_tuple_converter", "build_nested_decimal_normalizer", "build_time_iso_converter", + "build_uuid_coercions", "should_json_encode_sequence", ) @@ -173,3 +174,35 @@ def build_nested_decimal_normalizer(*, mode: str = DEFAULT_DECIMAL_MODE) -> "Cal def build_time_iso_converter() -> "Callable[[datetime.date | datetime.datetime | datetime.time], str]": """Return a converter that formats temporal values using ISO 8601.""" return _time_iso_convert + + +def _uuid_to_string(value: Any) -> str: + return str(value) + + +def _uuid_utils_to_stdlib(value: Any) -> Any: + import uuid as _uuid_mod + + return _uuid_mod.UUID(bytes=value.bytes) + + +def build_uuid_coercions(*, native: bool = False) -> "dict[type[Any], Callable[[Any], Any]]": + """Return coercions for ``uuid_utils.UUID`` parameter binding. + + When ``uuid_utils`` is installed, returns a dict mapping its UUID type + to either ``str`` or ``uuid.UUID`` depending on the *native* flag. + When not installed, returns an empty dict. + + Args: + native: When ``True``, convert ``uuid_utils.UUID`` → ``uuid.UUID`` + (via ``.bytes``, for drivers that bind ``uuid.UUID`` natively). + When ``False`` (default), convert to ``str`` (for drivers that + need a plain string, e.g. DuckDB/SQLite). + """ + try: + import uuid_utils as _uuid_utils_mod # pyright: ignore[reportMissingImports] + except ImportError: + return {} + + converter = _uuid_utils_to_stdlib if native else _uuid_to_string + return {_uuid_utils_mod.UUID: converter} From ec5403a81ccbcb7f588fbe395a55409c1fe9bd88 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 17:22:10 +0000 Subject: [PATCH 37/39] feat: Add UUID coercion tests to validate conversion between uuid_utils.UUID and stdlib uuid.UUID --- tests/unit/utils/test_type_converters.py | 54 +++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_type_converters.py b/tests/unit/utils/test_type_converters.py index ec80723f7..b77f80658 100644 --- a/tests/unit/utils/test_type_converters.py +++ b/tests/unit/utils/test_type_converters.py @@ -1,10 +1,12 @@ """Tests for nested converter helpers.""" +import uuid from decimal import Decimal import pytest -from sqlspec.utils.type_converters import build_nested_decimal_normalizer +from sqlspec._typing import UUID_UTILS_INSTALLED +from sqlspec.utils.type_converters import build_nested_decimal_normalizer, build_uuid_coercions pytestmark = pytest.mark.xdist_group("utils") @@ -48,3 +50,53 @@ class DecimalList(list): normalized = normalizer(payload) assert normalized == {"items": [1.5, "x"]} + + +# --- UUID coercion tests --- + + +@pytest.mark.skipif(not UUID_UTILS_INSTALLED, reason="uuid_utils not installed") +def test_build_uuid_coercions_default_returns_str() -> None: + """Default mode converts uuid_utils.UUID to str.""" + import uuid_utils + + coercions = build_uuid_coercions() + assert uuid_utils.UUID in coercions + + u = uuid_utils.uuid7() + result = coercions[type(u)](u) + assert isinstance(result, str) + assert result == str(u) + + +@pytest.mark.skipif(not UUID_UTILS_INSTALLED, reason="uuid_utils not installed") +def test_build_uuid_coercions_native_returns_stdlib_uuid() -> None: + """Native mode converts uuid_utils.UUID to uuid.UUID via .bytes.""" + import uuid_utils + + coercions = build_uuid_coercions(native=True) + assert uuid_utils.UUID in coercions + + u = uuid_utils.uuid7() + result = coercions[type(u)](u) + assert type(result) is uuid.UUID + assert str(result) == str(u) + + +@pytest.mark.skipif(not UUID_UTILS_INSTALLED, reason="uuid_utils not installed") +def test_build_uuid_coercions_does_not_include_stdlib_uuid() -> None: + """Neither mode should include uuid.UUID as a key.""" + assert uuid.UUID not in build_uuid_coercions() + assert uuid.UUID not in build_uuid_coercions(native=True) + + +@pytest.mark.skipif(not UUID_UTILS_INSTALLED, reason="uuid_utils not installed") +def test_build_uuid_coercions_preserves_uuid_value() -> None: + """Round-trip through native coercion preserves the UUID value.""" + import uuid_utils + + coercions = build_uuid_coercions(native=True) + u = uuid_utils.uuid7() + stdlib_uuid = coercions[type(u)](u) + assert stdlib_uuid.bytes == u.bytes + assert stdlib_uuid.int == u.int From d7da67eb37137c5685ccd86d2d337ae2107dae66 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 17:26:37 +0000 Subject: [PATCH 38/39] chore: Refactor test cases to remove unnecessary asyncio markers - Removed `@pytest.mark.asyncio` from various test cases across multiple test files, as they are not required for the current test structure. - Updated tests in the following files: - tests/integration/adapters/mysqlconnector/extensions/events/test_queue_backend.py - tests/integration/adapters/mysqlconnector/test_driver_async.py - tests/integration/adapters/mysqlconnector/test_exceptions.py - tests/integration/adapters/mysqlconnector/test_parameter_styles.py - tests/integration/adapters/oracledb/extensions/events/test_queue_backend.py - tests/integration/adapters/oracledb/test_driver_async.py - tests/integration/adapters/oracledb/test_msgspec_clob.py - tests/integration/adapters/oracledb/test_stack.py - tests/integration/adapters/psqlpy/extensions/events/test_listen_notify.py - tests/integration/adapters/psqlpy/extensions/events/test_queue_backend.py - tests/integration/config/test_connection_injection.py - tests/integration/extensions/litestar/test_channels_backend.py - tests/integration/storage/test_streaming.py - tests/integration/test_pool_concurrency.py - tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py - tests/unit/adapters/test_asyncpg/test_cloud_connectors.py - tests/unit/adapters/test_oracledb/test_oracle_adk_store.py - tests/unit/adapters/test_pool_logging.py - tests/unit/config/test_connection_config_edge_cases.py - tests/unit/config/test_connection_config_parameters.py - tests/unit/config/test_migration_methods.py - tests/unit/driver/test_execute_script.py - tests/unit/driver/test_fetch_aliases.py - tests/unit/driver/test_stack_base.py - tests/unit/extensions/test_events/test_channel.py - tests/unit/storage/test_bridge.py --- .../extensions/events/test_queue_backend.py | 8 -------- .../adapters/aiosqlite/test_storage_bridge.py | 2 +- .../extensions/events/test_queue_backend.py | 3 --- .../adapters/asyncmy/test_explain.py | 2 +- .../adapters/asyncmy/test_storage_bridge.py | 2 +- .../extensions/events/test_listen_notify.py | 4 ---- .../extensions/events/test_queue_backend.py | 1 - .../adapters/asyncpg/test_cloud_connectors.py | 7 ------- .../adapters/asyncpg/test_explain.py | 2 +- .../adapters/asyncpg/test_schema_migration.py | 9 --------- .../adapters/asyncpg/test_storage_bridge.py | 1 - .../cockroach_asyncpg/test_parameter_styles.py | 18 ------------------ .../extensions/events/test_queue_backend.py | 2 +- .../mysqlconnector/test_driver_async.py | 2 +- .../adapters/mysqlconnector/test_exceptions.py | 2 +- .../mysqlconnector/test_parameter_styles.py | 10 ---------- .../extensions/events/test_queue_backend.py | 1 - .../adapters/oracledb/test_driver_async.py | 2 +- .../adapters/oracledb/test_msgspec_clob.py | 6 ------ .../adapters/oracledb/test_stack.py | 2 -- .../extensions/events/test_listen_notify.py | 2 -- .../extensions/events/test_queue_backend.py | 1 - .../config/test_connection_injection.py | 7 ------- .../litestar/test_channels_backend.py | 2 -- tests/integration/storage/test_streaming.py | 2 -- tests/integration/test_pool_concurrency.py | 1 - .../test_aiosqlite/test_pool_shutdown.py | 3 --- .../test_asyncpg/test_cloud_connectors.py | 13 ------------- .../test_oracledb/test_oracle_adk_store.py | 4 ---- tests/unit/adapters/test_pool_logging.py | 3 --- .../test_connection_config_edge_cases.py | 3 --- .../test_connection_config_parameters.py | 3 --- tests/unit/config/test_migration_methods.py | 15 --------------- tests/unit/driver/test_execute_script.py | 1 - tests/unit/driver/test_fetch_aliases.py | 8 -------- tests/unit/driver/test_stack_base.py | 3 --- .../extensions/test_events/test_channel.py | 2 -- tests/unit/storage/test_bridge.py | 7 ------- 38 files changed, 8 insertions(+), 158 deletions(-) diff --git a/tests/integration/adapters/aiosqlite/extensions/events/test_queue_backend.py b/tests/integration/adapters/aiosqlite/extensions/events/test_queue_backend.py index 21d7706b5..4c4bd89b2 100644 --- a/tests/integration/adapters/aiosqlite/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/aiosqlite/extensions/events/test_queue_backend.py @@ -8,7 +8,6 @@ @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_publish(tmp_path) -> None: """Aiosqlite event channel publishes events asynchronously.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -35,7 +34,6 @@ async def test_aiosqlite_event_channel_publish(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_consume(tmp_path) -> None: """Aiosqlite event channel consumes events asynchronously.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -61,7 +59,6 @@ async def test_aiosqlite_event_channel_consume(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_ack(tmp_path) -> None: """Aiosqlite event channel acknowledges events asynchronously.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -93,7 +90,6 @@ async def test_aiosqlite_event_channel_ack(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_metadata(tmp_path) -> None: """Aiosqlite event channel preserves metadata in async operations.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -122,7 +118,6 @@ async def test_aiosqlite_event_channel_metadata(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_telemetry(tmp_path) -> None: """Aiosqlite event operations are tracked in telemetry.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -150,7 +145,6 @@ async def test_aiosqlite_event_channel_telemetry(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_custom_table_name(tmp_path) -> None: """Custom queue table name is used for events.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -177,7 +171,6 @@ async def test_aiosqlite_event_channel_custom_table_name(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_multiple_channels(tmp_path) -> None: """Events are correctly filtered by channel.""" migrations_dir = prepare_events_migrations(tmp_path) @@ -205,7 +198,6 @@ async def test_aiosqlite_event_channel_multiple_channels(tmp_path) -> None: @pytest.mark.integration -@pytest.mark.asyncio async def test_aiosqlite_event_channel_attempts_tracked(tmp_path) -> None: """Event attempts counter is incremented on dequeue.""" migrations_dir = prepare_events_migrations(tmp_path) diff --git a/tests/integration/adapters/aiosqlite/test_storage_bridge.py b/tests/integration/adapters/aiosqlite/test_storage_bridge.py index f6bb5f69b..951620a96 100644 --- a/tests/integration/adapters/aiosqlite/test_storage_bridge.py +++ b/tests/integration/adapters/aiosqlite/test_storage_bridge.py @@ -8,7 +8,7 @@ from sqlspec.adapters.aiosqlite import AiosqliteDriver -pytestmark = [pytest.mark.asyncio, pytest.mark.xdist_group("sqlite")] +pytestmark = [pytest.mark.xdist_group("sqlite")] async def test_aiosqlite_load_from_arrow(aiosqlite_session: AiosqliteDriver) -> None: diff --git a/tests/integration/adapters/asyncmy/extensions/events/test_queue_backend.py b/tests/integration/adapters/asyncmy/extensions/events/test_queue_backend.py index 43070b00f..4b6b75f4d 100644 --- a/tests/integration/adapters/asyncmy/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/asyncmy/extensions/events/test_queue_backend.py @@ -13,7 +13,6 @@ @pytest.mark.mysql -@pytest.mark.asyncio async def test_asyncmy_event_channel_queue_fallback(mysql_service: MySQLService, tmp_path: Any) -> None: """AsyncMy configs publish, consume, and ack events via the queue backend.""" migrations = tmp_path / "migrations" @@ -58,7 +57,6 @@ async def test_asyncmy_event_channel_queue_fallback(mysql_service: MySQLService, @pytest.mark.mysql -@pytest.mark.asyncio async def test_asyncmy_event_channel_multiple_messages(mysql_service: MySQLService, tmp_path: Any) -> None: """AsyncMy queue backend handles multiple messages correctly.""" migrations = tmp_path / "migrations" @@ -104,7 +102,6 @@ async def test_asyncmy_event_channel_multiple_messages(mysql_service: MySQLServi @pytest.mark.mysql -@pytest.mark.asyncio async def test_asyncmy_event_channel_nack_redelivery(mysql_service: MySQLService, tmp_path: Any) -> None: """AsyncMy queue backend redelivers nacked messages.""" migrations = tmp_path / "migrations" diff --git a/tests/integration/adapters/asyncmy/test_explain.py b/tests/integration/adapters/asyncmy/test_explain.py index e9a9eb417..5ae05bfc8 100644 --- a/tests/integration/adapters/asyncmy/test_explain.py +++ b/tests/integration/adapters/asyncmy/test_explain.py @@ -9,7 +9,7 @@ from sqlspec.builder import Explain, sql from sqlspec.core import SQL -pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncio(loop_scope="function")] +pytestmark = [pytest.mark.xdist_group("mysql")] @pytest.fixture diff --git a/tests/integration/adapters/asyncmy/test_storage_bridge.py b/tests/integration/adapters/asyncmy/test_storage_bridge.py index 81c788063..09fe7bb64 100644 --- a/tests/integration/adapters/asyncmy/test_storage_bridge.py +++ b/tests/integration/adapters/asyncmy/test_storage_bridge.py @@ -8,7 +8,7 @@ from sqlspec.adapters.asyncmy import AsyncmyDriver -pytestmark = [pytest.mark.asyncio, pytest.mark.xdist_group("mysql")] +pytestmark = [pytest.mark.xdist_group("mysql")] async def _fetch_rows(asyncmy_driver: AsyncmyDriver, table: str) -> list[dict[str, object]]: diff --git a/tests/integration/adapters/asyncpg/extensions/events/test_listen_notify.py b/tests/integration/adapters/asyncpg/extensions/events/test_listen_notify.py index e98d5bdde..489d7f902 100644 --- a/tests/integration/adapters/asyncpg/extensions/events/test_listen_notify.py +++ b/tests/integration/adapters/asyncpg/extensions/events/test_listen_notify.py @@ -30,7 +30,6 @@ async def _wait_for_message(received: "list[Any]") -> None: @pytest.mark.postgres -@pytest.mark.asyncio async def test_asyncpg_listen_notify_publish_and_ack(postgres_service: "Any") -> None: """AsyncPG adapter publishes and acknowledges LISTEN/NOTIFY events.""" @@ -55,7 +54,6 @@ async def test_asyncpg_listen_notify_publish_and_ack(postgres_service: "Any") -> @pytest.mark.postgres -@pytest.mark.asyncio async def test_asyncpg_listen_notify_message_delivery(postgres_service: "Any") -> None: """AsyncPG adapter delivers NOTIFY payloads via EventChannel listener.""" @@ -96,7 +94,6 @@ async def _handler(message: Any) -> None: @pytest.mark.postgres -@pytest.mark.asyncio async def test_asyncpg_hybrid_listen_notify_durable(postgres_service: "Any", tmp_path: Any) -> None: """Hybrid backend stores event durably then notifies listeners.""" @@ -145,7 +142,6 @@ async def _handler(message: Any) -> None: @pytest.mark.postgres -@pytest.mark.asyncio async def test_asyncpg_listen_notify_metadata(postgres_service: "Any") -> None: """AsyncPG adapter preserves metadata in LISTEN/NOTIFY events.""" diff --git a/tests/integration/adapters/asyncpg/extensions/events/test_queue_backend.py b/tests/integration/adapters/asyncpg/extensions/events/test_queue_backend.py index d00fd8678..fab45521a 100644 --- a/tests/integration/adapters/asyncpg/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/asyncpg/extensions/events/test_queue_backend.py @@ -11,7 +11,6 @@ @pytest.mark.postgres -@pytest.mark.asyncio async def test_asyncpg_native_event_channel(postgres_service: "PostgresService") -> None: """AsyncPG configs surface native LISTEN/NOTIFY events.""" diff --git a/tests/integration/adapters/asyncpg/test_cloud_connectors.py b/tests/integration/adapters/asyncpg/test_cloud_connectors.py index 13ce27e12..3badc95b8 100644 --- a/tests/integration/adapters/asyncpg/test_cloud_connectors.py +++ b/tests/integration/adapters/asyncpg/test_cloud_connectors.py @@ -32,7 +32,6 @@ @pytest.mark.skipif(not HAS_CLOUD_SQL_CREDENTIALS, reason="Cloud SQL credentials missing") -@pytest.mark.asyncio async def test_cloud_sql_connection_basic() -> None: """Test basic Cloud SQL connection via connector.""" instance = os.environ["GOOGLE_CLOUD_SQL_INSTANCE"] @@ -55,7 +54,6 @@ async def test_cloud_sql_connection_basic() -> None: @pytest.mark.skipif(not HAS_CLOUD_SQL_CREDENTIALS, reason="Cloud SQL credentials missing") -@pytest.mark.asyncio async def test_cloud_sql_query_execution() -> None: """Test query execution via Cloud SQL connector.""" instance = os.environ["GOOGLE_CLOUD_SQL_INSTANCE"] @@ -79,7 +77,6 @@ async def test_cloud_sql_query_execution() -> None: @pytest.mark.skipif(not HAS_CLOUD_SQL_CREDENTIALS, reason="Cloud SQL IAM requires credentials") -@pytest.mark.asyncio async def test_cloud_sql_iam_auth() -> None: """Test Cloud SQL with IAM authentication.""" instance = os.environ["GOOGLE_CLOUD_SQL_INSTANCE"] @@ -101,7 +98,6 @@ async def test_cloud_sql_iam_auth() -> None: @pytest.mark.skipif(not HAS_CLOUD_SQL_CREDENTIALS, reason="Cloud SQL credentials missing") -@pytest.mark.asyncio async def test_cloud_sql_private_ip() -> None: """Test Cloud SQL connection using PRIVATE IP type.""" instance = os.environ["GOOGLE_CLOUD_SQL_INSTANCE"] @@ -129,7 +125,6 @@ async def test_cloud_sql_private_ip() -> None: @pytest.mark.skipif(not HAS_ALLOYDB_CREDENTIALS, reason="AlloyDB credentials missing") -@pytest.mark.asyncio async def test_alloydb_connection_basic() -> None: """Test basic AlloyDB connection via connector.""" instance_uri = os.environ["GOOGLE_ALLOYDB_INSTANCE_URI"] @@ -156,7 +151,6 @@ async def test_alloydb_connection_basic() -> None: @pytest.mark.skipif(not HAS_ALLOYDB_CREDENTIALS, reason="AlloyDB credentials missing") -@pytest.mark.asyncio async def test_alloydb_query_execution() -> None: """Test query execution via AlloyDB connector.""" instance_uri = os.environ["GOOGLE_ALLOYDB_INSTANCE_URI"] @@ -184,7 +178,6 @@ async def test_alloydb_query_execution() -> None: @pytest.mark.skipif(not HAS_ALLOYDB_CREDENTIALS, reason="AlloyDB IAM requires credentials") -@pytest.mark.asyncio async def test_alloydb_iam_auth() -> None: """Test AlloyDB with IAM authentication.""" instance_uri = os.environ["GOOGLE_ALLOYDB_INSTANCE_URI"] diff --git a/tests/integration/adapters/asyncpg/test_explain.py b/tests/integration/adapters/asyncpg/test_explain.py index 8052084ac..fb282cbec 100644 --- a/tests/integration/adapters/asyncpg/test_explain.py +++ b/tests/integration/adapters/asyncpg/test_explain.py @@ -9,7 +9,7 @@ from sqlspec.builder import Explain, sql from sqlspec.core import SQL -pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncio(loop_scope="function")] +pytestmark = [pytest.mark.xdist_group("postgres")] @pytest.fixture diff --git a/tests/integration/adapters/asyncpg/test_schema_migration.py b/tests/integration/adapters/asyncpg/test_schema_migration.py index 5ee9e8fa7..3580db621 100644 --- a/tests/integration/adapters/asyncpg/test_schema_migration.py +++ b/tests/integration/adapters/asyncpg/test_schema_migration.py @@ -22,7 +22,6 @@ def _create_config(postgres_service: "PostgresService") -> AsyncpgConfig: ) -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_tracker_creates_full_schema(postgres_service: "PostgresService") -> None: """Test AsyncPG tracker creates complete schema with all columns.""" @@ -58,7 +57,6 @@ async def test_asyncpg_tracker_creates_full_schema(postgres_service: "PostgresSe await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_tracker_migrates_legacy_schema(postgres_service: "PostgresService") -> None: """Test AsyncPG tracker adds missing columns to legacy schema.""" @@ -96,7 +94,6 @@ async def test_asyncpg_tracker_migrates_legacy_schema(postgres_service: "Postgre await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_tracker_migration_preserves_data(postgres_service: "PostgresService") -> None: """Test AsyncPG schema migration preserves existing records.""" @@ -132,7 +129,6 @@ async def test_asyncpg_tracker_migration_preserves_data(postgres_service: "Postg await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_tracker_version_type_recording(postgres_service: "PostgresService") -> None: """Test AsyncPG tracker correctly records version_type.""" @@ -163,7 +159,6 @@ async def test_asyncpg_tracker_version_type_recording(postgres_service: "Postgre await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_tracker_execution_sequence(postgres_service: "PostgresService") -> None: """Test AsyncPG tracker execution_sequence auto-increments.""" @@ -198,7 +193,6 @@ async def test_asyncpg_tracker_execution_sequence(postgres_service: "PostgresSer await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_get_current_version_uses_execution_sequence(postgres_service: "PostgresService") -> None: """Test AsyncPG get_current_version uses execution order.""" @@ -221,7 +215,6 @@ async def test_asyncpg_get_current_version_uses_execution_sequence(postgres_serv await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_update_version_record_preserves_metadata(postgres_service: "PostgresService") -> None: """Test AsyncPG update preserves execution_sequence and applied_at.""" @@ -258,7 +251,6 @@ async def test_asyncpg_update_version_record_preserves_metadata(postgres_service await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_update_version_record_idempotent(postgres_service: "PostgresService") -> None: """Test AsyncPG update_version_record is idempotent.""" @@ -283,7 +275,6 @@ async def test_asyncpg_update_version_record_idempotent(postgres_service: "Postg await config.close_pool() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_migration_schema_is_idempotent(postgres_service: "PostgresService") -> None: """Test AsyncPG schema migration can be run multiple times.""" diff --git a/tests/integration/adapters/asyncpg/test_storage_bridge.py b/tests/integration/adapters/asyncpg/test_storage_bridge.py index b0b3b1166..87bedab1b 100644 --- a/tests/integration/adapters/asyncpg/test_storage_bridge.py +++ b/tests/integration/adapters/asyncpg/test_storage_bridge.py @@ -21,7 +21,6 @@ ] -@pytest.mark.asyncio(loop_scope="function") async def test_asyncpg_storage_bridge_with_minio( asyncpg_async_driver: AsyncpgDriver, minio_service: "MinioService", diff --git a/tests/integration/adapters/cockroach_asyncpg/test_parameter_styles.py b/tests/integration/adapters/cockroach_asyncpg/test_parameter_styles.py index 8eedc4460..326b0a97b 100644 --- a/tests/integration/adapters/cockroach_asyncpg/test_parameter_styles.py +++ b/tests/integration/adapters/cockroach_asyncpg/test_parameter_styles.py @@ -66,7 +66,6 @@ async def cockroach_asyncpg_parameter_session( class TestNumericParameterStyle: """Test NUMERIC ($1, $2) parameter style (native for CockroachDB).""" - @pytest.mark.asyncio async def test_numeric_single_parameter(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test single $1 placeholder works natively.""" result = await cockroach_asyncpg_parameter_session.execute( @@ -77,7 +76,6 @@ async def test_numeric_single_parameter(self, cockroach_asyncpg_parameter_sessio assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_numeric_multiple_parameters( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -95,7 +93,6 @@ async def test_numeric_multiple_parameters( class TestQmarkConversion: """Test QMARK (?) to NUMERIC ($1) conversion.""" - @pytest.mark.asyncio async def test_qmark_single_parameter(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test single ? placeholder gets converted to $1.""" result = await cockroach_asyncpg_parameter_session.execute( @@ -106,7 +103,6 @@ async def test_qmark_single_parameter(self, cockroach_asyncpg_parameter_session: assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_qmark_multiple_parameters(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test multiple ? placeholders get converted to $1, $2, etc.""" result = await cockroach_asyncpg_parameter_session.execute( @@ -117,7 +113,6 @@ async def test_qmark_multiple_parameters(self, cockroach_asyncpg_parameter_sessi assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test2" - @pytest.mark.asyncio async def test_qmark_in_insert(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test ? placeholders in INSERT statements.""" await cockroach_asyncpg_parameter_session.execute( @@ -135,7 +130,6 @@ async def test_qmark_in_insert(self, cockroach_asyncpg_parameter_session: Cockro class TestNamedColonConversion: """Test NAMED_COLON (:name) to NUMERIC ($1) conversion.""" - @pytest.mark.asyncio async def test_named_colon_single_parameter( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -148,7 +142,6 @@ async def test_named_colon_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_named_colon_multiple_parameters( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -162,7 +155,6 @@ async def test_named_colon_multiple_parameters( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test2" - @pytest.mark.asyncio async def test_named_colon_repeated_parameter( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -178,7 +170,6 @@ async def test_named_colon_repeated_parameter( class TestNamedPyformatConversion: """Test NAMED_PYFORMAT (%(name)s) to NUMERIC ($1) conversion.""" - @pytest.mark.asyncio async def test_named_pyformat_single_parameter( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -191,7 +182,6 @@ async def test_named_pyformat_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_named_pyformat_multiple_parameters( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -209,7 +199,6 @@ async def test_named_pyformat_multiple_parameters( class TestSQLObjectConversion: """Test parameter conversion with SQL objects.""" - @pytest.mark.asyncio async def test_sql_object_with_numeric(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test SQL object with $n placeholders.""" sql_numeric = SQL("SELECT * FROM test_parameter_conversion WHERE value BETWEEN $1 AND $2", 150, 250) @@ -219,7 +208,6 @@ async def test_sql_object_with_numeric(self, cockroach_asyncpg_parameter_session assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test2" - @pytest.mark.asyncio async def test_sql_object_with_qmark(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test SQL object with ? placeholders.""" sql_qmark = SQL("SELECT * FROM test_parameter_conversion WHERE name = ? OR name = ?", "test1", "test3") @@ -235,7 +223,6 @@ async def test_sql_object_with_qmark(self, cockroach_asyncpg_parameter_session: class TestExecuteMany: """Test parameter conversion with execute_many.""" - @pytest.mark.asyncio async def test_execute_many_with_numeric(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test execute_many with $n placeholders.""" data = [("batch1", 1001, "Batch 1"), ("batch2", 1002, "Batch 2"), ("batch3", 1003, "Batch 3")] @@ -247,7 +234,6 @@ async def test_execute_many_with_numeric(self, cockroach_asyncpg_parameter_sessi assert isinstance(result, SQLResult) assert result.rows_affected == 3 - @pytest.mark.asyncio async def test_execute_many_with_qmark(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test execute_many with ? placeholders.""" data = [("qbatch1", 2001, "QBatch 1"), ("qbatch2", 2002, "QBatch 2")] @@ -263,7 +249,6 @@ async def test_execute_many_with_qmark(self, cockroach_asyncpg_parameter_session class TestEdgeCases: """Test edge cases in parameter conversion.""" - @pytest.mark.asyncio async def test_null_parameters(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test NULL parameter handling.""" result = await cockroach_asyncpg_parameter_session.execute( @@ -274,7 +259,6 @@ async def test_null_parameters(self, cockroach_asyncpg_parameter_session: Cockro assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test3" - @pytest.mark.asyncio async def test_sql_injection_prevention(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test that parameter escaping prevents SQL injection.""" malicious_input = "'; DROP TABLE test_parameter_conversion; --" @@ -291,7 +275,6 @@ async def test_sql_injection_prevention(self, cockroach_asyncpg_parameter_sessio ) assert count_result.get_data()[0]["count"] >= 3 - @pytest.mark.asyncio async def test_special_characters_in_parameters( self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver ) -> None: @@ -308,7 +291,6 @@ async def test_special_characters_in_parameters( assert len(result.data) == 1 assert result.get_data()[0]["description"] == special_value - @pytest.mark.asyncio async def test_like_with_wildcards(self, cockroach_asyncpg_parameter_session: CockroachAsyncpgDriver) -> None: """Test LIKE queries with wildcard parameters.""" result = await cockroach_asyncpg_parameter_session.execute( diff --git a/tests/integration/adapters/mysqlconnector/extensions/events/test_queue_backend.py b/tests/integration/adapters/mysqlconnector/extensions/events/test_queue_backend.py index 91d997e78..80aae555d 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/mysqlconnector/extensions/events/test_queue_backend.py @@ -9,7 +9,7 @@ from sqlspec.adapters.mysqlconnector import MysqlConnectorAsyncConfig from tests.integration.adapters._events_helpers import setup_async_event_channel -pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.asyncio] +pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector] @pytest.mark.mysql diff --git a/tests/integration/adapters/mysqlconnector/test_driver_async.py b/tests/integration/adapters/mysqlconnector/test_driver_async.py index f7f7bffc9..bb0ecb897 100644 --- a/tests/integration/adapters/mysqlconnector/test_driver_async.py +++ b/tests/integration/adapters/mysqlconnector/test_driver_async.py @@ -9,7 +9,7 @@ from sqlspec.adapters.mysqlconnector import MysqlConnectorAsyncConfig, MysqlConnectorAsyncDriver from sqlspec.utils.serializers import from_json, to_json -pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.asyncio] +pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector] @pytest.fixture diff --git a/tests/integration/adapters/mysqlconnector/test_exceptions.py b/tests/integration/adapters/mysqlconnector/test_exceptions.py index a52267a4c..3d7b471b2 100644 --- a/tests/integration/adapters/mysqlconnector/test_exceptions.py +++ b/tests/integration/adapters/mysqlconnector/test_exceptions.py @@ -14,7 +14,7 @@ UniqueViolationError, ) -pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.asyncio] +pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector] @pytest.fixture diff --git a/tests/integration/adapters/mysqlconnector/test_parameter_styles.py b/tests/integration/adapters/mysqlconnector/test_parameter_styles.py index 86b1c8cbd..fda5cb112 100644 --- a/tests/integration/adapters/mysqlconnector/test_parameter_styles.py +++ b/tests/integration/adapters/mysqlconnector/test_parameter_styles.py @@ -293,7 +293,6 @@ def test_sql_injection_prevention(self, mysqlconnector_sync_parameter_session: M class TestAsyncQmarkConversion: """Test QMARK (?) to POSITIONAL_PYFORMAT (%s) conversion for async driver.""" - @pytest.mark.asyncio async def test_qmark_single_parameter( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -306,7 +305,6 @@ async def test_qmark_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_qmark_multiple_parameters( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -324,7 +322,6 @@ async def test_qmark_multiple_parameters( class TestAsyncNamedColonConversion: """Test NAMED_COLON (:name) to POSITIONAL_PYFORMAT (%s) conversion for async driver.""" - @pytest.mark.asyncio async def test_named_colon_single_parameter( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -337,7 +334,6 @@ async def test_named_colon_single_parameter( assert len(result.data) == 1 assert result.get_data()[0]["name"] == "test1" - @pytest.mark.asyncio async def test_named_colon_multiple_parameters( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -355,7 +351,6 @@ async def test_named_colon_multiple_parameters( class TestAsyncNamedPyformatConversion: """Test NAMED_PYFORMAT (%(name)s) to POSITIONAL_PYFORMAT (%s) conversion for async driver.""" - @pytest.mark.asyncio async def test_named_pyformat_parameters( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -373,7 +368,6 @@ async def test_named_pyformat_parameters( class TestAsyncPositionalPyformatNative: """Test POSITIONAL_PYFORMAT (%s) works natively for async driver.""" - @pytest.mark.asyncio async def test_pyformat_parameters(self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver) -> None: """Test %s placeholders work directly.""" result = await mysqlconnector_async_parameter_session.execute( @@ -388,7 +382,6 @@ async def test_pyformat_parameters(self, mysqlconnector_async_parameter_session: class TestAsyncSQLObject: """Test parameter conversion with SQL objects for async driver.""" - @pytest.mark.asyncio async def test_sql_object_with_qmark( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -403,7 +396,6 @@ async def test_sql_object_with_qmark( class TestAsyncExecuteMany: """Test parameter conversion with execute_many for async driver.""" - @pytest.mark.asyncio async def test_execute_many_with_qmark( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: @@ -421,7 +413,6 @@ async def test_execute_many_with_qmark( class TestAsyncEdgeCases: """Test edge cases for async driver.""" - @pytest.mark.asyncio async def test_boolean_parameters(self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver) -> None: """Test boolean parameters are converted to integers for MySQL.""" await mysqlconnector_async_parameter_session.execute_script(""" @@ -445,7 +436,6 @@ async def test_boolean_parameters(self, mysqlconnector_async_parameter_session: await mysqlconnector_async_parameter_session.execute_script("DROP TABLE IF EXISTS test_bools_async") - @pytest.mark.asyncio async def test_sql_injection_prevention( self, mysqlconnector_async_parameter_session: MysqlConnectorAsyncDriver ) -> None: diff --git a/tests/integration/adapters/oracledb/extensions/events/test_queue_backend.py b/tests/integration/adapters/oracledb/extensions/events/test_queue_backend.py index b9f612fd1..6687b2a99 100644 --- a/tests/integration/adapters/oracledb/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/oracledb/extensions/events/test_queue_backend.py @@ -124,7 +124,6 @@ def test_oracle_sync_event_channel_queue_fallback(tmp_path: "Path") -> None: @pytest.mark.oracle -@pytest.mark.asyncio async def test_oracle_async_event_channel_queue_fallback(tmp_path: "Path") -> None: """Async Oracle configs also use the queue fallback.""" diff --git a/tests/integration/adapters/oracledb/test_driver_async.py b/tests/integration/adapters/oracledb/test_driver_async.py index 2aa2cb59a..86c28ac14 100644 --- a/tests/integration/adapters/oracledb/test_driver_async.py +++ b/tests/integration/adapters/oracledb/test_driver_async.py @@ -11,7 +11,7 @@ from sqlspec.core import SQLResult from sqlspec.exceptions import SQLSpecError -pytestmark = [pytest.mark.xdist_group("oracle"), pytest.mark.asyncio(loop_scope="function")] +pytestmark = [pytest.mark.xdist_group("oracle")] ParamStyle = Literal["positional_binds", "dict_binds"] diff --git a/tests/integration/adapters/oracledb/test_msgspec_clob.py b/tests/integration/adapters/oracledb/test_msgspec_clob.py index 254eb2ef6..08c64cad2 100644 --- a/tests/integration/adapters/oracledb/test_msgspec_clob.py +++ b/tests/integration/adapters/oracledb/test_msgspec_clob.py @@ -55,7 +55,6 @@ class BinaryDocumentRecord(msgspec.Struct): LARGE_BINARY_CONTENT = b"\x00\x01\x02\x03" * 2000 -@pytest.mark.asyncio async def test_oracle_async_clob_msgspec_hydration(oracle_async_session: OracleAsyncDriver) -> None: """Test async CLOB automatic hydration into msgspec struct. @@ -130,7 +129,6 @@ def test_oracle_sync_clob_msgspec_hydration(oracle_sync_session: OracleSyncDrive ) -@pytest.mark.asyncio async def test_oracle_async_mixed_clob_varchar2_msgspec(oracle_async_session: OracleAsyncDriver) -> None: """Test msgspec hydration with mixed CLOB and VARCHAR2 columns. @@ -206,7 +204,6 @@ def test_oracle_sync_mixed_clob_varchar2_msgspec(oracle_sync_session: OracleSync ) -@pytest.mark.asyncio async def test_oracle_async_json_in_clob_detection(oracle_async_session: OracleAsyncDriver) -> None: """Test JSON detection in CLOB with msgspec hydration. @@ -283,7 +280,6 @@ def test_oracle_sync_json_in_clob_detection(oracle_sync_session: OracleSyncDrive ) -@pytest.mark.asyncio async def test_oracle_async_blob_remains_bytes(oracle_async_session: OracleAsyncDriver) -> None: """Test that BLOB columns still return bytes, not strings. @@ -356,7 +352,6 @@ def test_oracle_sync_blob_remains_bytes(oracle_sync_session: OracleSyncDriver) - ) -@pytest.mark.asyncio async def test_oracle_async_multiple_clob_columns(oracle_async_session: OracleAsyncDriver) -> None: """Test msgspec hydration with multiple CLOB columns. @@ -471,7 +466,6 @@ class MultiClobRecord(msgspec.Struct): ) -@pytest.mark.asyncio async def test_oracle_async_empty_clob_msgspec(oracle_async_session: OracleAsyncDriver) -> None: """Test msgspec hydration with NULL CLOB values. diff --git a/tests/integration/adapters/oracledb/test_stack.py b/tests/integration/adapters/oracledb/test_stack.py index 215716504..1013d6941 100644 --- a/tests/integration/adapters/oracledb/test_stack.py +++ b/tests/integration/adapters/oracledb/test_stack.py @@ -41,7 +41,6 @@ def _reset_sync_table(driver: OracleSyncDriver, table_name: str) -> None: driver.execute_script(CREATE_TEMPLATE.format(table_name=table_name)) -@pytest.mark.asyncio(loop_scope="function") async def test_async_statement_stack_native_pipeline( monkeypatch: pytest.MonkeyPatch, oracle_async_session: OracleAsyncDriver ) -> None: @@ -85,7 +84,6 @@ async def tracking_execute_stack_native( await oracle_async_session.execute_script(DROP_TEMPLATE.format(table_name=table_name)) -@pytest.mark.asyncio(loop_scope="function") async def test_async_statement_stack_continue_on_error_pipeline(oracle_async_session: OracleAsyncDriver) -> None: """Ensure continue-on-error surfaces failures while executing remaining operations.""" diff --git a/tests/integration/adapters/psqlpy/extensions/events/test_listen_notify.py b/tests/integration/adapters/psqlpy/extensions/events/test_listen_notify.py index 6a1622f5d..0bdb01f0b 100644 --- a/tests/integration/adapters/psqlpy/extensions/events/test_listen_notify.py +++ b/tests/integration/adapters/psqlpy/extensions/events/test_listen_notify.py @@ -17,7 +17,6 @@ def _dsn(service: "Any") -> str: return f"postgres://{service.user}:{service.password}@{service.host}:{service.port}/{service.database}" -@pytest.mark.asyncio async def test_psqlpy_listen_notify_native(postgres_service: "Any") -> None: """Native LISTEN/NOTIFY path delivers payloads.""" @@ -56,7 +55,6 @@ async def _handler(message: Any) -> None: await config.close_pool() -@pytest.mark.asyncio async def test_psqlpy_listen_notify_hybrid(postgres_service: "Any", tmp_path) -> None: """Hybrid backend persists then signals via NOTIFY.""" diff --git a/tests/integration/adapters/psqlpy/extensions/events/test_queue_backend.py b/tests/integration/adapters/psqlpy/extensions/events/test_queue_backend.py index 6151f5d32..2f7b31111 100644 --- a/tests/integration/adapters/psqlpy/extensions/events/test_queue_backend.py +++ b/tests/integration/adapters/psqlpy/extensions/events/test_queue_backend.py @@ -13,7 +13,6 @@ @pytest.mark.postgres -@pytest.mark.asyncio async def test_psqlpy_event_channel_queue_fallback(tmp_path, postgres_service: "PostgresService") -> None: """Psqlpy adapters consume events via the queue backend.""" diff --git a/tests/integration/config/test_connection_injection.py b/tests/integration/config/test_connection_injection.py index a979be5cc..79214fdc9 100644 --- a/tests/integration/config/test_connection_injection.py +++ b/tests/integration/config/test_connection_injection.py @@ -20,7 +20,6 @@ pytestmark = pytest.mark.xdist_group("config") -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_connection_instance_with_pre_created_pool(asyncpg_connection_config: dict) -> None: """Test AsyncpgConfig with connection_instance using pre-created pool.""" @@ -44,7 +43,6 @@ async def test_asyncpg_connection_instance_with_pre_created_pool(asyncpg_connect await pool.close() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_connection_instance_bypasses_pool_creation(asyncpg_connection_config: dict) -> None: """Test that connection_instance bypasses _create_pool logic.""" @@ -69,7 +67,6 @@ async def test_asyncpg_connection_instance_bypasses_pool_creation(asyncpg_connec await pool.close() -@pytest.mark.asyncio async def test_aiosqlite_connection_instance_with_pre_created_pool(tmp_path: Path) -> None: """Test AiosqliteConfig with connection_instance using pre-created pool.""" from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool @@ -171,7 +168,6 @@ def test_sqlite_connection_instance_none_creates_new_pool(tmp_path: Path) -> Non config.close_pool() -@pytest.mark.asyncio async def test_aiosqlite_connection_instance_none_creates_new_pool(tmp_path: Path) -> None: """Test that connection_instance=None causes new pool creation for async.""" db_path = tmp_path / "test.db" @@ -244,7 +240,6 @@ def test_connection_instance_with_empty_connection_config() -> None: pool.close() -@pytest.mark.asyncio @pytest.mark.postgres async def test_asyncpg_connection_instance_overrides_connection_config_pool_params( asyncpg_connection_config: dict, @@ -310,7 +305,6 @@ def test_sqlite_connection_instance_after_close_pool() -> None: assert config.connection_instance is None -@pytest.mark.asyncio async def test_aiosqlite_connection_instance_after_close_pool() -> None: """Test that connection_instance can be closed via config.""" from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool @@ -336,7 +330,6 @@ def test_connection_instance_with_mock_pool() -> None: assert config.connection_instance is mock_pool -@pytest.mark.asyncio async def test_connection_instance_with_async_mock_pool() -> None: """Test that connection_instance accepts async mock pools for testing.""" mock_pool = MagicMock() diff --git a/tests/integration/extensions/litestar/test_channels_backend.py b/tests/integration/extensions/litestar/test_channels_backend.py index 3b1736757..41da347bf 100644 --- a/tests/integration/extensions/litestar/test_channels_backend.py +++ b/tests/integration/extensions/litestar/test_channels_backend.py @@ -3,7 +3,6 @@ from typing import Any, cast import msgspec.json -import pytest from litestar.channels.plugin import ChannelsPlugin from sqlspec.adapters.aiosqlite.config import AiosqliteConfig @@ -19,7 +18,6 @@ async def _next_event(subscriber: "Any") -> bytes: raise RuntimeError(msg) -@pytest.mark.asyncio async def test_litestar_channels_backend_database_roundtrip(tmp_path: "Any") -> None: migrations = tmp_path / "migrations" migrations.mkdir() diff --git a/tests/integration/storage/test_streaming.py b/tests/integration/storage/test_streaming.py index f69f1d123..d501a701a 100644 --- a/tests/integration/storage/test_streaming.py +++ b/tests/integration/storage/test_streaming.py @@ -17,7 +17,6 @@ def test_sync_stream_read_local(test_file): assert content == b"hello world" * 1000 -@pytest.mark.asyncio async def test_async_stream_read_local(test_file): pipeline = AsyncStoragePipeline() stream = await pipeline.stream_read_async(test_file, chunk_size=10) @@ -27,7 +26,6 @@ async def test_async_stream_read_local(test_file): assert content == b"hello world" * 1000 -@pytest.mark.asyncio async def test_async_stream_read_local_explicit_uri(test_file): pipeline = AsyncStoragePipeline() uri = f"file://{test_file}" diff --git a/tests/integration/test_pool_concurrency.py b/tests/integration/test_pool_concurrency.py index 48d7016f8..e0aa2a133 100644 --- a/tests/integration/test_pool_concurrency.py +++ b/tests/integration/test_pool_concurrency.py @@ -16,7 +16,6 @@ from sqlspec.adapters.duckdb import DuckDBConnectionPool -@pytest.mark.asyncio async def test_asyncpg_pool_concurrency(postgres_service: PostgresService) -> None: """Verify that multiple concurrent calls to provide_pool result in a single pool.""" config_params = { diff --git a/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py b/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py index 3918980e8..42cb561bd 100644 --- a/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py +++ b/tests/unit/adapters/test_aiosqlite/test_pool_shutdown.py @@ -65,7 +65,6 @@ async def _resolve() -> _FakeAiosqliteConnection: return _resolve().__await__() -@pytest.mark.asyncio async def test_create_connection_sets_daemon_for_legacy_proxy(monkeypatch: pytest.MonkeyPatch) -> None: """Pool should set daemon mode for pre-0.22 thread-based connect proxy.""" from sqlspec.adapters.aiosqlite import pool as pool_module @@ -83,7 +82,6 @@ async def test_create_connection_sets_daemon_for_legacy_proxy(monkeypatch: pytes await pool._retire_connection(pool_connection, reason="test_cleanup") -@pytest.mark.asyncio async def test_create_connection_sets_daemon_for_modern_proxy(monkeypatch: pytest.MonkeyPatch) -> None: """Pool should set daemon mode for 0.22+ connect proxy internal worker thread.""" from sqlspec.adapters.aiosqlite import pool as pool_module @@ -101,7 +99,6 @@ async def test_create_connection_sets_daemon_for_modern_proxy(monkeypatch: pytes await pool._retire_connection(pool_connection, reason="test_cleanup") -@pytest.mark.asyncio async def test_pool_close_uses_force_stop_when_close_times_out(monkeypatch: pytest.MonkeyPatch) -> None: """Pool should trigger force-stop fallback when graceful close times out.""" from sqlspec.adapters.aiosqlite import pool as pool_module diff --git a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py index 55d7e49c9..22acd0a40 100644 --- a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py +++ b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py @@ -179,7 +179,6 @@ def test_normal_config_without_connectors() -> None: assert config.driver_features.get("enable_alloydb", False) is not True -@pytest.mark.asyncio async def test_cloud_sql_connector_initialization(mock_cloud_sql_module) -> None: """Cloud SQL connector should be initialized correctly in create_pool.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -209,7 +208,6 @@ async def test_cloud_sql_connector_initialization(mock_cloud_sql_module) -> None assert "user" not in call_kwargs -@pytest.mark.asyncio async def test_cloud_sql_iam_auth_enabled(mock_cloud_sql_module) -> None: """Cloud SQL IAM authentication should configure enable_iam_auth=True.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -240,7 +238,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_cloud_sql_iam_auth_disabled(mock_cloud_sql_module) -> None: """Cloud SQL with IAM disabled should configure enable_iam_auth=False.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -271,7 +268,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_cloud_sql_ip_type_configuration(mock_cloud_sql_module) -> None: """Cloud SQL IP type should be passed to connector.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -302,7 +298,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_cloud_sql_default_ip_type(mock_cloud_sql_module) -> None: """Cloud SQL should default to PRIVATE IP type.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -329,7 +324,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_alloydb_connector_initialization(mock_alloydb_module) -> None: """AlloyDB connector should be initialized correctly in create_pool.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", True): @@ -362,7 +356,6 @@ async def test_alloydb_connector_initialization(mock_alloydb_module) -> None: assert "user" not in call_kwargs -@pytest.mark.asyncio async def test_alloydb_iam_auth_enabled(mock_alloydb_module) -> None: """AlloyDB IAM authentication should configure enable_iam_auth=True.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", True): @@ -393,7 +386,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_alloydb_ip_type_configuration(mock_alloydb_module) -> None: """AlloyDB IP type should be passed to connector.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", True): @@ -424,7 +416,6 @@ async def mock_connect(**kwargs): await get_conn_func() -@pytest.mark.asyncio async def test_cloud_sql_connector_cleanup(mock_cloud_sql_module) -> None: """Cloud SQL connector should be closed on pool close.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -450,7 +441,6 @@ async def test_cloud_sql_connector_cleanup(mock_cloud_sql_module) -> None: assert config._cloud_sql_connector is None -@pytest.mark.asyncio async def test_alloydb_connector_cleanup(mock_alloydb_module) -> None: """AlloyDB connector should be closed on pool close.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", True): @@ -479,7 +469,6 @@ async def test_alloydb_connector_cleanup(mock_alloydb_module) -> None: assert config._alloydb_connector is None -@pytest.mark.asyncio async def test_connection_factory_pattern_cloud_sql(mock_cloud_sql_module) -> None: """Cloud SQL should use connection factory pattern with connect parameter.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", True): @@ -503,7 +492,6 @@ async def test_connection_factory_pattern_cloud_sql(mock_cloud_sql_module) -> No assert callable(call_kwargs["connect"]) -@pytest.mark.asyncio async def test_connection_factory_pattern_alloydb(mock_alloydb_module) -> None: """AlloyDB should use connection factory pattern with connect parameter.""" with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", True): @@ -530,7 +518,6 @@ async def test_connection_factory_pattern_alloydb(mock_alloydb_module) -> None: assert callable(call_kwargs["connect"]) -@pytest.mark.asyncio async def test_pool_close_without_connectors() -> None: """Closing pool without connectors should not raise errors.""" config = AsyncpgConfig(connection_config={"dsn": "postgresql://localhost/test"}) diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py index c206ec1cf..ec618e94e 100644 --- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py +++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py @@ -2,12 +2,9 @@ from decimal import Decimal -import pytest - from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore -@pytest.mark.asyncio async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None: store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] @@ -18,7 +15,6 @@ async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None assert result == {"value": 1.25, "nested": {"score": 0.5}} -@pytest.mark.asyncio async def test_oracle_async_adk_store_deserialize_state_dict_coerces_decimal() -> None: store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] diff --git a/tests/unit/adapters/test_pool_logging.py b/tests/unit/adapters/test_pool_logging.py index a13718d38..0775b0273 100644 --- a/tests/unit/adapters/test_pool_logging.py +++ b/tests/unit/adapters/test_pool_logging.py @@ -143,7 +143,6 @@ def test_duckdb_pool_database_name_property(self) -> None: class TestAiosqliteConnectionPoolLogging: """Tests for aiosqlite pool logging structure.""" - @pytest.mark.asyncio async def test_aiosqlite_pool_uses_pool_logger(self) -> None: """Test that aiosqlite pool imports and uses POOL_LOGGER_NAME.""" from sqlspec.adapters.aiosqlite.pool import _ADAPTER_NAME @@ -152,7 +151,6 @@ async def test_aiosqlite_pool_uses_pool_logger(self) -> None: assert aiosqlite_pool_logger_name == "sqlspec.pool" assert _ADAPTER_NAME == "aiosqlite" - @pytest.mark.asyncio async def test_aiosqlite_pool_has_pool_id(self) -> None: """Test that aiosqlite pool generates a pool_id.""" from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool @@ -161,7 +159,6 @@ async def test_aiosqlite_pool_has_pool_id(self) -> None: assert hasattr(pool, "_pool_id") assert len(pool._pool_id) == 8 # UUID prefix - @pytest.mark.asyncio async def test_aiosqlite_pool_database_name_property(self) -> None: """Test that aiosqlite pool has _database_name property for logging.""" from sqlspec.adapters.aiosqlite.pool import AiosqliteConnectionPool diff --git a/tests/unit/config/test_connection_config_edge_cases.py b/tests/unit/config/test_connection_config_edge_cases.py index de2428003..fd434d137 100644 --- a/tests/unit/config/test_connection_config_edge_cases.py +++ b/tests/unit/config/test_connection_config_edge_cases.py @@ -4,8 +4,6 @@ standardized parameter naming. """ -import pytest - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.asyncpg.config import AsyncpgConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -310,7 +308,6 @@ def test_connection_config_with_bytes_values() -> None: assert config.connection_config["ssl_cert"] == b"certificate data" -@pytest.mark.asyncio async def test_aiosqlite_connection_config_with_pathlib_path() -> None: """Test that connection_config accepts pathlib.Path objects.""" from pathlib import Path diff --git a/tests/unit/config/test_connection_config_parameters.py b/tests/unit/config/test_connection_config_parameters.py index 4082e7e11..e3ecdcd26 100644 --- a/tests/unit/config/test_connection_config_parameters.py +++ b/tests/unit/config/test_connection_config_parameters.py @@ -15,8 +15,6 @@ 5. Configuration merging and overrides """ -import pytest - from sqlspec.adapters.adbc.config import AdbcConfig from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.asyncmy.config import AsyncmyConfig @@ -464,7 +462,6 @@ def test_connection_instance_parameter_naming_consistency() -> None: assert config.connection_instance is None -@pytest.mark.asyncio async def test_asyncpg_config_with_pre_created_pool() -> None: """Test AsyncpgConfig with connection_instance set to pre-created pool.""" from unittest.mock import AsyncMock, MagicMock diff --git a/tests/unit/config/test_migration_methods.py b/tests/unit/config/test_migration_methods.py index 4eb204394..b31f86c0e 100644 --- a/tests/unit/config/test_migration_methods.py +++ b/tests/unit/config/test_migration_methods.py @@ -19,8 +19,6 @@ from pathlib import Path from unittest.mock import patch -import pytest - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.asyncpg.config import AsyncpgConfig from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -194,7 +192,6 @@ def test_sqlite_config_fix_migrations_calls_commands(tmp_path: Path) -> None: mock_fix.assert_called_once_with(True, False, True) -@pytest.mark.asyncio async def test_asyncpg_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" migration_dir = tmp_path / "migrations" @@ -210,7 +207,6 @@ async def test_asyncpg_config_migrate_up_calls_commands(tmp_path: Path) -> None: mock_upgrade.assert_called_once_with("0002", False, True, False, use_logger=False, echo=None, summary_only=None) -@pytest.mark.asyncio async def test_asyncpg_config_migrate_down_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.migrate_down() delegates to AsyncMigrationCommands.downgrade().""" migration_dir = tmp_path / "migrations" @@ -226,7 +222,6 @@ async def test_asyncpg_config_migrate_down_calls_commands(tmp_path: Path) -> Non mock_downgrade.assert_called_once_with("base", dry_run=False, use_logger=False, echo=None, summary_only=None) -@pytest.mark.asyncio async def test_asyncpg_config_get_current_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.get_current_migration() delegates to AsyncMigrationCommands.current().""" migration_dir = tmp_path / "migrations" @@ -243,7 +238,6 @@ async def test_asyncpg_config_get_current_migration_calls_commands(tmp_path: Pat assert result == "0002" -@pytest.mark.asyncio async def test_asyncpg_config_create_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.create_migration() delegates to AsyncMigrationCommands.revision().""" migration_dir = tmp_path / "migrations" @@ -259,7 +253,6 @@ async def test_asyncpg_config_create_migration_calls_commands(tmp_path: Path) -> mock_revision.assert_called_once_with("add users table", "sql") -@pytest.mark.asyncio async def test_asyncpg_config_init_migrations_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.init_migrations() delegates to AsyncMigrationCommands.init().""" migration_dir = tmp_path / "migrations" @@ -275,7 +268,6 @@ async def test_asyncpg_config_init_migrations_calls_commands(tmp_path: Path) -> mock_init.assert_called_once_with(str(migration_dir), True) -@pytest.mark.asyncio async def test_asyncpg_config_stamp_migration_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.stamp_migration() delegates to AsyncMigrationCommands.stamp().""" migration_dir = tmp_path / "migrations" @@ -291,7 +283,6 @@ async def test_asyncpg_config_stamp_migration_calls_commands(tmp_path: Path) -> mock_stamp.assert_called_once_with("0003") -@pytest.mark.asyncio async def test_asyncpg_config_fix_migrations_calls_commands(tmp_path: Path) -> None: """Test that AsyncpgConfig.fix_migrations() delegates to AsyncMigrationCommands.fix().""" migration_dir = tmp_path / "migrations" @@ -336,7 +327,6 @@ def test_duckdb_pooled_config_get_current_migration_calls_commands(tmp_path: Pat assert result is None -@pytest.mark.asyncio async def test_aiosqlite_async_config_migrate_up_calls_commands(tmp_path: Path) -> None: """Test that AiosqliteConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" migration_dir = tmp_path / "migrations" @@ -367,7 +357,6 @@ def test_migrate_up_default_parameters_sync(tmp_path: Path) -> None: mock_upgrade.assert_called_once_with("head", False, True, False, use_logger=False, echo=None, summary_only=None) -@pytest.mark.asyncio async def test_migrate_up_default_parameters_async(tmp_path: Path) -> None: """Test that migrate_up() uses correct default parameter values for async configs.""" migration_dir = tmp_path / "migrations" @@ -398,7 +387,6 @@ def test_migrate_down_default_parameters_sync(tmp_path: Path) -> None: mock_downgrade.assert_called_once_with("-1", dry_run=False, use_logger=False, echo=None, summary_only=None) -@pytest.mark.asyncio async def test_migrate_down_default_parameters_async(tmp_path: Path) -> None: """Test that migrate_down() uses correct default parameter values for async configs.""" migration_dir = tmp_path / "migrations" @@ -429,7 +417,6 @@ def test_create_migration_default_file_type_sync(tmp_path: Path) -> None: mock_revision.assert_called_once_with("test migration", "sql") -@pytest.mark.asyncio async def test_create_migration_default_file_type_async(tmp_path: Path) -> None: """Test that create_migration() defaults to 'sql' file type for async configs.""" migration_dir = tmp_path / "migrations" @@ -460,7 +447,6 @@ def test_init_migrations_default_package_sync(tmp_path: Path) -> None: mock_init.assert_called_once_with(str(migration_dir), True) -@pytest.mark.asyncio async def test_init_migrations_default_package_async(tmp_path: Path) -> None: """Test that init_migrations() defaults to package=True for async configs.""" migration_dir = tmp_path / "migrations" @@ -491,7 +477,6 @@ def test_fix_migrations_default_parameters_sync(tmp_path: Path) -> None: mock_fix.assert_called_once_with(False, True, False) -@pytest.mark.asyncio async def test_fix_migrations_default_parameters_async(tmp_path: Path) -> None: """Test that fix_migrations() uses correct default parameter values for async configs.""" migration_dir = tmp_path / "migrations" diff --git a/tests/unit/driver/test_execute_script.py b/tests/unit/driver/test_execute_script.py index f9c8347c3..0152ac927 100644 --- a/tests/unit/driver/test_execute_script.py +++ b/tests/unit/driver/test_execute_script.py @@ -18,7 +18,6 @@ def test_sync_execute_script_tracks_all_successful_statements(mock_sync_driver) @requires_interpreted -@pytest.mark.asyncio async def test_async_execute_script_tracks_all_successful_statements(mock_async_driver) -> None: """Async execute_script should report all statements as successful.""" result = await mock_async_driver.execute_script("SELECT 1; SELECT 2; SELECT 3;") diff --git a/tests/unit/driver/test_fetch_aliases.py b/tests/unit/driver/test_fetch_aliases.py index 6abef4ba9..f46660841 100644 --- a/tests/unit/driver/test_fetch_aliases.py +++ b/tests/unit/driver/test_fetch_aliases.py @@ -351,7 +351,6 @@ def test_sync_fetch_with_total_delegates_to_select_with_total() -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_delegates_to_select() -> None: """Test that async fetch() delegates to async select() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -368,7 +367,6 @@ async def test_async_fetch_delegates_to_select() -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_one_delegates_to_select_one() -> None: """Test that async fetch_one() delegates to async select_one() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -385,7 +383,6 @@ async def test_async_fetch_one_delegates_to_select_one() -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_one_or_none_delegates_to_select_one_or_none() -> None: """Test that async fetch_one_or_none() delegates to async select_one_or_none() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -402,7 +399,6 @@ async def test_async_fetch_one_or_none_delegates_to_select_one_or_none() -> None @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_value_delegates_to_select_value() -> None: """Test that async fetch_value() delegates to async select_value() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -417,7 +413,6 @@ async def test_async_fetch_value_delegates_to_select_value() -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_value_or_none_delegates_to_select_value_or_none() -> None: """Test that async fetch_value_or_none() delegates to async select_value_or_none() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -434,7 +429,6 @@ async def test_async_fetch_value_or_none_delegates_to_select_value_or_none() -> @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_to_arrow_delegates_to_select_to_arrow() -> None: """Test that async fetch_to_arrow() delegates to async select_to_arrow() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -463,7 +457,6 @@ async def test_async_fetch_to_arrow_delegates_to_select_to_arrow() -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_with_total_delegates_to_select_with_total() -> None: """Test that async fetch_with_total() delegates to async select_with_total() with identical arguments.""" mock_driver = AsyncMock(spec=AsyncDriverAdapterBase) @@ -506,7 +499,6 @@ def __init__(self, **kwargs: Any) -> None: @requires_interpreted -@pytest.mark.asyncio async def test_async_fetch_one_with_schema_type_argument() -> None: """Test that async fetch_one() correctly passes schema_type to select_one().""" diff --git a/tests/unit/driver/test_stack_base.py b/tests/unit/driver/test_stack_base.py index 9714af063..822c4ffcb 100644 --- a/tests/unit/driver/test_stack_base.py +++ b/tests/unit/driver/test_stack_base.py @@ -10,7 +10,6 @@ @requires_interpreted -@pytest.mark.asyncio async def test_async_execute_stack_fail_fast_rolls_back(mock_async_driver) -> None: original_execute = mock_async_driver.execute @@ -31,7 +30,6 @@ async def failing_execute(self, statement, *params, **kwargs): # type: ignore[n @requires_interpreted -@pytest.mark.asyncio async def test_async_execute_stack_continue_on_error(mock_async_driver) -> None: original_execute = mock_async_driver.execute @@ -53,7 +51,6 @@ async def failing_execute(self, statement, *params, **kwargs): # type: ignore[n @requires_interpreted -@pytest.mark.asyncio async def test_async_execute_stack_execute_arrow(mock_async_driver) -> None: sentinel = object() diff --git a/tests/unit/extensions/test_events/test_channel.py b/tests/unit/extensions/test_events/test_channel.py index a22e432b7..ea2cf80b7 100644 --- a/tests/unit/extensions/test_events/test_channel.py +++ b/tests/unit/extensions/test_events/test_channel.py @@ -73,7 +73,6 @@ def test_event_channel_publish_and_ack_sync(tmp_path) -> None: assert snapshot.get("SqliteConfig.events.ack") == pytest.approx(1.0) -@pytest.mark.asyncio async def test_event_channel_async_iteration(tmp_path) -> None: """Async adapters can publish and drain events via the iterator helper.""" @@ -148,7 +147,6 @@ def test_event_channel_backend_fallback(tmp_path) -> None: assert row["status"] == "acked" -@pytest.mark.asyncio async def test_event_channel_portal_bridge_sync_api(tmp_path) -> None: """Async adapters publish and consume events via the event_channel helper.""" diff --git a/tests/unit/storage/test_bridge.py b/tests/unit/storage/test_bridge.py index 2868fd3b7..9efbd84a7 100644 --- a/tests/unit/storage/test_bridge.py +++ b/tests/unit/storage/test_bridge.py @@ -93,7 +93,6 @@ def cursor(self) -> DummyAsyncmyCursorImpl: return DummyAsyncmyCursorImpl(self.operations) -@pytest.mark.asyncio async def test_asyncpg_load_from_storage(monkeypatch: pytest.MonkeyPatch) -> None: arrow_table = pa.table({"id": [1, 2], "name": ["alpha", "beta"]}) @@ -140,7 +139,6 @@ def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, obje assert job.telemetry["destination"] == "ingest_target" -@pytest.mark.asyncio async def test_psqlpy_load_from_arrow_overwrite() -> None: arrow_table = pa.table({"id": [7, 8], "name": ["east", "west"]}) dummy_connection = DummyPsqlpyConnection() @@ -164,7 +162,6 @@ async def test_psqlpy_load_from_arrow_overwrite() -> None: assert job.telemetry["rows_processed"] == arrow_table.num_rows -@pytest.mark.asyncio async def test_psqlpy_load_from_storage_merges_telemetry(monkeypatch: pytest.MonkeyPatch) -> None: arrow_table = pa.table({"id": [1, 2], "name": ["north", "south"]}) dummy_connection = DummyPsqlpyConnection() @@ -187,7 +184,6 @@ async def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str assert job.telemetry["extra"]["source"]["destination"] == "s3://bucket/part-2.parquet" # type: ignore[index] -@pytest.mark.asyncio async def test_aiosqlite_load_from_arrow_overwrite() -> None: connection = await aiosqlite.connect(":memory:") try: @@ -213,7 +209,6 @@ async def test_aiosqlite_load_from_arrow_overwrite() -> None: await connection.close() -@pytest.mark.asyncio async def test_aiosqlite_load_from_storage_includes_source(monkeypatch: pytest.MonkeyPatch) -> None: connection = await aiosqlite.connect(":memory:") try: @@ -292,7 +287,6 @@ def _fake_read(self, *_: object, **__: object) -> tuple[pa.Table, dict[str, obje connection.close() -@pytest.mark.asyncio async def test_asyncmy_load_from_arrow_overwrite() -> None: connection = DummyAsyncmyConnection() driver = AsyncmyDriver( @@ -309,7 +303,6 @@ async def test_asyncmy_load_from_arrow_overwrite() -> None: assert job.telemetry["destination"] == "analytics.scores" -@pytest.mark.asyncio async def test_asyncmy_load_from_storage_merges_source(monkeypatch: pytest.MonkeyPatch) -> None: connection = DummyAsyncmyConnection() driver = AsyncmyDriver( From 055b4a27e6ceb6c50d5bd6466acf82bb5026ef08 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 16 Mar 2026 21:19:20 +0000 Subject: [PATCH 39/39] feat: Refactor driver parameter profiles to use a unified coercions dictionary for custom type coercions --- sqlspec/adapters/asyncmy/core.py | 8 ++++---- sqlspec/adapters/mysqlconnector/core.py | 8 ++++---- sqlspec/adapters/psqlpy/core.py | 4 ++-- sqlspec/adapters/pymysql/core.py | 8 ++++---- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sqlspec/adapters/asyncmy/core.py b/sqlspec/adapters/asyncmy/core.py index 5cd66d506..54e5a06b9 100644 --- a/sqlspec/adapters/asyncmy/core.py +++ b/sqlspec/adapters/asyncmy/core.py @@ -1,6 +1,6 @@ """AsyncMy adapter compiled helpers.""" -from collections.abc import Sized +from collections.abc import Callable, Sized from typing import TYPE_CHECKING, Any from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile @@ -25,7 +25,7 @@ from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Mapping, Sequence __all__ = ( "apply_driver_features", @@ -126,7 +126,7 @@ def normalize_execute_many_parameters(parameters: Any) -> Any: def build_profile() -> "DriverParameterProfile": """Create the AsyncMy driver parameter profile.""" - + coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int, **build_uuid_coercions()} return DriverParameterProfile( name="AsyncMy", default_style=ParameterStyle.QMARK, @@ -139,7 +139,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, + custom_type_coercions=coercions, default_dialect="mysql", ) diff --git a/sqlspec/adapters/mysqlconnector/core.py b/sqlspec/adapters/mysqlconnector/core.py index 1f51cbc2d..3b4495b62 100644 --- a/sqlspec/adapters/mysqlconnector/core.py +++ b/sqlspec/adapters/mysqlconnector/core.py @@ -1,6 +1,6 @@ """MysqlConnector adapter compiled helpers.""" -from collections.abc import Sized +from collections.abc import Callable, Sized from typing import TYPE_CHECKING, Any from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile @@ -25,7 +25,7 @@ from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Mapping, Sequence __all__ = ( "apply_driver_features", @@ -109,7 +109,7 @@ def normalize_execute_many_parameters(parameters: Any) -> Any: def build_profile() -> "DriverParameterProfile": """Create the mysql-connector driver parameter profile.""" - + coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int, **build_uuid_coercions()} return DriverParameterProfile( name="mysql-connector", default_style=ParameterStyle.QMARK, @@ -122,7 +122,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, + custom_type_coercions=coercions, default_dialect="mysql", ) diff --git a/sqlspec/adapters/psqlpy/core.py b/sqlspec/adapters/psqlpy/core.py index efa6204bb..98de9cf00 100644 --- a/sqlspec/adapters/psqlpy/core.py +++ b/sqlspec/adapters/psqlpy/core.py @@ -183,7 +183,7 @@ def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> "tuple[Any, ...]": def build_profile() -> "DriverParameterProfile": """Create the psqlpy driver parameter profile.""" - + coercions: dict[type, Callable[[Any], Any]] = {decimal.Decimal: float, **build_uuid_coercions(native=True)} return DriverParameterProfile( name="Psqlpy", default_style=ParameterStyle.NUMERIC, @@ -196,7 +196,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={decimal.Decimal: float, **build_uuid_coercions(native=True)}, + custom_type_coercions=coercions, default_dialect="postgres", ) diff --git a/sqlspec/adapters/pymysql/core.py b/sqlspec/adapters/pymysql/core.py index f2f60458c..a55f6ed21 100644 --- a/sqlspec/adapters/pymysql/core.py +++ b/sqlspec/adapters/pymysql/core.py @@ -1,6 +1,6 @@ """PyMySQL adapter compiled helpers.""" -from collections.abc import Sized +from collections.abc import Callable, Sized from typing import TYPE_CHECKING, Any from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile @@ -25,7 +25,7 @@ from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Mapping, Sequence __all__ = ( "apply_driver_features", @@ -109,7 +109,7 @@ def normalize_execute_many_parameters(parameters: Any) -> Any: def build_profile() -> "DriverParameterProfile": """Create the PyMySQL driver parameter profile.""" - + coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int, **build_uuid_coercions()} return DriverParameterProfile( name="PyMySQL", default_style=ParameterStyle.QMARK, @@ -122,7 +122,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, + custom_type_coercions=coercions, default_dialect="mysql", )