Source code for baybe.utils.augmentation

"""Utilities related to data augmentation."""

from collections.abc import Collection, Sequence
from itertools import permutations, product

import pandas as pd


[docs] def df_apply_permutation_augmentation( df: pd.DataFrame, permutation_groups: Sequence[Sequence[str]], ) -> pd.DataFrame: """Augment a dataframe if permutation invariant columns are present. Each group in ``permutation_groups`` contains the names of columns that are permuted in lockstep. All groups must have the same length, and that length must be at least 2 (otherwise there is nothing to permute). Args: df: The dataframe that should be augmented. permutation_groups: Groups of column names that are permuted in lockstep. For example, ``[["A1", "A2"], ["B1", "B2"]]`` means that the columns ``A1`` and ``A2`` are permuted, and ``B1`` and ``B2`` are permuted in the same way. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Raises: ValueError: If no permutation groups are given. ValueError: If any permutation group has fewer than two entries. ValueError: If the permutation groups have differing amounts of entries. Examples: >>> df = pd.DataFrame({'A1':[1,2],'A2':[3,4], 'B1': [5, 6], 'B2': [7, 8]}) >>> df A1 A2 B1 B2 0 1 3 5 7 1 2 4 6 8 >>> groups = [['A1', 'A2']] >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 0 3 1 5 7 1 2 4 6 8 1 4 2 6 8 >>> groups = [['A1', 'A2'], ['B1', 'B2']] >>> dfa = df_apply_permutation_augmentation(df, groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 0 3 1 7 5 1 2 4 6 8 1 4 2 8 6 """ # Validation if len(permutation_groups) < 1: raise ValueError("Permutation augmentation requires at least one group.") if len({len(seq) for seq in permutation_groups}) != 1: raise ValueError( "Permutation augmentation can only work if all groups have the same " "number of entries." ) if len(permutation_groups[0]) < 2: raise ValueError( "Permutation augmentation can only work if each group has at " "least two entries." ) # Augmentation Loop n_positions = len(permutation_groups[0]) new_rows: list[pd.DataFrame] = [] idx_permutation = list(permutations(range(n_positions))) for _, row in df.iterrows(): # For each row in the original df, collect all its permutations to_add = [] for perm in idx_permutation: new_row = row.copy() # Permute columns within each group according to the permutation for group in permutation_groups: cols = list(group) new_row[cols] = row[[cols[k] for k in perm]] to_add.append(new_row) # Store augmented rows, keeping the index of their original row new_rows.append( pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) ) return pd.concat(new_rows)
[docs] def df_apply_mirror_augmentation( df: pd.DataFrame, column: str, *, mirror_point: float = 0.0, ) -> pd.DataFrame: """Augment a dataframe for a mirror invariant column. Args: df: The dataframe that should be augmented. column: The name of the affected column. mirror_point: The point along which to mirror the values. Points that have exactly this value will not be augmented. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Examples: >>> df = pd.DataFrame({'A':[1, 0, -2], 'B': [3, 4, 5]}) >>> df A B 0 1 3 1 0 4 2 -2 5 >>> dfa = df_apply_mirror_augmentation(df, "A") >>> dfa A B 0 1 3 0 -1 3 1 0 4 2 -2 5 2 2 5 >>> dfa = df_apply_mirror_augmentation(df, "A", mirror_point=1) >>> dfa A B 0 1 3 1 0 4 1 2 4 2 -2 5 2 4 5 """ new_rows: list[pd.DataFrame] = [] for _, row in df.iterrows(): to_add = [row] # Always keep original row # Create the augmented row by mirroring the point at the mirror point. # x_mirrored = mirror_point + (mirror_point - x) = 2*mirror_point - x if row[column] != mirror_point: row_new = row.copy() row_new[column] = 2.0 * mirror_point - row[column] to_add.append(row_new) # Store augmented rows, keeping the index of their original row new_rows.append( pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) ) return pd.concat(new_rows)
[docs] def df_apply_dependency_augmentation( df: pd.DataFrame, causing: tuple[str, Sequence], affected: Collection[tuple[str, Sequence]], ) -> pd.DataFrame: """Augment a dataframe if dependency invariant columns are present. This works with the concept of column-values pairs for causing and affected column. Any row present where the specified causing column has one of the provided values will trigger an augmentation on the affected columns. The latter are augmented by going through all their invariant values and adding respective new rows. Args: df: The dataframe that should be augmented. causing: Causing column name and its causing values. affected: Affected columns and their invariant values. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Examples: >>> df = pd.DataFrame({'A':[0,1],'B':[2,3], 'C': [5, 5], 'D': [6, 7]}) >>> df A B C D 0 0 2 5 6 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3,4])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 0 0 4 5 6 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3,4])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 0 0 4 5 6 1 1 3 5 7 >>> causing = ('A', [0, 1]) >>> affected = [('B', [2,3])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 1 1 2 5 7 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3]), ('C', [5, 6])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 2 6 6 0 0 3 5 6 0 0 3 6 6 1 1 3 5 7 """ new_rows: list[pd.DataFrame] = [] col_causing, vals_causing = causing affected_cols, affected_inv_vals = zip(*affected) affected_inv_vals_combinations = list(product(*affected_inv_vals)) # Iterate through all rows that have a causing value in the respective column. for _, row in df.iterrows(): to_add = [] if row[col_causing] not in vals_causing: # Just keep unaffected rows without augmentation to_add.append(row) else: # Create augmented rows by assigning the affected columns all possible # values to_add += [ pd.Series({**row.to_dict(), **dict(zip(affected_cols, values))}) for values in affected_inv_vals_combinations ] # Store augmented rows, keeping the index of their original row new_rows.append( pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) ) return pd.concat(new_rows)