Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 81 additions & 6 deletions src/easyscience/variable/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,33 @@ def __init__(

@classmethod
def from_dependency(
cls, name: str, dependency_expression: str, dependency_map: Optional[dict] = None, **kwargs
cls,
name: str,
dependency_expression: str,
dependency_map: Optional[dict] = None,
desired_unit: str | sc.Unit | None = None,
**kwargs,
) -> Parameter: # noqa: E501
"""
Create a dependent Parameter directly from a dependency expression.

:param name: The name of the parameter
:param dependency_expression: The dependency expression to evaluate. This should be a string which can be evaluated by the ASTEval interpreter.
:param dependency_map: A dictionary of dependency expression symbol name and dependency object pairs. This is inserted into the asteval interpreter to resolve dependencies.
:param desired_unit: The desired unit of the dependent parameter.
:param kwargs: Additional keyword arguments to pass to the Parameter constructor.
:return: A new dependent Parameter object.
""" # noqa: E501
# Set default values for required parameters for the constructor, they get overwritten by the dependency anyways
default_kwargs = {'value': 0.0, 'unit': '', 'variance': 0.0, 'min': -np.inf, 'max': np.inf}
default_kwargs = {'value': 0.0, 'variance': 0.0, 'min': -np.inf, 'max': np.inf}
# Update with user-provided kwargs, to avoid errors.
default_kwargs.update(kwargs)
parameter = cls(name=name, **default_kwargs)
parameter.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map)
parameter.make_dependent_on(
dependency_expression=dependency_expression,
dependency_map=dependency_map,
desired_unit=desired_unit,
)
return parameter

def _update(self) -> None:
Expand All @@ -158,11 +168,20 @@ def _update(self) -> None:
) # noqa: E501
self._min.unit = temporary_parameter.unit
self._max.unit = temporary_parameter.unit

if self._desired_unit is not None:
self._convert_unit(self._desired_unit)

Comment on lines +173 to +174
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_convert_unit doesn't catch any exceptions and we aren't checking anything here. Maybe add some exception handling so it doesn't bubble up potentially leaving the parameter in a weird state?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on purpose. To keep the _update method as fast as possible, all the checks are made in the make_dependent_on method. If this method succeeds, then the _update will also always succeed.
This is also how we did it for the other updates :)

self._notify_observers()
else:
warnings.warn('This parameter is not dependent. It cannot be updated.')

def make_dependent_on(self, dependency_expression: str, dependency_map: Optional[dict] = None) -> None:
def make_dependent_on(
self,
dependency_expression: str,
dependency_map: Optional[dict] = None,
desired_unit: str | sc.Unit | None = None,
) -> None:
"""
Make this parameter dependent on another parameter. This will overwrite the current value, unit, variance, min and max.

Expand All @@ -183,6 +202,9 @@ def make_dependent_on(self, dependency_expression: str, dependency_map: Optional
A dictionary of dependency expression symbol name and dependency object pairs.
This is inserted into the asteval interpreter to resolve dependencies.

:param desired_unit:
The desired unit of the dependent parameter. If None, the default unit of the dependency expression result is used.

""" # noqa: E501
if not isinstance(dependency_expression, str):
raise TypeError('`dependency_expression` must be a string representing a valid dependency expression.')
Expand Down Expand Up @@ -212,13 +234,17 @@ def make_dependent_on(self, dependency_expression: str, dependency_map: Optional
'_dependency_map': self._dependency_map,
'_dependency_interpreter': self._dependency_interpreter,
'_clean_dependency_string': self._clean_dependency_string,
'_desired_unit': self._desired_unit,
}
for dependency in self._dependency_map.values():
dependency._detach_observer(self)

self._independent = False
self._dependency_string = dependency_expression
self._dependency_map = dependency_map if dependency_map is not None else {}
if desired_unit is not None and not (isinstance(desired_unit, str) or isinstance(desired_unit, sc.Unit)):
raise TypeError('`desired_unit` must be a string representing a valid unit.')
self._desired_unit = desired_unit
# List of allowed python constructs for the asteval interpreter
asteval_config = {
'import': False,
Expand Down Expand Up @@ -289,6 +315,17 @@ def make_dependent_on(self, dependency_expression: str, dependency_map: Optional
raise error
# Update the parameter with the dependency result
self._fixed = False

if self._desired_unit is not None:
try:
dependency_result._convert_unit(self._desired_unit)
except Exception as e:
desired_unit_for_error_message = self._desired_unit
self._revert_dependency() # also deletes self._desired_unit
raise UnitError(
f'Failed to convert unit from {dependency_result.unit} to {desired_unit_for_error_message}: {e}'
)

self._update()

def make_independent(self) -> None:
Expand All @@ -306,6 +343,7 @@ def make_independent(self) -> None:
del self._dependency_interpreter
del self._dependency_string
del self._clean_dependency_string
del self._desired_unit
else:
raise AttributeError('This parameter is already independent.')

Expand Down Expand Up @@ -470,6 +508,28 @@ def convert_unit(self, unit_str: str) -> None:
"""
self._convert_unit(unit_str)

def set_desired_unit(self, unit_str: str | sc.Unit | None) -> None:
"""
Set the desired unit for a dependent Parameter. This will convert the parameter to the desired unit.

:param unit_str: The desired unit as a string.
"""

if self._independent:
raise AttributeError('This is an independent parameter, desired unit can only be set for dependent parameters.')
if not (isinstance(unit_str, str) or isinstance(unit_str, sc.Unit) or unit_str is None):
raise TypeError('`unit_str` must be a string representing a valid unit.')

if unit_str is not None:
try:
old_unit_for_message = self.unit
self._convert_unit(unit_str)
except Exception as e:
raise UnitError(f'Failed to convert unit from {old_unit_for_message} to {unit_str}: {e}')

self._desired_unit = unit_str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing a check on the type of unit_str

self._update()

@property
def min(self) -> numbers.Number:
"""
Expand Down Expand Up @@ -580,6 +640,9 @@ def as_dict(self, skip: Optional[List[str]] = None) -> Dict[str, Any]:
# Save the dependency expression
raw_dict['_dependency_string'] = self._clean_dependency_string

if self._desired_unit is not None:
raw_dict['_desired_unit'] = self._desired_unit

# Mark that this parameter is dependent
raw_dict['_independent'] = self._independent

Expand Down Expand Up @@ -648,6 +711,7 @@ def from_dict(cls, obj_dict: dict) -> 'Parameter':
dependency_string = raw_dict.pop('_dependency_string', None)
dependency_map_serializer_ids = raw_dict.pop('_dependency_map_serializer_ids', None)
is_independent = raw_dict.pop('_independent', True)
desired_unit = raw_dict.pop('_desired_unit', None)
# Note: Keep _serializer_id in the dict so it gets passed to __init__

# Create the parameter using the base class method (serializer_id is now handled in __init__)
Expand All @@ -659,6 +723,7 @@ def from_dict(cls, obj_dict: dict) -> 'Parameter':
param._pending_dependency_map_serializer_ids = dependency_map_serializer_ids
# Keep parameter as independent initially - will be made dependent after all objects are loaded
param._independent = True
param._pending_desired_unit = desired_unit

return param

Expand Down Expand Up @@ -874,7 +939,12 @@ def __truediv__(self, other: Union[DescriptorNumber, Parameter, numbers.Number])
elif self.max <= 0:
combinations = [self.max / other.min, np.inf]
else:
combinations = [self.min / other.min, self.max / other.max, self.min / other.max, self.max / other.min]
combinations = [
self.min / other.min,
self.max / other.max,
self.min / other.max,
self.max / other.min,
]
else:
combinations = [self.min / other.value, self.max / other.value]
else:
Expand Down Expand Up @@ -1017,13 +1087,18 @@ def resolve_pending_dependencies(self) -> None:

# Establish the dependency relationship
try:
self.make_dependent_on(dependency_expression=dependency_string, dependency_map=dependency_map)
self.make_dependent_on(
dependency_expression=dependency_string,
dependency_map=dependency_map,
desired_unit=self._pending_desired_unit,
)
except Exception as e:
raise ValueError(f"Error establishing dependency '{dependency_string}': {e}")

# Clean up temporary attributes
delattr(self, '_pending_dependency_string')
delattr(self, '_pending_dependency_map_serializer_ids')
delattr(self, '_pending_desired_unit')

def _find_parameter_by_serializer_id(self, serializer_id: str) -> Optional['DescriptorNumber']:
"""Find a parameter by its serializer_id from all parameters in the global map."""
Expand Down
Loading
Loading