Source code for baybe.objectives.botorch

"""BoTorch objectives."""

from botorch.acquisition.objective import MCAcquisitionObjective
from torch import Tensor
from typing_extensions import override

from baybe.utils.basic import compose


[docs] class ChainedMCObjective(MCAcquisitionObjective): """A chained Monte Carlo objective."""
[docs] def __init__(self, *objectives: MCAcquisitionObjective) -> None: super().__init__() self.objectives = objectives
[docs] @override def forward(self, samples: Tensor, X: Tensor | None = None) -> Tensor: # noqa: D102 return compose(*(o.forward for o in self.objectives))(samples, X)