diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index dd1f95a5..9c99a303 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -51,8 +51,9 @@ def model_to_dict(model: pydantic.BaseModel, include: Optional[List[str]] = None) -> dict: """Utility function to convert a pydantic model to a dictionary.""" - keys = model.model_fields.keys() - keys = keys if include is None else [item for item in include if item in model.model_fields] + model_cls = model if isinstance(model, type) else type(model) # Handles the possibility that sometimes model is a class not instance + keys = model_cls.model_fields.keys() + keys = keys if include is None else [item for item in include if item in model_cls.model_fields] return {key: getattr(model, key) for key in keys} @@ -76,7 +77,8 @@ def subset_model( """ new_fields = {} - for name, field_info in model.model_fields.items(): + model_cls = model if isinstance(model, type) else type(model) # Handles the possibility that sometimes model is a class not instance + for name, field_info in model_cls.model_fields.items(): if name in fields: # copy directly # TODO -- handle cross-field validation @@ -114,7 +116,8 @@ def model_from_state(model: Type[ModelType], state: State) -> ModelType: :param state: state object to create from :return: model object """ - keys = [item for item in model.model_fields.keys() if item in state] + model_cls = model if isinstance(model, type) else type(model) # Handles the possibility that sometimes model is a class not instance + keys = [item for item in model_cls.model_fields.keys() if item in state] return model(**{key: state[key] for key in keys}) @@ -153,7 +156,8 @@ def _validate_and_extract_signature_types( def _validate_keys(model: Type[pydantic.BaseModel], keys: List[str], fn: Callable) -> None: - missing_keys = [key for key in keys if key not in model.model_fields] + model_cls = model if isinstance(model, type) else type(model) # Handles the possibility that sometimes model is a class not instance + missing_keys = [key for key in keys if key not in model_cls.model_fields] if missing_keys: raise ValueError( f"Function fn: {fn.__qualname__} is not a valid pydantic action. " diff --git a/pyproject.toml b/pyproject.toml index 3cb698ed..af464f36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ tests = [ "langchain_core", "langchain_community", "pandas", - "pydantic[email]", + "pydantic[email]>=2.11", "pyarrow", "apache-burr[aiosqlite]", "apache-burr[asyncpg]", @@ -120,7 +120,7 @@ documentation = [ ] tracking-client = [ - "pydantic>1" + "pydantic>=2.11" ] tracking-client-s3 = [ @@ -141,7 +141,7 @@ tracking-server = [ "click", "fastapi", "uvicorn", - "pydantic", + "pydantic>=2.11", "pydantic-settings", "fastapi-pagination", "fastapi-utils", @@ -153,7 +153,7 @@ tracking-server = [ ] pydantic = [ - "pydantic" + "pydantic>=2.11" ] haystack = [ diff --git a/tests/integrations/test_burr_pydantic.py b/tests/integrations/test_burr_pydantic.py index 8fbb94c0..7edc26a2 100644 --- a/tests/integrations/test_burr_pydantic.py +++ b/tests/integrations/test_burr_pydantic.py @@ -22,6 +22,7 @@ import pytest from pydantic import BaseModel, ConfigDict, EmailStr, Field from pydantic.fields import FieldInfo +import warnings from burr.core import expr from burr.core.action import ( @@ -133,6 +134,12 @@ class MyModelWithConfig(pydantic.BaseModel): assert SubsetModel.__name__ == "MyModelWithConfigSubset" assert SubsetModel.model_config == {"arbitrary_types_allowed": True} +def test_pydantic_version(): + """Ensure pydantic >= 2.11 is installed (required for class-level model_fields access).""" + from packaging.version import Version + assert float(float(".".join(pydantic.__version__.split(".")[:2]))) >= float("2.11"), ( + f"pydantic >= 2.11 required, got {pydantic.__version__}" + ) def test_merge_to_state(): model = OriginalModel(