Module panama.feature_engineering.feature_extraction

Classes

class FeatureExtractor
Expand source code
class FeatureExtractor:
    def __init__(self):
        super().__init__()

    def count_over_period(
        self,
        sdf: DataFrame,
        event_date: str,
        event_names: List[str],
        interval: str,
        reference_date: Union[str, datetime, date],
        group_cols: List[str],
        filter_by: str = "1=1",
        direction="backward",
    ) -> DataFrame:
        """
        Computes the count of events over a specified time period for each group in the input DataFrame.

        Args:
            sdf: Input DataFrame containing event data.
            event_date: Name of the column in `sdf` containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            event_names: List of names of the columns containing events to count.
            interval: Time interval of the period to filter on, e.g., "1 day".
            reference_date: Date to use as the reference point for interval calculation.
            group_cols: List of column names to group by.
            filter_by: SQL-like filter expression to apply to `sdf`. Defaults to "1=1".
            direction: Direction in which to compute the time interval with respect to the reference date, either "backward" or "forward".
                Defaults to "backward".

        Returns:
            DataFrame containing counts of events over the specified time period for each group.
        """
        sdf = sdf.filter(filter_by)
        sdf = self._apply_interval(
            sdf=sdf, event_date=event_date, interval=interval, reference_date=reference_date, direction=direction
        )
        return self.aggr_over_period(
            sdf=sdf,
            event_date=event_date,
            event_names=event_names,
            interval=interval,
            reference_date=reference_date,
            group_cols=group_cols,
            filter_by=filter_by,
            aggr_fun="count",
            aggr_col=None,  # type: ignore
        )

    def _apply_interval(
        self,
        sdf: DataFrame,
        event_date: str,
        interval: str,
        reference_date: Union[str, datetime, date],
        direction: str = "backward",
    ) -> DataFrame:
        """
        Applies a time interval filter to the input DataFrame and returns the filtered DataFrame.

        Args:
            sdf: Input DataFrame containing event data.
            event_date: Name of the column in `sdf` containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            interval: Time interval of the period to filter on, e.g., "1 day".
            reference_date: Date to use as the reference point for interval calculation.
            direction: Direction in which to compute the time interval with respect to the reference date, either "backward" or "forward".
                Defaults to "backward".

        Returns:
            Filtered DataFrame containing events that fall within the specified time interval.

        Raises:
            ValueError: If `direction` is not "backward" or "forward".
            TypeError: If `reference_date` is not a datetime, or date object,
                if `reference_date` is a string that does not match the expected date format ("yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss"),
                or if `reference_date` is a string but not the name of a column of `sdf`
        """
        if direction not in (["backward", "forward"]):
            raise ValueError(f"{direction} not valid for direction")
        if isinstance(reference_date, str) and reference_date in sdf.columns:
            return self._filter_by_col_reference(
                sdf=sdf, event_date=event_date, interval=interval, reference_date=reference_date, direction=direction
            )
        elif isinstance(reference_date, str):
            try:
                reference_date = datetime.strptime(reference_date, "%Y-%m-%d")
            except:
                try:
                    reference_date = datetime.strptime(reference_date, "%Y-%m-%d %H:%M:%S")
                except:
                    raise TypeError(f"{reference_date} not valid for reference_date")
            return self._filter_by_single_reference(
                sdf=sdf, event_date=event_date, interval=interval, reference_date=reference_date, direction=direction
            )
        elif isinstance(reference_date, (datetime, date)):
            return self._filter_by_single_reference(
                sdf=sdf, event_date=event_date, interval=interval, reference_date=reference_date, direction=direction
            )
        else:
            raise TypeError(f"{reference_date} not valid for reference_date")

    def _filter_by_col_reference(
        self, sdf: DataFrame, event_date: str, interval: str, reference_date: str, direction: str
    ) -> DataFrame:
        """
        Filters the input DataFrame based on a time interval defined by a reference date column.

        Args:
            sdf: Input DataFrame containing event data.
            event_date: Name of the column in `sdf` containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            interval: Time interval of the period to filter on, e.g., "1 day".
            reference_date: Name of the column in `sdf` containing reference dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            direction: Direction in which to compute the time interval with respect to the reference date, either "backward" or "forward".

        Returns:
            Filtered DataFrame containing events that fall within the specified time interval.
        """
        if direction == "backward":
            sdf = sdf.filter(
                F.col(event_date).between(
                    F.col(reference_date) - F.expr("INTERVAL " + interval),
                    F.col(reference_date),
                )
            )
        else:
            sdf = sdf.filter(
                F.col(event_date).between(
                    F.col(reference_date),
                    F.col(reference_date) + F.expr("INTERVAL " + interval),
                )
            )
        return sdf

    def _filter_by_single_reference(
        self, sdf: DataFrame, event_date: str, interval: str, reference_date: Union[datetime, date], direction: str
    ) -> DataFrame:
        """
        Filters the input DataFrame based on a time interval defined by a single reference date.

        Args:
            sdf: Input DataFrame containing event data.
            event_date: Name of the column in `sdf` containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            interval: Time interval of the period to filter on, e.g., "1 day".
            reference_date: Reference date to use as the basis for the time interval filter.
            direction: Direction in which to compute the time interval with respect to the reference date, either "backward" or "forward".

        Returns:
            Filtered DataFrame containing events that fall within the specified time interval.
        """
        if direction == "backward":
            sdf = sdf.filter(
                F.col(event_date).between(
                    reference_date - F.expr("INTERVAL " + interval),  # type: ignore
                    reference_date,
                )
            )
        else:
            sdf = sdf.filter(
                F.col(event_date).between(
                    reference_date,
                    reference_date + F.expr("INTERVAL " + interval),  # type: ignore
                )
            )
        return sdf

    def _create_pivot_event_col(
        self, sdf: DataFrame, event_names: List[str], aggr_col: Optional[str] = None
    ) -> DataFrame:
        """
        Creates a new column in the input DataFrame by concatenating the values of multiple columns.

        Args:
            sdf: Input DataFrame containing event data.
            event_names: List of names of columns in `sdf` to concatenate.
            aggr_col: Name of the column to use for filtering rows with missing values. Defaults to None.

        Returns:
            Modified DataFrame with a new column containing concatenated values, and original columns dropped.

        Raises:
            ValueError: If any of the `event_names` are not columns of `sdf`.
            TypeError: If `event_names` is not a list.
        """
        if aggr_col is not None:
            sdf = sdf.dropna(subset=aggr_col)
        if isinstance(event_names, list):
            for event_name in event_names:
                if event_name in sdf.columns:
                    sdf = sdf.withColumn(
                        event_name, F.concat_ws("_", F.lit(event_name), F.coalesce(F.col(event_name), F.lit("missing")))
                    )
                else:
                    raise ValueError(f"{event_name} is not a column of sdf")
            sdf = sdf.withColumn("pivot_col", F.concat_ws("_", *event_names))
        else:
            raise TypeError(f"{type(event_names)} not valid for event_names")
        return sdf.drop(*event_names)

    def aggr_over_period(
        self,
        sdf: DataFrame,
        event_date: str,
        event_names: List[str],
        interval: str,
        reference_date: Union[str, datetime, date],
        group_cols: List[str],
        aggr_col: str,
        filter_by: str = "1=1",
        aggr_fun: str = "sum",
        direction: str = "backward",
    ) -> DataFrame:
        """
        Computes the specified aggregation function for event data over a specified time period, grouped by specified columns.

        Args:
            sdf: Input DataFrame containing event data.
            event_date: Name of the column in `sdf` containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
            event_names: List of names of columns in `sdf` containing event data to aggregate.
            interval: Time interval of the period to filter on, e.g., "1 day".
            reference_date: Date to use as the reference point for interval calculation.
            group_cols: List of column names to group the results by.
            filter_by: SQL-like filter statement to apply to the input DataFrame before aggregation. Defaults to "1=1".
            aggr_fun: Aggregation function to apply to the event data. Defaults to "sum".
            aggr_col: Name of the column to aggregate. Defaults to None.
            direction: Direction in which to expand the time period, either "backward" or "forward". Defaults to "backward".

        Returns:
            DataFrame containing aggregated event data.
        """
        sdf = sdf.filter(filter_by)
        sdf = self._apply_interval(
            sdf=sdf, event_date=event_date, interval=interval, reference_date=reference_date, direction=direction
        )
        sdf = sdf.withColumn("aux_col", F.lit(1))
        sdf = self._create_pivot_event_col(sdf=sdf, event_names=event_names, aggr_col=aggr_col)
        if aggr_fun == "count":
            aggr_col = "aux_col"
        else:
            if aggr_col is None:
                raise ValueError(f"aggr_col can not be None")
        sdf = sdf.groupBy(group_cols).pivot("pivot_col").agg({aggr_col: aggr_fun})
        output_cols = [col for col in sdf.columns if col.startswith(tuple(event_names))]
        if aggr_fun == "count":
            sdf = sdf.fillna(0, subset=output_cols)
            for col in output_cols:
                sdf = sdf.withColumn(col, F.col(col).cast("int"))
                sdf = sdf.withColumnRenamed(col, f"{aggr_fun}_{col}_{interval.replace(' ', '')}".lower())
        else:
            for col in output_cols:
                sdf = sdf.withColumnRenamed(col, f"{aggr_fun}_{col}_{aggr_col}_{interval.replace(' ', '')}".lower())
        return sdf

    def compute_col_lag_fast(
        self,
        sdf: DataFrame,
        lag_col: str,
        lag_value: int,
        value_col: str,
        join_cols: Optional[List[str]] = None,
    ) -> DataFrame:
        """
        Computes the lag of a column in a Spark DataFrame faster using window and lag functions.

        Args:
            sdf: Input DataFrame.
            lag_col: Name of the column to use for ordering the lag.
            lag_value: Number of rows to lag.
            value_col: Name of the column to compute the lag for.
            join_cols: List of the columns to use for partitioning the data. Defaults to None.

        Returns:
            Input DataFrame with a new column containing the lagged values.
        """
        sdf = sdf.withColumn(
            f"{value_col}-{lag_value}",
            F.lag(F.col(value_col), lag_value).over(Window.partitionBy(*join_cols).orderBy(F.col(lag_col))),  # type: ignore
        )
        return sdf

    def compute_col_lag_safe(
        self,
        sdf: DataFrame,
        lag_col: str,
        lag_interval: str,
        value_col: str,
        join_cols: List[str] = [],
    ) -> DataFrame:
        """
        Computes the lag of a column in a Spark DataFrame safely checking for lags using join.

        Args:
            sdf: Input DataFrame.
            lag_col: Name of the column to use for ordering the lag.
            lag_interval: Number of unit to lag and unit of measure of lags, e.g., "1 day".
            value_col: Name of the column to compute the lag for.
            join_cols: List of the columns to use to join the data and so compute lag.

        Returns:
            Input DataFrame with a new column containing the lagged values.
        """
        if lag_col in join_cols:
            join_cols = [c for c in join_cols if c != lag_col] + [f"{lag_col}_join"]
        else:
            join_cols.append(f"{lag_col}_join")
        sdf = (
            sdf.withColumn(f"{lag_col}_join", F.col(lag_col) - F.expr(f"INTERVAL {lag_interval}"))
            .join(
                sdf.withColumnRenamed(lag_col, f"{lag_col}_join").select(
                    *join_cols,
                    F.col(value_col).alias(f"{value_col}-{lag_interval.replace(' ', '')}"),
                ),
                on=join_cols,
                how="left",
            )
            .drop(f"{lag_col}_join")
        )
        return sdf

    def create_categorical_col_from_bins(
        self,
        sdf: DataFrame,
        input_col: str,
        output_col: str,
        bins: Union[List[int], List[float]],
    ) -> DataFrame:
        """
        Creates a new categorical column in a Spark DataFrame by binning a numerical column based on the specified bins.

        Args:
            sdf: Input DataFrame.
            input_col: Name of the numerical column to be binned.
            output_col: Name of the new categorical column to be created.
            bins: List of bin edges. The first element must be -inf and the last element must be inf.

        Returns:
            Input DataFrame with a new categorical column containing the binned values.

        Raises:
            ValueError: If the specified bins list is not valid. The first element must be -inf and the last element must be inf.
        """
        bins.sort()
        if bins[0] != -float("inf") or bins[-1] != float("inf"):
            raise ValueError(
                f"{bins} not valid for bins: the first element must be -inf and the last element must be inf"
            )
        bucketizer = Bucketizer(splits=bins, inputCol=input_col, outputCol=output_col, handleInvalid="keep")  # type: ignore
        sdf = bucketizer.transform(sdf)
        mapping = self._create_categorical_col_mapping(bins)
        dp = DataPreprocessor()
        sdf = dp.remap_col_values(
            sdf=sdf, input_cols=output_col, output_cols=output_col, mappings=mapping, keep_original=False
        )
        return sdf

    def create_categorical_col_from_num_bins(
        self,
        sdf: DataFrame,
        input_col: str,
        output_col: str,
        bins: int,
    ) -> DataFrame:
        """
        Creates a new categorical column in a Spark DataFrame by binning a numerical column based on the specified number of bins.

        Args:
            sdf: Input DataFrame.
            input_col: Name of the numerical column to be binned.
            output_col: Name of the new categorical column to be created.
            bins: Number of bins to use.

        Returns:
            Input DataFrame with a new categorical column containing the binned values.
        """
        quantile_discretizer = QuantileDiscretizer(
            numBuckets=bins, inputCol=input_col, outputCol=output_col, handleInvalid="keep"
        )
        sdf_fitted = quantile_discretizer.fit(sdf)
        bins = sdf_fitted.getSplits()
        sdf = sdf_fitted.transform(sdf)
        mapping = self._create_categorical_col_mapping(bins)  # type: ignore
        dp = DataPreprocessor()
        sdf = dp.remap_col_values(
            sdf=sdf, input_cols=output_col, output_cols=output_col, mappings=mapping, keep_original=False
        )
        return sdf

    def _create_categorical_col_mapping(self, bins: Union[List[int], List[float]]) -> Dict:
        """
        Creates a mapping dictionary for a categorical column based on the specified list of bin edges.

        Args:
            bins: List of bin edges.

        Returns:
            Dictionary that maps the bin index to the corresponding bin range. The last index is mapped to None.
        """
        mapping_values = []
        for i in range(len(bins) - 2):
            mapping_values.append(f"[{bins[i]}, {bins[i+1]})")
        mapping_values.append(f"[{bins[len(bins)-2]}, {bins[len(bins)-1]}]")
        mapping_values.append(None)
        mapping = dict(zip(range(len(bins) + 1), mapping_values))
        return mapping

    def get_most_least_frequent_category(
        self,
        sdf: DataFrame,
        event_name: str,
        group_cols: List[str],
        aggr_col: Optional[str] = None,
        aggr_fun: str = "count",
        most_least: str = "most",
    ) -> DataFrame:
        """
            Returns a DataFrame with the most or least frequent value of a categorical variable in each group defined by the group_cols.
            If aggr_fun (different from "count") and aggr_col are provided, for each group defined by the group_cols, it returns the value of event_col corresponding
            to the maximum or minimum value of aggr_fun computed on aggr_col.
            It drops null values in aggr_col and event_name.

        Args:
            sdf: Input DataFrame.
            event_name: Name of the column in `sdf` containing the categorical variable to determine the most or least frequent category of.
            group_cols: List of column names to group by.
            aggr_col: Name of the column to aggregate. If None, it defaults to event_name.
            aggr_fun: Aggregation function to use. Defaults to 'count'.
            most_least: Type of frequency to return: 'most' for most frequent or 'least' for least frequent. Defaults to 'most'.

        Returns:
            DataFrame with the most or least frequent value of the categorical variable in each group defined by group_cols.
        """
        if aggr_col == None:
            aggr_col = event_name
        sdf = sdf.dropna(subset=[aggr_col, event_name])
        sdf = (
            sdf.groupBy(group_cols + [event_name])
            .agg({aggr_col: aggr_fun})
            .withColumnRenamed(f"{aggr_fun}({aggr_col})", f"{aggr_fun}_{aggr_col}")
        )
        if aggr_fun == "count":
            sdf = sdf.withColumn(f"{aggr_fun}_{aggr_col}", F.col(f"{aggr_fun}_{aggr_col}").cast("int"))
        if most_least == "most":
            w = Window.partitionBy(group_cols).orderBy(F.desc(F.col(f"{aggr_fun}_{aggr_col}")))
        elif most_least == "least":
            w = Window.partitionBy(group_cols).orderBy(F.col(f"{aggr_fun}_{aggr_col}"))
        else:
            raise ValueError(f"{most_least} not valid for most_least")
        sdf = sdf.withColumn("row_number", F.rank().over(w)).filter(F.col("row_number") == 1)
        sdf = sdf.withColumn("row_number", F.count("*").over(w)).withColumn("tie", F.col("row_number") > 1)
        return sdf.withColumnRenamed(event_name, f"{most_least}_frequent_{event_name}").drop("row_number")

Methods

def aggr_over_period(self, sdf: pyspark.sql.dataframe.DataFrame, event_date: str, event_names: List[str], interval: str, reference_date: Union[str, datetime.datetime, datetime.date], group_cols: List[str], aggr_col: str, filter_by: str = '1=1', aggr_fun: str = 'sum', direction: str = 'backward') ‑> pyspark.sql.dataframe.DataFrame

Computes the specified aggregation function for event data over a specified time period, grouped by specified columns.

Args

sdf
Input DataFrame containing event data.
event_date
Name of the column in sdf containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
event_names
List of names of columns in sdf containing event data to aggregate.
interval
Time interval of the period to filter on, e.g., "1 day".
reference_date
Date to use as the reference point for interval calculation.
group_cols
List of column names to group the results by.
filter_by
SQL-like filter statement to apply to the input DataFrame before aggregation. Defaults to "1=1".
aggr_fun
Aggregation function to apply to the event data. Defaults to "sum".
aggr_col
Name of the column to aggregate. Defaults to None.
direction
Direction in which to expand the time period, either "backward" or "forward". Defaults to "backward".

Returns

DataFrame containing aggregated event data.

def compute_col_lag_fast(self, sdf: pyspark.sql.dataframe.DataFrame, lag_col: str, lag_value: int, value_col: str, join_cols: Optional[List[str]] = None) ‑> pyspark.sql.dataframe.DataFrame

Computes the lag of a column in a Spark DataFrame faster using window and lag functions.

Args

sdf
Input DataFrame.
lag_col
Name of the column to use for ordering the lag.
lag_value
Number of rows to lag.
value_col
Name of the column to compute the lag for.
join_cols
List of the columns to use for partitioning the data. Defaults to None.

Returns

Input DataFrame with a new column containing the lagged values.

def compute_col_lag_safe(self, sdf: pyspark.sql.dataframe.DataFrame, lag_col: str, lag_interval: str, value_col: str, join_cols: List[str] = []) ‑> pyspark.sql.dataframe.DataFrame

Computes the lag of a column in a Spark DataFrame safely checking for lags using join.

Args

sdf
Input DataFrame.
lag_col
Name of the column to use for ordering the lag.
lag_interval
Number of unit to lag and unit of measure of lags, e.g., "1 day".
value_col
Name of the column to compute the lag for.
join_cols
List of the columns to use to join the data and so compute lag.

Returns

Input DataFrame with a new column containing the lagged values.

def count_over_period(self, sdf: pyspark.sql.dataframe.DataFrame, event_date: str, event_names: List[str], interval: str, reference_date: Union[str, datetime.datetime, datetime.date], group_cols: List[str], filter_by: str = '1=1', direction='backward') ‑> pyspark.sql.dataframe.DataFrame

Computes the count of events over a specified time period for each group in the input DataFrame.

Args

sdf
Input DataFrame containing event data.
event_date
Name of the column in sdf containing event dates in format "yyyy-MM-dd" or "yyyy-MM-dd HH:mm:ss".
event_names
List of names of the columns containing events to count.
interval
Time interval of the period to filter on, e.g., "1 day".
reference_date
Date to use as the reference point for interval calculation.
group_cols
List of column names to group by.
filter_by
SQL-like filter expression to apply to sdf. Defaults to "1=1".
direction
Direction in which to compute the time interval with respect to the reference date, either "backward" or "forward". Defaults to "backward".

Returns

DataFrame containing counts of events over the specified time period for each group.

def create_categorical_col_from_bins(self, sdf: pyspark.sql.dataframe.DataFrame, input_col: str, output_col: str, bins: Union[List[int], List[float]]) ‑> pyspark.sql.dataframe.DataFrame

Creates a new categorical column in a Spark DataFrame by binning a numerical column based on the specified bins.

Args

sdf
Input DataFrame.
input_col
Name of the numerical column to be binned.
output_col
Name of the new categorical column to be created.
bins
List of bin edges. The first element must be -inf and the last element must be inf.

Returns

Input DataFrame with a new categorical column containing the binned values.

Raises

ValueError
If the specified bins list is not valid. The first element must be -inf and the last element must be inf.
def create_categorical_col_from_num_bins(self, sdf: pyspark.sql.dataframe.DataFrame, input_col: str, output_col: str, bins: int) ‑> pyspark.sql.dataframe.DataFrame

Creates a new categorical column in a Spark DataFrame by binning a numerical column based on the specified number of bins.

Args

sdf
Input DataFrame.
input_col
Name of the numerical column to be binned.
output_col
Name of the new categorical column to be created.
bins
Number of bins to use.

Returns

Input DataFrame with a new categorical column containing the binned values.

def get_most_least_frequent_category(self, sdf: pyspark.sql.dataframe.DataFrame, event_name: str, group_cols: List[str], aggr_col: Optional[str] = None, aggr_fun: str = 'count', most_least: str = 'most') ‑> pyspark.sql.dataframe.DataFrame

Returns a DataFrame with the most or least frequent value of a categorical variable in each group defined by the group_cols. If aggr_fun (different from "count") and aggr_col are provided, for each group defined by the group_cols, it returns the value of event_col corresponding to the maximum or minimum value of aggr_fun computed on aggr_col. It drops null values in aggr_col and event_name.

Args

sdf
Input DataFrame.
event_name
Name of the column in sdf containing the categorical variable to determine the most or least frequent category of.
group_cols
List of column names to group by.
aggr_col
Name of the column to aggregate. If None, it defaults to event_name.
aggr_fun
Aggregation function to use. Defaults to 'count'.
most_least
Type of frequency to return: 'most' for most frequent or 'least' for least frequent. Defaults to 'most'.

Returns

DataFrame with the most or least frequent value of the categorical variable in each group defined by group_cols.