Source code for pyreal.transformers.impute

import numpy as np
import pandas as pd

from pyreal.transformers import TransformerBase


[docs]class MultiTypeImputer(TransformerBase): """ Imputes a data set, handling columns of different types. Imputes numeric columns with the mean, and categorical columns with the mode value. """
[docs] def __init__(self, columns=None, **kwargs): """ Initialize the base imputers """ if columns is not None and not isinstance(columns, (list, tuple, np.ndarray, pd.Index)): columns = [columns] self.columns = columns self.numeric_cols = None self.categorical_cols = None self.means = None self.modes = None super().__init__(**kwargs)
def fit(self, x, **params): """ Fit the imputer Args: x (DataFrame of shape (n_instances, n_features)): The dataset to fit to Returns: None """ if self.columns is None: self.columns = x.columns self.numeric_cols = ( x[self.columns] .dropna(axis="columns", how="all") .select_dtypes(include="number") .columns ) self.categorical_cols = ( x[self.columns] .dropna(axis="columns", how="all") .select_dtypes(exclude="number") .columns ) self.means = x[self.numeric_cols].mean(axis=0) self.modes = x[self.categorical_cols].mode(axis=0) if self.modes.shape[0] > 0: self.modes = self.modes.iloc[0, :] return super().fit(x) def data_transform(self, x): """ Imputes `x`. Numeric columns get imputed with the column mean. Categorical columns get imputed with the column mode. Args: x (DataFrame of shape (n_instances, n_features)): The dataset to impute Returns: DataFrame of shape (n_instances, n_transformed_features): The imputed dataset """ if self.numeric_cols is None: raise RuntimeError("Must fit imputer before transforming") types = x[self.columns].dtypes series_flag = False name = None if isinstance(x, pd.Series): series_flag = True name = x.name x = x.to_frame().T result = x.copy() result[self.numeric_cols] = result[self.numeric_cols].fillna(value=self.means) result[self.categorical_cols] = result[self.categorical_cols].fillna(value=self.modes) result = result.astype(types) if series_flag: result = result.squeeze() result.name = name return result