Source code for tubular.mixins

"""Contains mixin classes for use across transformers."""

from __future__ import annotations

import warnings

import narwhals as nw
from beartype import beartype
from narwhals.dtypes import DType  # noqa: F401 - required for nw.Schema see #455

from tubular._utils import (
    _convert_dataframe_to_narwhals,
    _return_narwhals_or_native_dataframe,
)
from tubular.types import DataFrame, NumericTypes


[docs] class CheckNumericMixin: """Mixin class with methods for numeric transformers."""
[docs] def classname(self) -> str: """Get name of the current class when called. Returns ------- str: name of class """ return type(self).__name__
[docs] @beartype def check_numeric_columns( self, X: DataFrame, return_native: bool = True, ) -> DataFrame: """Check column args are numeric for numeric transformers. Parameters ---------- X: DataFrame Data containing columns to check. return_native: bool indicates whether to return nw or pd/pl dataframe Returns ------- DataFrame: validated dataframe Raises ------ TypeError: if provided columns are non-numeric """ X = _convert_dataframe_to_narwhals(X) schema = X.collect_schema() non_numeric_columns = [ col for col in self.columns if schema[col] not in NumericTypes ] # sort as set ordering can be inconsistent non_numeric_columns.sort() if len(non_numeric_columns) > 0: msg = f"{self.classname()}: The following columns are not numeric in X; {non_numeric_columns}" raise TypeError(msg) return _return_narwhals_or_native_dataframe(X, return_native)
[docs] class DropOriginalMixin: """Mixin class to validate and apply 'drop_original' argument used by various transformers. Transformer deletes transformer input columns depending on boolean argument. """
[docs] def classname(self) -> str: """Get name of the current class when called. Returns ------- str: name of class """ return type(self).__name__
[docs] @staticmethod @beartype def drop_original_column( X: DataFrame, drop_original: bool, columns: list[str] | str | None, return_native: bool = True, ) -> DataFrame: """Drop input columns from X if drop_original set to True. Parameters ---------- X : DataFrame Data with columns to drop. drop_original : bool boolean dictating dropping the input columns from X after checks. columns: list[str] | str | None Object containing columns to drop return_native: bool controls whether mixin returns native or narwhals type Returns ------- X : DataFrame Transformed input X with columns dropped. """ X = _convert_dataframe_to_narwhals(X) if drop_original: X = X.drop(columns) return X.to_native() if return_native else X
[docs] class WeightColumnMixin: """Mixin class with weights functionality."""
[docs] def classname(self) -> str: """Get the name of the current class when called. Returns ------- str: name of class """ return type(self).__name__
@staticmethod def _create_unit_weights_column( X: DataFrame, return_native: bool = True, verbose: bool = False, ) -> tuple[DataFrame, str]: """Create unit weights column. Useful to streamline logic and just treat all cases as weighted, avoids branches for weights/non-weights. Function will check: - does 'unit_weights_column' already exist in data? (unlikely but check to be thorough) - if it does not, create unit weight 'unit_weights_column' - if it does, then reuse column - is it valid for our purposes? i.e. all unit weights - if not, raise warning (for verbose=True) Parameters ---------- X: DataFrame pandas, polars, or narwhals df return_native: bool controls whether to return nw or pd/pl dataframe verbose: controls verbosity Returns ------- DataFrame: DataFrame with added 'unit_weights_column' Raises ------ TypeError: if unit_weights_column already exists and is non numeric. """ X = _convert_dataframe_to_narwhals(X) unit_weights_column = "unit_weights_column" if unit_weights_column in X.collect_schema().names(): if X.schema[unit_weights_column] not in NumericTypes: error_msg = f"{unit_weights_column} is present in X and non-numeric, transformer logic requires this to be an all 1 value column." raise TypeError( error_msg, ) if verbose: warn_msg = f"column {unit_weights_column} is present in X, transformer logic will assume this column contains all 1 values." warnings.warn(warn_msg, stacklevel=2) else: # finally create dummy weights column if valid option not found X = X.with_columns(nw.lit(1).alias(unit_weights_column).cast(nw.Int8)) return _return_narwhals_or_native_dataframe( X, return_native, ), unit_weights_column
[docs] @beartype def check_weights_column(self, X: DataFrame, weights_column: str) -> None: """Validate weights column in dataframe. Parameters ---------- X: DataFrame input data weights_column: str name of weight column Raises ------ ValueError: if weights_column is missing from data ValueError: if weights_column is non-numeric """ X = _convert_dataframe_to_narwhals(X) # check if given weight is in columns if weights_column not in X.collect_schema().names(): msg = f"{self.classname()}: weight col ({weights_column}) is not present in columns of data" raise ValueError(msg) # check weight is numeric schema = X.collect_schema() if schema[weights_column] not in NumericTypes: msg = f"{self.classname()}: weight column must be numeric." raise ValueError(msg)
[docs] @staticmethod @beartype def get_valid_weights_filter_expr( weights_column: str, verbose: bool = False ) -> nw.Expr: """Validate weights column in dataframe. Parameters ---------- weights_column: str name of weight column verbose: bool control verbosity of method Returns ------- nw.Expr: expression to be used for filtering down to valid weights rows """ if verbose: warnings.warn( "Weights must be strictly positive, non-null, and finite - rows failing this will be filtered out.", stacklevel=2, ) expr_ge_0 = nw.col(weights_column) > 0 expr_not_null = ~nw.col(weights_column).is_null() expr_not_nan = ~nw.col(weights_column).is_nan() expr_finite = nw.col(weights_column).is_finite() return expr_ge_0 & expr_not_null & expr_not_nan & expr_finite