Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ratapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def get_handle(self, index: int):
"""
custom_file = self.files[index]
full_path = os.path.join(custom_file["path"], custom_file["filename"])

if not os.path.isfile(full_path):
raise FileNotFoundError(f"The custom file ({custom_file['name']}) does not have a valid path.")

if not custom_file["function_name"] and custom_file["language"] != Languages.Matlab:
raise ValueError(f"The custom file ({custom_file['name']}) does not have a valid function name.")

if custom_file["language"] == Languages.Python:
file_handle = get_python_handle(custom_file["filename"], custom_file["function_name"], custom_file["path"])
elif custom_file["language"] == Languages.Matlab:
Expand Down
84 changes: 65 additions & 19 deletions ratapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pathlib
import warnings
from itertools import count
from contextlib import suppress
from typing import Any

import numpy as np
Expand All @@ -18,14 +18,41 @@


# Create a counter for each model
background_number = count(1)
contrast_number = count(1)
custom_file_number = count(1)
data_number = count(1)
domain_contrast_number = count(1)
layer_number = count(1)
parameter_number = count(1)
resolution_number = count(1)
background_number = ["Background", 0]
contrast_number = ["Contrast", 0]
custom_file_number = ["Custom File", 0]
data_number = ["Data", 0]
domain_contrast_number = ["Domain Contrast", 0]
layer_number = ["Layer", 0]
parameter_number = ["Parameter", 0]
resolution_number = ["Resolution", 0]

_model_counter = {
"Background": background_number,
"Contrast": contrast_number,
"ContrastWithRatio": contrast_number,
"CustomFile": custom_file_number,
"Data": data_number,
"DomainContrast": domain_contrast_number,
"Layer": layer_number,
"AbsorptionLayer": layer_number,
"Parameter": parameter_number,
"ProtectedParameter": parameter_number,
"Resolution": resolution_number,
}


def _model_name_factory(model_name: str) -> str:
"""Generate a unique name for model using a global counter.

Parameters
----------
model_name : str
The name of the model class.
"""
title, number = _model_counter[model_name]
_model_counter[model_name][1] += 1
return f"New {title} {(number + 1)}"


class RATModel(BaseModel, validate_assignment=True, extra="forbid"):
Expand All @@ -38,6 +65,25 @@ def __repr__(self):
)
return f"{self.__repr_name__()}({fields_repr})"

@field_validator("name", mode="after", check_fields=False)
@classmethod
def update_counter(cls, name: str) -> str:
"""Update the auto name counter if a similar name is manually given.

Parameters
----------
name : str
The name of the model.
"""
title, number = _model_counter[cls.__name__]
prefix = f"New {title} "
if name.startswith(prefix):
with suppress(ValueError):
new_number = int(name[len(prefix) :])
if new_number > number:
_model_counter[cls.__name__][1] = new_number
return name

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.display_fields]
Expand Down Expand Up @@ -116,7 +162,7 @@ class Background(Signal):

"""

name: str = Field(default_factory=lambda: f"New Background {next(background_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Background"), min_length=1)

@model_validator(mode="after")
def check_unsupported_parameters(self):
Expand Down Expand Up @@ -173,7 +219,7 @@ class Contrast(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Contrast"), min_length=1)
data: str = ""
background: str = ""
background_action: BackgroundActions = BackgroundActions.Add
Expand Down Expand Up @@ -255,7 +301,7 @@ class ContrastWithRatio(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Contrast {next(contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("ContrastWithRatio"), min_length=1)
data: str = ""
background: str = ""
background_action: BackgroundActions = BackgroundActions.Add
Expand Down Expand Up @@ -309,7 +355,7 @@ class CustomFile(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Custom File {next(custom_file_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("CustomFile"), min_length=1)
filename: str = ""
function_name: str = ""
language: Languages = Languages.Python
Expand Down Expand Up @@ -348,7 +394,7 @@ class Data(RATModel, arbitrary_types_allowed=True):

"""

name: str = Field(default_factory=lambda: f"New Data {next(data_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Data"), min_length=1)
data: np.ndarray = np.empty([0, 3])
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)
Expand Down Expand Up @@ -453,7 +499,7 @@ class DomainContrast(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Domain Contrast {next(domain_contrast_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("DomainContrast"), min_length=1)
model: list[str] = []

def __str__(self):
Expand Down Expand Up @@ -483,7 +529,7 @@ class Layer(RATModel, populate_by_name=True):

"""

name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Layer"), min_length=1)
thickness: str
SLD: str = Field(validation_alias="SLD_real")
roughness: str
Expand Down Expand Up @@ -522,7 +568,7 @@ class AbsorptionLayer(RATModel, populate_by_name=True):

"""

name: str = Field(default_factory=lambda: f"New Layer {next(layer_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("AbsorptionLayer"), min_length=1)
thickness: str
SLD_real: str = Field(validation_alias="SLD")
SLD_imaginary: str
Expand Down Expand Up @@ -555,7 +601,7 @@ class Parameter(RATModel):

"""

name: str = Field(default_factory=lambda: f"New Parameter {next(parameter_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Parameter"), min_length=1)
min: float = 0.0
value: float = 0.0
max: float = 0.0
Expand Down Expand Up @@ -638,7 +684,7 @@ class Resolution(Signal):

"""

name: str = Field(default_factory=lambda: f"New Resolution {next(resolution_number)}", min_length=1)
name: str = Field(default_factory=lambda: _model_name_factory("Resolution"), min_length=1)

@field_validator("type")
@classmethod
Expand Down
26 changes: 26 additions & 0 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pathlib
import pickle
import tempfile
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -675,6 +677,30 @@ def test_make_controls(standard_layers_controls) -> None:
check_controls_equal(controls, standard_layers_controls)


@patch("ratapi.wrappers.MatlabWrapper")
def test_file_handles(wrapper):
handle = FileHandles([ratapi.models.CustomFile(name="Test Custom File", filename="cpp_test.dll", language="cpp")])

with pytest.raises(FileNotFoundError, match="The custom file \\(Test Custom File\\) does not have a valid path."):
handle.get_handle(0)

with tempfile.NamedTemporaryFile("w", suffix=".dll") as f:
tmp_file = pathlib.Path(f.name)
handle.files[0]["path"] = tmp_file.parent
handle.files[0]["filename"] = tmp_file.name
handle.files[0]["function_name"] = ""
# No function name should throw exception
with pytest.raises(
ValueError, match="The custom file \\(Test Custom File\\) does not have a valid function name."
):
handle.get_handle(0)

# Matlab does not need function name
handle.files[0]["language"] = "matlab"
handle.get_handle(0)
wrapper.assert_called()


def check_problem_equal(actual_problem, expected_problem) -> None:
"""Compare two instances of the "problem" object for equality."""
scalar_fields = [
Expand Down
22 changes: 12 additions & 10 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ def test_default_names(model: Callable, model_name: str, model_params: dict) ->
format: "New <model name> <integer>".
"""
model_1 = model(**model_params)
prefix = f"New {model_name} "
assert model_1.name.startswith(prefix)
index = int(model_1.name[len(prefix) :])

model_2 = model(**model_params)
model_3 = model(name="Given Name", **model_params)
model_4 = model(**model_params)

assert model_1.name == f"New {model_name} 1"
assert model_2.name == f"New {model_name} 2"
assert model_1.name == f"New {model_name} {index}"
assert model_2.name == f"New {model_name} {index + 1}"
assert model_3.name == "Given Name"
assert model_4.name == f"New {model_name} 3"
assert model_4.name == f"New {model_name} {index + 2}"

# If user adds name in similar format. The next auto number will take it into account.
model(name=f"{prefix}{index + 20}", **model_params)
model_5 = model(**model_params)
assert model_5.name == f"New {model_name} {index + 21}"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,13 +109,6 @@ def test_initialise_with_extra_fields(self, model: Callable, model_params: dict)
model(new_field=1, **model_params)


# def test_custom_file_path_is_absolute() -> None:
# """If we use provide a relative path to the custom file model, it should be converted to an absolute path."""
# relative_path = pathlib.Path("./relative_path")
# custom_file = ratapi.models.CustomFile(path=relative_path)
# assert custom_file.path.is_absolute()


def test_data_eq() -> None:
"""If we use the Data.__eq__ method with an object that is not a pydantic BaseModel, we should return
"NotImplemented".
Expand Down