From d0ed75a3afaf5050d87ca7b9512460baa7a8752d Mon Sep 17 00:00:00 2001 From: Stephen Nneji Date: Mon, 16 Feb 2026 13:38:16 +0000 Subject: [PATCH] Fixes auto name bug and adds clearer error for missing custom file or function --- ratapi/inputs.py | 7 ++++ ratapi/models.py | 84 ++++++++++++++++++++++++++++++++++---------- tests/test_inputs.py | 26 ++++++++++++++ tests/test_models.py | 22 ++++++------ 4 files changed, 110 insertions(+), 29 deletions(-) diff --git a/ratapi/inputs.py b/ratapi/inputs.py index 921890b..bdfadda 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -77,6 +77,13 @@ def get_handle(self, index: int): """ custom_file = self.files[index] full_path = os.path.join(custom_file["path"], custom_file["filename"]) + + if not os.path.isfile(full_path): + raise FileNotFoundError(f"The custom file ({custom_file['name']}) does not have a valid path.") + + if not custom_file["function_name"] and custom_file["language"] != Languages.Matlab: + raise ValueError(f"The custom file ({custom_file['name']}) does not have a valid function name.") + if custom_file["language"] == Languages.Python: file_handle = get_python_handle(custom_file["filename"], custom_file["function_name"], custom_file["path"]) elif custom_file["language"] == Languages.Matlab: diff --git a/ratapi/models.py b/ratapi/models.py index 755ea97..ba0a1e9 100644 --- a/ratapi/models.py +++ b/ratapi/models.py @@ -2,7 +2,7 @@ import pathlib import warnings -from itertools import count +from contextlib import suppress from typing import Any import numpy as np @@ -18,14 +18,41 @@ # Create a counter for each model -background_number = count(1) -contrast_number = count(1) -custom_file_number = count(1) -data_number = count(1) -domain_contrast_number = count(1) -layer_number = count(1) -parameter_number = count(1) -resolution_number = count(1) +background_number = ["Background", 0] +contrast_number = ["Contrast", 0] +custom_file_number = ["Custom File", 0] +data_number = ["Data", 0] +domain_contrast_number = ["Domain Contrast", 0] +layer_number = ["Layer", 0] +parameter_number = ["Parameter", 0] +resolution_number = ["Resolution", 0] + +_model_counter = { + "Background": background_number, + "Contrast": contrast_number, + "ContrastWithRatio": contrast_number, + "CustomFile": custom_file_number, + "Data": data_number, + "DomainContrast": domain_contrast_number, + "Layer": layer_number, + "AbsorptionLayer": layer_number, + "Parameter": parameter_number, + "ProtectedParameter": parameter_number, + "Resolution": resolution_number, +} + + +def _model_name_factory(model_name: str) -> str: + """Generate a unique name for model using a global counter. + + Parameters + ---------- + model_name : str + The name of the model class. + """ + title, number = _model_counter[model_name] + _model_counter[model_name][1] += 1 + return f"New {title} {(number + 1)}" class RATModel(BaseModel, validate_assignment=True, extra="forbid"): @@ -38,6 +65,25 @@ def __repr__(self): ) return f"{self.__repr_name__()}({fields_repr})" + @field_validator("name", mode="after", check_fields=False) + @classmethod + def update_counter(cls, name: str) -> str: + """Update the auto name counter if a similar name is manually given. + + Parameters + ---------- + name : str + The name of the model. + """ + title, number = _model_counter[cls.__name__] + prefix = f"New {title} " + if name.startswith(prefix): + with suppress(ValueError): + new_number = int(name[len(prefix) :]) + if new_number > number: + _model_counter[cls.__name__][1] = new_number + return name + def __str__(self): table = prettytable.PrettyTable() table.field_names = [key.replace("_", " ") for key in self.display_fields] @@ -116,7 +162,7 @@ class Background(Signal): """ - name: str = Field(default_factory=lambda: f"New Background {next(background_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Background"), min_length=1) @model_validator(mode="after") def check_unsupported_parameters(self): @@ -173,7 +219,7 @@ class Contrast(RATModel): """ - name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Contrast"), min_length=1) data: str = "" background: str = "" background_action: BackgroundActions = BackgroundActions.Add @@ -255,7 +301,7 @@ class ContrastWithRatio(RATModel): """ - name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("ContrastWithRatio"), min_length=1) data: str = "" background: str = "" background_action: BackgroundActions = BackgroundActions.Add @@ -309,7 +355,7 @@ class CustomFile(RATModel): """ - name: str = Field(default_factory=lambda: f"New Custom File {next(custom_file_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("CustomFile"), min_length=1) filename: str = "" function_name: str = "" language: Languages = Languages.Python @@ -348,7 +394,7 @@ class Data(RATModel, arbitrary_types_allowed=True): """ - name: str = Field(default_factory=lambda: f"New Data {next(data_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Data"), min_length=1) data: np.ndarray = np.empty([0, 3]) data_range: list[float] = Field(default=[], min_length=2, max_length=2) simulation_range: list[float] = Field(default=[], min_length=2, max_length=2) @@ -453,7 +499,7 @@ class DomainContrast(RATModel): """ - name: str = Field(default_factory=lambda: f"New Domain Contrast {next(domain_contrast_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("DomainContrast"), min_length=1) model: list[str] = [] def __str__(self): @@ -483,7 +529,7 @@ class Layer(RATModel, populate_by_name=True): """ - name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Layer"), min_length=1) thickness: str SLD: str = Field(validation_alias="SLD_real") roughness: str @@ -522,7 +568,7 @@ class AbsorptionLayer(RATModel, populate_by_name=True): """ - name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("AbsorptionLayer"), min_length=1) thickness: str SLD_real: str = Field(validation_alias="SLD") SLD_imaginary: str @@ -555,7 +601,7 @@ class Parameter(RATModel): """ - name: str = Field(default_factory=lambda: f"New Parameter {next(parameter_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Parameter"), min_length=1) min: float = 0.0 value: float = 0.0 max: float = 0.0 @@ -638,7 +684,7 @@ class Resolution(Signal): """ - name: str = Field(default_factory=lambda: f"New Resolution {next(resolution_number)}", min_length=1) + name: str = Field(default_factory=lambda: _model_name_factory("Resolution"), min_length=1) @field_validator("type") @classmethod diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 244fd4d..242d748 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2,6 +2,8 @@ import pathlib import pickle +import tempfile +from unittest.mock import patch import numpy as np import pytest @@ -675,6 +677,30 @@ def test_make_controls(standard_layers_controls) -> None: check_controls_equal(controls, standard_layers_controls) +@patch("ratapi.wrappers.MatlabWrapper") +def test_file_handles(wrapper): + handle = FileHandles([ratapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")]) + + with pytest.raises(FileNotFoundError, match="The custom file \\(Test Custom File\\) does not have a valid path."): + handle.get_handle(0) + + with tempfile.NamedTemporaryFile("w", suffix=".dll") as f: + tmp_file = pathlib.Path(f.name) + handle.files[0]["path"] = tmp_file.parent + handle.files[0]["filename"] = tmp_file.name + handle.files[0]["function_name"] = "" + # No function name should throw exception + with pytest.raises( + ValueError, match="The custom file \\(Test Custom File\\) does not have a valid function name." + ): + handle.get_handle(0) + + # Matlab does not need function name + handle.files[0]["language"] = "matlab" + handle.get_handle(0) + wrapper.assert_called() + + def check_problem_equal(actual_problem, expected_problem) -> None: """Compare two instances of the "problem" object for equality.""" scalar_fields = [ diff --git a/tests/test_models.py b/tests/test_models.py index 949cf5a..4343b44 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -32,14 +32,23 @@ def test_default_names(model: Callable, model_name: str, model_params: dict) -> format: "New ". """ model_1 = model(**model_params) + prefix = f"New {model_name} " + assert model_1.name.startswith(prefix) + index = int(model_1.name[len(prefix) :]) + model_2 = model(**model_params) model_3 = model(name="Given Name", **model_params) model_4 = model(**model_params) - assert model_1.name == f"New {model_name} 1" - assert model_2.name == f"New {model_name} 2" + assert model_1.name == f"New {model_name} {index}" + assert model_2.name == f"New {model_name} {index + 1}" assert model_3.name == "Given Name" - assert model_4.name == f"New {model_name} 3" + assert model_4.name == f"New {model_name} {index + 2}" + + # If user adds name in similar format. The next auto number will take it into account. + model(name=f"{prefix}{index + 20}", **model_params) + model_5 = model(**model_params) + assert model_5.name == f"New {model_name} {index + 21}" @pytest.mark.parametrize( @@ -100,13 +109,6 @@ def test_initialise_with_extra_fields(self, model: Callable, model_params: dict) model(new_field=1, **model_params) -# def test_custom_file_path_is_absolute() -> None: -# """If we use provide a relative path to the custom file model, it should be converted to an absolute path.""" -# relative_path = pathlib.Path("./relative_path") -# custom_file = ratapi.models.CustomFile(path=relative_path) -# assert custom_file.path.is_absolute() - - def test_data_eq() -> None: """If we use the Data.__eq__ method with an object that is not a pydantic BaseModel, we should return "NotImplemented".