Source code for baybe.symmetries.mirror

"""Mirror symmetry."""

from __future__ import annotations

import gc
from collections.abc import Iterable
from typing import TYPE_CHECKING

import pandas as pd
from attrs import define, field
from attrs.validators import instance_of
from typing_extensions import override

from baybe.symmetries.base import Symmetry
from baybe.utils.augmentation import df_apply_mirror_augmentation
from baybe.utils.validation import validate_is_finite

if TYPE_CHECKING:
    from baybe.parameters.base import Parameter
    from baybe.searchspace import SearchSpace


[docs] @define(frozen=True) class MirrorSymmetry(Symmetry): """Class for representing mirror symmetries. A mirror symmetry expresses that certain parameters can be inflected at a mirror point without affecting the outcome of the model. For instance, when specified for parameter ``x`` and mirror point ``c``, the symmetry expresses that $f(..., c+x, ...) = f(..., c-x, ...)$. """ _parameter_name: str = field(validator=instance_of(str), alias="parameter_name") """The name of the single parameter affected by the symmetry.""" # object variables mirror_point: float = field( default=0.0, converter=float, validator=validate_is_finite, kw_only=True ) """The mirror point.""" @override @property def parameter_names(self) -> tuple[str]: return (self._parameter_name,)
[docs] @override def augment_measurements( self, measurements: pd.DataFrame, parameters: Iterable[Parameter] | None = None, ) -> pd.DataFrame: # See base class. if not self.use_data_augmentation: return measurements measurements = df_apply_mirror_augmentation( measurements, self._parameter_name, mirror_point=self.mirror_point ) return measurements
[docs] @override def validate_searchspace_context(self, searchspace: SearchSpace) -> None: """See base class. Args: searchspace: The searchspace to validate against. Raises: TypeError: If the affected parameter is not numerical. """ super().validate_searchspace_context(searchspace) param = searchspace.get_parameters_by_name(self.parameter_names)[0] if not param.is_numerical: raise TypeError( f"In a '{self.__class__.__name__}', the affected parameter must " f"be numerical. However, the parameter '{param.name}' is of " f"type '{param.__class__.__name__}' and is not numerical." )
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()