Source code for pywrangler.base

"""This module contains the BaseWrangler definition and the wrangler base
classes including wrangler descriptions and parameters.

"""
import inspect
from abc import ABC, abstractmethod

from pywrangler.util import _pprint
from pywrangler.util.helper import get_param_names


[docs]class BaseWrangler(ABC): """Defines the basic interface common to all data wranglers. In analogy to sklearn transformers (see link below), all wranglers have to implement `fit`, `transform` and `fit_transform` methods. In addition, parameters (e.g. column names) need to be provided via the `__init__` method. Furthermore, `get_params` and `set_params` methods are required for grid search and pipeline compatibility. The `fit` method contains optional fitting (e.g. compute mean and variance for scaling) which sets training data dependent transformation behaviour. The `transform` method includes the actual computational transformation. The `fit_transform` either applies the former methods in sequence or adds a new implementation of both with better performance. The `__init__` method should contain any logic behind parameter parsing and conversion. In contrast to sklearn, wranglers do only accept dataframes like objects (like pandas/pyspark/dask dataframes) as inputs to `fit` and `transform`. The relevant columns and their respective meaning is provided via the `__init__` method. In addition, wranglers may accept multiple input dataframes with different shapes. Also, the number of samples may also change between input and output (which is not allowed in sklearn). The `preserves_sample_size` indicates whether sample size (number of rows) may change during transformation. The wrangler's employed computation engine is given via `computation_engine`. See also -------- https://scikit-learn.org/stable/developers/contributing.html """ @property @abstractmethod def preserves_sample_size(self) -> bool: raise NotImplementedError @property @abstractmethod def computation_engine(self) -> str: raise NotImplementedError
[docs] def get_params(self) -> dict: """Retrieve all wrangler parameters set within the __init__ method. Returns ------- param_dict: dictionary Parameter names as keys and corresponding values as values """ base_classes = [cls for cls in inspect.getmro(self.__class__) if issubclass(cls, BaseWrangler)] ignore = ["self", "args", "kwargs"] param_names = [] for cls in base_classes[::-1]: param_names.extend(get_param_names(cls.__init__, ignore)) param_dict = {x: getattr(self, x) for x in param_names} return param_dict
[docs] def set_params(self, **params): """Set wrangler parameters Parameters ---------- params: dict Dictionary containing new values to be updated on wrangler. Keys have to match parameter names of wrangler. Returns ------- self """ valid_params = self.get_params() for key, value in params.items(): if key not in valid_params: raise ValueError('Invalid parameter {} for wrangler {}. ' 'Check the list of available parameters ' 'with `wrangler.get_params().keys()`.' .format(key, self)) setattr(self, key, value) return self
[docs] @abstractmethod def fit(self, *args, **kwargs): raise NotImplementedError
[docs] @abstractmethod def transform(self, *args, **kwargs): raise NotImplementedError
[docs] @abstractmethod def fit_transform(self, *args, **kwargs): raise NotImplementedError
def __repr__(self): template = '{wrangler_name} ({computation_engine})\n\n{parameters}'\ parameters = (_pprint.header("Parameters", 3) + _pprint.enumeration(self.get_params(), 3)) _repr = template.format(wrangler_name=self.__class__.__name__, computation_engine=self.computation_engine, parameters=parameters) if not self.preserves_sample_size: _repr += "\n\n Note: Does not preserve sample size." return _repr