Source code for baybe.symmetries.permutation

"""Permutation symmetry."""

from __future__ import annotations

import gc
from collections.abc import Iterable, Sequence
from itertools import combinations
from typing import TYPE_CHECKING, Any

import pandas as pd
from attrs import define, field
from attrs.validators import deep_iterable, instance_of, min_len
from typing_extensions import override

from baybe.symmetries.base import Symmetry
from baybe.utils.augmentation import df_apply_permutation_augmentation

if TYPE_CHECKING:
    from baybe.parameters.base import Parameter
    from baybe.searchspace import SearchSpace


def _convert_groups(
    groups: Sequence[Sequence[str]],
) -> tuple[tuple[str, ...], ...]:
    """Convert nested sequences to a tuple of tuples, blocking bare strings."""
    if isinstance(groups, str):
        raise ValueError(
            "The 'permutation_groups' argument must be a sequence of sequences, "
            "not a string."
        )
    converted = []
    for g in groups:
        if isinstance(g, str):
            raise ValueError(
                "Each element in 'permutation_groups' must be a sequence of "
                "parameter names, not a string."
            )
        converted.append(tuple(g))
    return tuple(converted)


[docs] @define(frozen=True) class PermutationSymmetry(Symmetry): """Class for representing permutation symmetries. A permutation symmetry expresses that certain parameters can be permuted without affecting the outcome of the model. For instance, this is the case if $f(x,y) = f(y,x)$. """ permutation_groups: tuple[tuple[str, ...], ...] = field( converter=_convert_groups, validator=( min_len(1), deep_iterable( member_validator=deep_iterable( member_validator=instance_of(str), iterable_validator=min_len(2), ), ), ), ) """The permutation groups. Each group contains parameter names that are permuted in lockstep. All groups must have the same length.""" @permutation_groups.validator def _validate_permutation_groups( # noqa: DOC101, DOC103 self, _: Any, groups: tuple[tuple[str, ...], ...] ) -> None: """Validate the permutation groups. Raises: ValueError: If the groups have different lengths. ValueError: If any group contains duplicate parameters. ValueError: If any parameter name appears in multiple groups. """ # Ensure all groups have the same length if len({len(g) for g in groups}) != 1: lengths = {k + 1: len(g) for k, g in enumerate(groups)} raise ValueError( f"In a '{self.__class__.__name__}', all permutation groups " f"must have the same length. Got group lengths: {lengths}." ) # Ensure parameter names in each group are unique for group in groups: if len(set(group)) != len(group): raise ValueError( f"In a '{self.__class__.__name__}', all parameters being " f"permuted with each other must be unique. However, the " f"following group contains duplicates: {group}." ) # Ensure there is no overlap between any permutation group for a, b in combinations(groups, 2): if overlap := set(a) & set(b): raise ValueError( f"In a '{self.__class__.__name__}', parameter names cannot " f"appear in multiple permutation groups. However, the " f"following parameter names appear in several groups: " f"{overlap}." ) @override @property def parameter_names(self) -> tuple[str, ...]: return tuple(name for group in self.permutation_groups for name in group)
[docs] @override def augment_measurements( self, measurements: pd.DataFrame, parameters: Iterable[Parameter] | None = None, ) -> pd.DataFrame: # See base class. if not self.use_data_augmentation: return measurements measurements = df_apply_permutation_augmentation( measurements, self.permutation_groups ) return measurements
[docs] @override def validate_searchspace_context(self, searchspace: SearchSpace) -> None: """See base class. Args: searchspace: The searchspace to validate against. Raises: ValueError: If parameters within a permutation group are not equivalent (i.e., differ in type or specification). """ super().validate_searchspace_context(searchspace) # Ensure permuted parameters all have the same specification. # Without this, it could be attempted to read in data that is not allowed # for parameters that only allow a subset or different values compared to # parameters they are being permuted with. for group in self.permutation_groups: params = searchspace.get_parameters_by_name(group) ref = params[0] for p in params[1:]: if not ref.is_equivalent(p): raise ValueError( f"In a '{self.__class__.__name__}', all parameters " f"within a permutation group must be equivalent. " f"However, '{ref.name}' and '{p.name}' differ in " f"their specification." )
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()