Source code for baybe.symmetries.base
"""Base class for symmetries."""
from __future__ import annotations
import gc
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import pandas as pd
from attrs import define, field
from attrs.validators import instance_of
from baybe.exceptions import IncompatibleSearchSpaceError
from baybe.serialization import SerialMixin
if TYPE_CHECKING:
from baybe.parameters.base import Parameter
from baybe.searchspace import SearchSpace
[docs]
@define(frozen=True)
class Symmetry(SerialMixin, ABC):
"""Abstract base class for symmetries.
A ``Symmetry`` is a concept that can be used to configure the modeling process in
the presence of invariances.
"""
use_data_augmentation: bool = field(
default=True, validator=instance_of(bool), kw_only=True
)
"""Flag indicating whether data augmentation is to be used."""
@property
@abstractmethod
def parameter_names(self) -> tuple[str, ...]:
"""The names of the parameters affected by the symmetry."""
[docs]
def summary(self) -> dict:
"""Return a custom summarization of the symmetry."""
symmetry_dict = dict(
Type=self.__class__.__name__,
Affected_Parameters=self.parameter_names,
Data_Augmentation=self.use_data_augmentation,
)
return symmetry_dict
[docs]
@abstractmethod
def augment_measurements(
self,
measurements: pd.DataFrame,
parameters: Iterable[Parameter] | None = None,
) -> pd.DataFrame:
"""Augment the given measurements according to the symmetry.
Args:
measurements: The dataframe containing the measurements to be
augmented.
parameters: Optional parameter objects carrying additional information.
Only required by specific augmentation implementations.
Returns:
The augmented dataframe including the original measurements.
"""
[docs]
def validate_searchspace_context(self, searchspace: SearchSpace) -> None:
"""Validate that the symmetry is compatible with the given searchspace.
Args:
searchspace: The searchspace to validate against.
Raises:
IncompatibleSearchSpaceError: If the symmetry affects parameters not
present in the searchspace.
"""
parameters_missing = set(self.parameter_names).difference(
searchspace.parameter_names
)
if parameters_missing:
raise IncompatibleSearchSpaceError(
f"The symmetry of type '{self.__class__.__name__}' was set up with the "
f"following parameters that are not present in the search space: "
f"{parameters_missing}."
)
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()