Source code for baybe.targets.binary
"""Binary targets."""
import gc
import warnings
from typing import TypeAlias
import numpy as np
import pandas as pd
from attrs import define, field
from attrs.validators import instance_of
from typing_extensions import override
from baybe.exceptions import InvalidTargetValueError
from baybe.serialization import SerialMixin
from baybe.targets.base import Target
from baybe.utils.validation import validate_not_nan
ChoiceValue: TypeAlias = bool | int | float | str
"""Types of values that a :class:`BinaryTarget` can take."""
_SUCCESS_VALUE_COMP = 1.0
"""Computational representation of the success value."""
_FAILURE_VALUE_COMP = 0.0
"""Computational representation of the failure value."""
[docs]
@define(frozen=True)
class BinaryTarget(Target, SerialMixin):
"""Class for binary targets."""
# FIXME[typing]: https://github.com/python-attrs/attrs/issues/1336
success_value: ChoiceValue = field(
default=True,
validator=[instance_of(ChoiceValue), validate_not_nan],
kw_only=True,
)
"""Experimental representation of the success value."""
failure_value: ChoiceValue = field(
default=False,
validator=[instance_of(ChoiceValue), validate_not_nan],
kw_only=True,
)
"""Experimental representation of the failure value."""
@failure_value.validator
def _validate_values(self, _, value):
"""Validate that the two choice values of the target are different."""
if value == self.success_value:
raise ValueError(
f"The two choice values of a '{BinaryTarget.__name__}' must be "
f"different but the following value was provided for both choices of "
f"target '{self.name}': {value}"
)
[docs]
@override
def transform(
self, series: pd.Series | None = None, /, *, data: pd.DataFrame | None = None
) -> pd.Series:
# >>>>>>>>>> Deprecation
if not ((series is None) ^ (data is None)):
raise ValueError(
"Provide the data to be transformed as first positional argument."
)
if data is not None:
assert data.shape[1] == 1
series = data.iloc[:, 0]
warnings.warn(
"Providing a dataframe via the `data` argument is deprecated and "
"will be removed in a future version. Please pass your data "
"in form of a series as positional argument instead.",
DeprecationWarning,
)
# Mypy does not infer from the above that `series` must be a series here
assert isinstance(series, pd.Series)
# <<<<<<<<<< Deprecation
# Validate target values
invalid = series[
~series.isin([self.success_value, self.failure_value]).to_numpy()
]
if len(invalid) > 0:
raise InvalidTargetValueError(
f"The following values entered for target '{self.name}' are not in the "
f"set of accepted choice values "
f"{set((self.success_value, self.failure_value))}: {set(invalid)}"
)
# Transform
success_idx = series == self.success_value
return pd.Series(
np.where(success_idx, _SUCCESS_VALUE_COMP, _FAILURE_VALUE_COMP),
index=series.index,
name=series.name,
)
[docs]
@override
def summary(self) -> dict:
return dict(
Type=self.__class__.__name__,
Name=self.name,
Success_value=self.success_value,
Failure_value=self.failure_value,
)
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()