Module panama.feature_engineering.preprocessing

Classes

class DataPreprocessor
Expand source code
class DataPreprocessor:
    def __init__(self):
        super().__init__()

    def remap_col_values(
        self,
        sdf: DataFrame,
        input_cols: Union[str, List[str]],
        output_cols: Union[str, List[str]],
        mappings: Union[Dict, List[Dict]],
        keep_original: bool = True,
    ) -> DataFrame:
        """
        Remaps the values of one or more columns in a Spark DataFrame according to a mapping or a list of mappings.

        Args:
            sdf: The Spark DataFrame to modify.
            input_cols: The name of the column or a list of names of columns to remap.
                If a list is provided, it must have the same length as `output_cols` and `mappings`.
            output_cols: The name of the column or a list of names of columns to output the remapped values.
                If a list is provided, it must have the same length as `input_cols` and `mappings`.
            mappings: The mapping or list of mappings to use for remapping the values of the columns.
                A mapping is a dictionary that maps input values to output values.
            keep_original: Whether to keep the original values of input_cols in output_cols if a mapping is not provided for all values.
                Defaults to True.

        Returns:
            The modified Spark DataFrame.

        Raises:
            TypeError: If `input_cols`, `output_cols`, and `mappings` are not of types string, string and dictionary respectively,
                or `input_cols`, `output_cols`, and `mappings` are not lists.
        """
        if isinstance(input_cols, str) and isinstance(output_cols, str) and isinstance(mappings, dict):
            sdf = self._remap_col_values(sdf, input_cols, output_cols, mappings, keep_original)
        elif isinstance(input_cols, list) and isinstance(output_cols, list) and isinstance(mappings, list):
            for input_col, output_col, mapping in list(zip(input_cols, output_cols, mappings)):
                sdf = self._remap_col_values(sdf, input_col, output_col, mapping, keep_original)
        else:
            raise TypeError(
                f"Can not remap values with types {type(input_cols)}, {type(output_cols)} and {type(mappings)}"
            )
        return sdf

    @staticmethod
    def _remap_col_values(
        sdf: DataFrame, input_col: str, output_col: str, mapping: Dict, keep_original: bool = True
    ) -> DataFrame:
        """
        Remaps the values of a column in a Spark DataFrame according to a mapping.

        Args:
            sdf: The Spark DataFrame to modify.
            input_col: The name of the column to remap.
            output_col: The name of the column to output the remapped values.
            mapping: The mapping to use for remapping the values of the column. A mapping is a dictionary that maps input values to output values.
            keep_original: Whether to keep the original values of input_col in output_col if a mapping is not provided for all values.
                Defaults to True.

        Returns:
            The modified Spark DataFrame.

        Raises:
            ValueError: If `input_col` is not a column of `sdf`.
        """
        if input_col in sdf.columns:
            spark_mapping = F.create_map([F.lit(x) for x in chain(*mapping.items())])
            if keep_original:
                return sdf.withColumn(output_col, F.coalesce(spark_mapping[F.col(input_col)], input_col))
            else:
                return sdf.withColumn(output_col, spark_mapping[F.col(input_col)])
        else:
            raise ValueError(f"{input_col} is not a column of sdf")

    def one_hot_encode(self, sdf: DataFrame, input_cols: Union[str, List[str]], drop_last: bool = False) -> DataFrame:
        """
        Performs one-hot encoding on one or more columns of a Spark DataFrame.

        Args:
            sdf: The Spark DataFrame to encode.
            input_cols: The name of the column or a list of names of columns to encode. If a list is provided, each column will be encoded separately.
            drop_last: Whether to drop the last category in each encoded column. Defaults to False.

        Returns:
            The modified Spark DataFrame.

        Raises:
            TypeError: If `input_cols` is not a string or a list of strings.
        """
        if isinstance(input_cols, str):
            sdf = self._one_hot_encode(sdf, input_cols, drop_last)
        elif isinstance(input_cols, list):
            for input_col in input_cols:
                sdf = self._one_hot_encode(sdf, input_col, drop_last)
        else:
            raise TypeError(f"Can not encode {input_cols} of type {type(input_cols)}")
        return sdf

    @staticmethod
    def _one_hot_encode(sdf: DataFrame, input_col: str, drop_last: bool = False) -> DataFrame:
        """
        Performs one-hot encoding on a column of a Spark DataFrame.

        Args:
            sdf: The Spark DataFrame to encode.
            input_col: The name of the column to encode.
            drop_last: Whether to drop the last category in the encoded column. Defaults to False.

        Returns:
            The modified Spark DataFrame.

        Raises:
            ValueError: If `input_col` is not a column of `sdf` or if there are missing values and `drop_last` is set to True.
        """
        if input_col in sdf.columns:
            group_cols = [col for col in sdf.columns if col != input_col]
            sdf = sdf.withColumn("aux_col", F.lit(1))
            sdf = (
                sdf.withColumn(input_col, F.lower(F.concat_ws("_", F.lit(input_col), F.col(input_col))))
                .groupBy(group_cols)
                .pivot(input_col)
                .agg(F.mean("aux_col").cast("int"))
            )
            output_cols = [col for col in sdf.columns if col.startswith(input_col)]
            sdf = sdf.fillna(0, subset=output_cols)
            if input_col in sdf.columns:
                sdf = sdf.withColumnRenamed(input_col, input_col + "_missing")
                if drop_last:
                    raise ValueError(
                        f"Can not encode with drop_last set to True and missing values found in {input_col}"
                    )
            if drop_last:
                sdf = sdf.drop(output_cols[-1])
        else:
            raise ValueError(f"{input_col} is not a column of sdf")
        return sdf

    @staticmethod
    def add_prefix_col_names(sdf: DataFrame, prefix: str, exclude_cols: List = []) -> DataFrame:
        """Add a prefix to all column names in a DataFrame.

        Args:
            sdf (DataFrame): DataFrame with columns to rename.
            prefix (str): Prefix to use in new column names.
            exclude_cols (Optional, List): List of columns not to rename.

        Returns:
            DataFrame: The final DataFrame with columns renamed.
        """
        mapping = {c: prefix + c.lower() for c in sdf.columns if c not in exclude_cols}
        sdf = sdf.select([F.col(c).alias(mapping.get(c, c)) for c in sdf.columns])
        return sdf

    @staticmethod
    def remap_col_names(sdf: DataFrame, mapping: Dict, exclude_columns: List = []) -> DataFrame:
        """Remap column names in a DataFrame.

        Args:
            sdf (DataFrame): DataFrame with columns to rename.
            mapping (Dict): The mapping to use for remapping the names of the column. A mapping is a dictionary that maps names substring's input to output.
            exclude_cols (Optional, List): List of columns not to rename.

        Returns:
            DataFrame: The final DataFrame with columns renamed.
        """
        col_list = sdf.columns

        _mapping = {}  # Init internal empty columns mapping dict

        for col in col_list:
            orig_col = col
            for pattern, repl in mapping.items():
                if col not in exclude_columns:
                    col = re.sub(string=col, repl=repl, pattern=r"\b" + re.escape(pattern) + r"\b")
            _mapping.update({orig_col: col})

        return sdf.select([F.col(c).alias(_mapping.get(c, c)) for c in sdf.columns])

    @staticmethod
    def keep_last_row(
        sdf: DataFrame, partition_cols: List[str], order_col: str, drop_order_col: bool = False, reverse: bool = True
    ) -> DataFrame:
        """Filters a DataFrame to keep only the last row for each partition.

        Args:
            sdf (DataFrame): the input DataFrame to filter.
            partition_cols (List[str]): list of column names used for partitioning.
            order_col (str): the column name used for ordering within each partition.
            reverse (bool, optional): boolean indicating whether to sort in descending order (default: True).

        Returns:
            DataFrame: The resulting DataFrame containing the filtered data with only the last row for each partition.
        """
        sdf = sdf.repartition(*partition_cols)
        w = Window.partitionBy(partition_cols)
        if reverse:
            w = w.orderBy(F.desc(order_col))
        else:
            w = w.orderBy(order_col)
        sdf = sdf.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") == 1).drop("rn")
        if drop_order_col:
            sdf = sdf.drop(order_col)

        return sdf

Static methods

def add_prefix_col_names(sdf: pyspark.sql.dataframe.DataFrame, prefix: str, exclude_cols: List = []) ‑> pyspark.sql.dataframe.DataFrame

Add a prefix to all column names in a DataFrame.

Args

sdf : DataFrame
DataFrame with columns to rename.
prefix : str
Prefix to use in new column names.
exclude_cols : Optional, List
List of columns not to rename.

Returns

DataFrame
The final DataFrame with columns renamed.
def keep_last_row(sdf: pyspark.sql.dataframe.DataFrame, partition_cols: List[str], order_col: str, drop_order_col: bool = False, reverse: bool = True) ‑> pyspark.sql.dataframe.DataFrame

Filters a DataFrame to keep only the last row for each partition.

Args

sdf : DataFrame
the input DataFrame to filter.
partition_cols : List[str]
list of column names used for partitioning.
order_col : str
the column name used for ordering within each partition.
reverse : bool, optional
boolean indicating whether to sort in descending order (default: True).

Returns

DataFrame
The resulting DataFrame containing the filtered data with only the last row for each partition.
def remap_col_names(sdf: pyspark.sql.dataframe.DataFrame, mapping: Dict, exclude_columns: List = []) ‑> pyspark.sql.dataframe.DataFrame

Remap column names in a DataFrame.

Args

sdf : DataFrame
DataFrame with columns to rename.
mapping : Dict
The mapping to use for remapping the names of the column. A mapping is a dictionary that maps names substring's input to output.
exclude_cols : Optional, List
List of columns not to rename.

Returns

DataFrame
The final DataFrame with columns renamed.

Methods

def one_hot_encode(self, sdf: pyspark.sql.dataframe.DataFrame, input_cols: Union[str, List[str]], drop_last: bool = False) ‑> pyspark.sql.dataframe.DataFrame

Performs one-hot encoding on one or more columns of a Spark DataFrame.

Args

sdf
The Spark DataFrame to encode.
input_cols
The name of the column or a list of names of columns to encode. If a list is provided, each column will be encoded separately.
drop_last
Whether to drop the last category in each encoded column. Defaults to False.

Returns

The modified Spark DataFrame.

Raises

TypeError
If input_cols is not a string or a list of strings.
def remap_col_values(self, sdf: pyspark.sql.dataframe.DataFrame, input_cols: Union[str, List[str]], output_cols: Union[str, List[str]], mappings: Union[Dict, List[Dict]], keep_original: bool = True) ‑> pyspark.sql.dataframe.DataFrame

Remaps the values of one or more columns in a Spark DataFrame according to a mapping or a list of mappings.

Args

sdf
The Spark DataFrame to modify.
input_cols
The name of the column or a list of names of columns to remap. If a list is provided, it must have the same length as output_cols and mappings.
output_cols
The name of the column or a list of names of columns to output the remapped values. If a list is provided, it must have the same length as input_cols and mappings.
mappings
The mapping or list of mappings to use for remapping the values of the columns. A mapping is a dictionary that maps input values to output values.
keep_original
Whether to keep the original values of input_cols in output_cols if a mapping is not provided for all values. Defaults to True.

Returns

The modified Spark DataFrame.

Raises

TypeError
If input_cols, output_cols, and mappings are not of types string, string and dictionary respectively, or input_cols, output_cols, and mappings are not lists.