Module panama.io.io

Classes

class IOAdls (spark)

Extends the IO interface to SQL data structure.

Expand source code
class IOAdls(IOInterface):
    """
    Extends the IO interface to SQL data structure.
    """

    def __init__(self, spark):
        super().__init__(spark)

    @staticmethod
    def _generate_absolute_path(path: str, storage_account: str, container: str, source_type: str) -> str:
        """Generate the url that points to the required path inside the data lake.

        Args:
            path (str): File or table location
            storage_account (str): Azure storage account.
            container (str): Name of the container where to find files or tables.
            source_type (str): value extracted from connections.json indicating type of source (adls, blob, etc..)

        Returns:
            str: the final url to read data from.
        """
        if source_type == "adls":
            return f"abfss://{container}@{storage_account}.dfs.core.windows.net/{path}"
        elif source_type == "blob":
            return f"wasb://{container}@{storage_account}.blob.core.windows.net/{path}"
        else:
            raise ValueError("Accepted source_type in connections.json are 'adls', 'blob' in IOAdls")

    @staticmethod
    def _get_storage_account_from_name(adls_name: str):
        """Get storage connection config from config json

        Args:
            adls_name (str): Name of datalake used to get connection config.

        Returns:
            (str): Azure storage account.
        """
        adls_config = IOInterface._get_connection_config_json().get(adls_name)
        adls_storage_account = adls_config.get("credentials").get("storage_account")  # type: ignore
        return adls_storage_account

    @staticmethod
    def _get_source_type_from_name(adls_name: str):
        """Get storage connection config from config json

        Args:
            adls_name (str): Name of datalake used to get connection config.

        Returns:
            (str): source_type value.
        """
        adls_config = IOInterface._get_connection_config_json().get(adls_name)
        source_type = adls_config.get("type")  # type: ignore
        return source_type

    @staticmethod
    def generate_connection_string(adls_name: str, container: str, path: str = "") -> str:
        """Generate the url that points to the required path inside the data lake.

        Args:

            container (str): Name of the container where to find files or tables.
            path (str): File or table location

        Returns:
            str: the final url to read data from."""

        # Build absolute adsl path from adls name
        storage_account = IOAdls._get_storage_account_from_name(adls_name)
        source_type = IOAdls._get_source_type_from_name(adls_name)

        absolute_path = IOAdls._generate_absolute_path(
            path=path, storage_account=storage_account, container=container, source_type=source_type
        )
        return absolute_path

    ############################### READING

    @panama.logging.log_execution(blocking=True)
    def read_file(
        self,
        adls_name: str,
        container: str,
        file_path: str,
        file_format: str = "delta",
        extra_options: Optional[dict] = None,
        create_view: Optional[bool] = False,
    ) -> DataFrame:
        """Reads a table directly from an ADLS instance's path location.

        Args:
            adls_name (str): Name of datalake (or blob) used to get connection config.
            container (str): Name of the container where to find files.
            file_path (str): Path where to find file (or files). Use folder to read multiple files, or specific file path with extension.
            file_format (str, optional): Format of the file to read. Defaults to "delta".
            extra_options (Optional[dict], optional): Additional options to pass to the spark read command as a dictionary. Defaults to None.
            create_view (bool, optional): return temp view containing entire table content. Defaults to False.

        Returns:
            DataFrame: containing the data from the specified table.
        """
        storage_account = self._get_storage_account_from_name(adls_name)
        source_type = self._get_source_type_from_name(adls_name)

        # Compose datalake file path from given table info
        abs_file_path = self._generate_absolute_path(
            path=file_path, storage_account=storage_account, container=container, source_type=source_type
        )

        # Init reader
        sdf_reader = self.spark.read.format(file_format)

        if extra_options:
            sdf_reader = sdf_reader.options(**extra_options)

        sdf = sdf_reader.load(abs_file_path)

        if create_view:
            self._create_sdf_temp_view(sdf=sdf, table_name_path=file_path)

        # Read table
        return sdf

    @panama.logging.log_execution(blocking=True)
    def read_table(
        self,
        table_name: str,
        max_datetime: Union[str, None] = None,
        create_view: Optional[bool] = False,
        force_read_to_env: Optional[bool] = False,
    ) -> DataFrame:
        """Reads a table from unity catalog.

        ATTENTION: momentary fix will let the user read dev catalog from dev environment (instead of test catalog from dev environment).

        Args:
            table_name (str): Name of the table to read from. Format must be "<catalog>.<schema>.<table>". Catalog must be indicated without env prefix.
            max_datetime (str, optional): filter out datetime newer (inclusive) to datetime input. Format must be '%Y-%m-%d %H:%M:%S' or '%Y-%m-%d'. Defaults to None.
            table_format (str, optional): Format of the table to read (e.g., 'delta').. Defaults to "delta".
            create_view (bool, optional): return temp view containing entire table content. Defaults to False.
            force_read_to_env (Optional, bool): if True, force read on same env of execution. This will only works on '_analytics' catalogs. Default to False.

        Returns:
            DataFrame: containing the data from the specified table.
        """

        # If required, lookup for latest version of delta table wrt the max_datetime input
        if max_datetime:
            time_travel_datetime = self.get_latest_delta_datetime(
                table_name=table_name,
                max_datetime=max_datetime,
                force_read_to_env=force_read_to_env,
            )
        else:
            time_travel_datetime = None

        # Set env in catalog name
        table_name = self._assign_env_to_string(table_name, purpose="r", force_read_to_env=force_read_to_env)

        query = f"SELECT * FROM {table_name}"

        if time_travel_datetime:
            # Add time travel sql condition.
            query += f' TIMESTAMP AS OF "{time_travel_datetime}"'
        else:
            pass

        sdf = self.spark.sql(query)

        # Create view if requested
        if create_view:
            self._create_sdf_temp_view(sdf=sdf, table_name_path=table_name)  # type: ignore

        return sdf

    def get_latest_delta_datetime(
        self,
        table_name: str,
        max_datetime: Union[str, None] = None,
        force_read_to_env: Optional[bool] = False,
    ) -> datetime.datetime:
        """
        Get latest timestamp when the delta table was written.
        If max_date or max_datetime input is prior to oldest table timestamp return the latter.

        Args:
            table_name (str): Name of the table to read from. Format must be "<catalog>.<schema>.<table>". Catalog must be indicated without env prefix.
            max_datetime (str, optional): filter out datetime newer (inclusive) to datetime input. Format must be '%Y-%m-%d %H:%M:%S' or '%Y-%m-%d'. Defaults to None.
            force_read_to_env (Optional, bool): if True, force read on same env of execution. This will only works on '_analytics' catalogs. Default to False.

        Returns:
            datetime.datetime: latest time when the delta table was written.
        """

        # Validate max_datetime input format and define filter condition.
        if max_datetime:
            date_filter = timestamp_str_to_datetime(max_datetime)

        # Set env in catalog name
        table_name = self._assign_env_to_string(table_name, "r", force_read_to_env=force_read_to_env)

        # Query delta table's history looking for timestamps
        query = (
            f"select timestamp as last_ts from (select * from (describe history {table_name})) order by timestamp desc"
        )
        timestamp_list = self.spark.sql(query).collect()

        # Extract last_ts from every row
        timestamp_list = [row["last_ts"] for row in timestamp_list]

        # Return latest if max_datetime was not provided.
        if max_datetime is None:
            print(f"No max_date passed for {table_name} extraction, returning last available timestamp.")
            return timestamp_list[0]

        # Filter list based on defined filter
        filtered_timestamp_list = [t for t in timestamp_list if t <= date_filter]  # type: ignore

        # Return last available timestamp or absolute older if filter date is prior to it.
        if len(filtered_timestamp_list) > 0:
            output_timestamp = filtered_timestamp_list[0]
        else:
            warnings.warn("No timestamp prior to max_date/max_datetime, returning the oldest timestamp")
            output_timestamp = timestamp_list[-1]

        return output_timestamp

    def read_most_recent_file(
        self,
        container: str,
        adls_name: str,
        file_name: str,
        file_format: str,
        look_folder: str = "",
        max_date: Union[str, None] = None,
        max_iterations: int = 100,
        extra_options: Optional[dict] = None,
        create_view: Optional[bool] = False,
    ) -> DataFrame:
        """Get the most recent file path from a directory with the structure:
            look_folder:
                - "2023-01-01"
                    - filename.csv
                - "2023-02-01"
                    - filename.csv
                - "2023-03-01"
                    - filename.csv
        Then read the file and return the spark dataframe containing the data.

        Args:
            container (str): Name of the container where to find and read file.
            adls_name (str): Name of datalake used to get connection config.
            file_name (str): Name of the file that we want to get.
            file_format (str): Extension of the file we are looking for.
            look_folder (str, Optional): Main directory containing the data that needs to be selected. Default is root.
            date (str, optional): launch date, that works as an upper limit (inclusive) for the latest copies of the file. Format must be %Y-%m-%d. Defaults to None.
            max_iterations (int, optional): limit the number of searching tentatives. Defaults to 100.
            extra_options (Optional[dict], optional): Additional options to pass to the spark read command as a dictionary. Defaults to None.
            create_view (bool, optional): return temp view containing entire table content. Defaults to False.

        Returns:
            DataFrame: spark dataframe containing data inside most recent version of file required.
        """

        # Search for latest version of file and get it's path
        file_path = self.get_path_most_recent_file(
            container=container,
            adls_name=adls_name,
            look_folder=look_folder,
            file_name=file_name,
            file_format=file_format,
            max_date=max_date,
            max_iterations=max_iterations,
        )

        # Read file using path
        return self.read_file(
            adls_name=adls_name,
            container=container,
            file_path=file_path,
            file_format=file_format,
            extra_options=extra_options,
            create_view=create_view,
        )

    @panama.logging.log_execution(blocking=True)
    def get_path_most_recent_file(
        self,
        container: str,
        adls_name: str,
        file_name: str,
        file_format: str,
        look_folder: str = "",
        max_date: Union[str, None] = None,
        max_iterations: int = 10,
    ):
        """Get the most recent file path from a directory with the structure:
            look_folder:
                - "2023-01-01"
                    - filename.csv
                - "2023-02-01"
                    - filename.csv
                - "2023-03-01"
                    - filename.csv

        Args:
            container (str): Name of the container where to find files.
            adls_name (str): Name of datalake used to get connection config.
            file_name (str): Name of the file that we want to get (no extension).
            file_format (str): Extension of the file we are looking for.
            look_folder (str, Optional): Main directory containing the data that needs to be selected. Default is root.
            max_date (str, optional): launch date, that works as an upper limit (inclusive) for the latest copies of the file. Format must be %Y-%m-%d. Defaults to None.
            max_iterations (int, optional): limit the number of searching tentatives. Defaults to 10.
        """

        # Concatenate file_name and file_format
        if re.search(string=file_name, pattern=".*\..*"):  # type: ignore
            raise ValueError("Please provide file_name input without the extension. Use the file_format parameter.")
        else:
            file_name_complete = ".".join([file_name, file_format])

        # Build absolute adsl path of look_folder
        storage_account = self._get_storage_account_from_name(adls_name)
        source_type = self._get_source_type_from_name(adls_name)

        look_folder_path = self._generate_absolute_path(
            path=look_folder, storage_account=storage_account, container=container, source_type=source_type
        )

        # Get list of paths in look_folder
        dir_list = dbutils_fs_ls_names(look_folder_path, self.spark)

        # Select only folder with %Y-%m-%d format...
        date_format = re.compile("\d{4}-\d{2}-\d{2}")  # type: ignore
        date_dir_list = list(filter(date_format.match, dir_list))
        # ... filter by date if requested
        if max_date:
            try:
                re.search(string=max_date, pattern=date_format).group(0)  # type: ignore
            except:
                raise ValueError("max_date must be of format %Y-%m-%d")
            date_dir_list = [i for i in date_dir_list if i <= f"{max_date}/"]
        # ... and sort
        date_dir_list = sorted(date_dir_list, reverse=True)

        try:
            print(f"{len(date_dir_list)} subfolder found, searching for file starting from more recent")
        except:
            raise FileNotFoundError(f"{look_folder} does not contain any folder with valid date format (%Y-%m-%d)")

        # Start searching for desired file
        file_found = None
        n = 0
        while file_found == None and n <= min(len(date_dir_list), max_iterations):
            # Search in current subfolder iteration a path that ends with file name
            try:
                p = look_folder_path + "/" + date_dir_list[n]
            except:
                raise FileNotFoundError(
                    "File not found! Check the file name (remember to use the extension) or the folder path"
                )
            try:
                # If successfull will interrupt the loop
                file_found = [i for i in dbutils_fs_ls_names(path=p, spark=self.spark) if i == file_name_complete][0]
                print("File found in folder:", str(date_dir_list[n])[0:-1])
                return "/".join([look_folder, date_dir_list[n], file_found])
            except:
                # Next iteration
                n += 1

        # If loop ended means that file was not found, throw error
        raise IndexError(
            "File not found! Check the file name (remember to use the extension) or the folder path or use a higher max_iterations"
        )

    ############################## WRITING

    @staticmethod
    def _df_repartitioning(sdf, mode: str, partitions: int) -> DataFrame:
        """Perform either repartioning or coalesce over input spark dataframe.

        Args:
            sdf (DataFrame): spark dataframe, main subject of the function.
            mode (str): string to control which operation to perform. Accepted values are 'repartition' and 'coalesce'.
            partitions (int): number of partition to obtain. Object of the selected method.

        Raises:
            KeyError: when mode is not 'repartition' or 'coalesce'.

        Returns:
            DataFrame: repartitioned spark dataframe.
        """
        if mode == "repartition":
            return sdf.repartition(partitions)
        elif mode == "coalesce":
            return sdf.coalesce(partitions)
        else:
            raise KeyError("Repartitioning modality unknown. Please use 'repartition' or 'coalesce'")

    @panama.logging.log_execution(blocking=True)  # type: ignore
    def write_file(
        self,
        data: DataFrame,
        adls_name: str,
        container: str,
        file_path: str,
        mode: str = "append",
        file_format: str = "delta",
        repartitioning: Optional[list] = None,
        partition_by: Optional[list] = None,
        extra_options: Optional[dict] = None,
        save_as_table: bool = False,
    ) -> None:
        """Write data into datalake's table.
        Optionally provide "repartitiong" a list like [str, int] to perform repartitiong of dataframe before writing.
        Accepted repartitiong modality are "repartition" and "coalesce".

        Args:
            data (DataFrame): spark dataframe containing data.
            adls_name (str): Name of datalake used to get connection config.
            container (str): Name of the container where to write files.
            file_path (str): Path where to write file. Provide path containing a table. If empty, last part of path will be the name of the new table.
            mode (str, optional): spark write mode. Defaults to "append".
            file_format (str, optional): Format of the file to write. Defaults to "delta".
            repartitioning (Optional[list], optional): List indicating how to perform pre-writing repartitioning. Defaults to None.
            partition_by (Optional[list], optional): column(s) to use as writing partitions. Defaults to None.
            extra_options (Optional[dict], optional): Additional options to pass to the spark write command as a dictionary. Defaults to None.
            save_as_table (bool, optional): if True, the write command will be SaveAsTable, else it will use save. Defaults to False.
        """
        # Compose datalake file path from given table info
        storage_account = self._get_storage_account_from_name(adls_name)
        source_type = self._get_source_type_from_name(adls_name)

        file_path = self._generate_absolute_path(
            path=file_path, storage_account=storage_account, container=container, source_type=source_type
        )

        # If dataframe partitioning instrunction were given, execute it accordingly
        if repartitioning is not None and len(repartitioning) == 2:
            data = self._df_repartitioning(sdf=data, mode=repartitioning[0], partitions=repartitioning[1])

        # Init data wirter
        data_writer = data.write

        # If partition_columns is provided then add partitionBy step to write pipeline
        if partition_by is not None:
            data_writer = data_writer.partitionBy(partition_by)

        # Add extra options if any was provided
        if extra_options:
            data_writer = data_writer.options(**extra_options)

        # Define write format and modality
        data_writer = data_writer.format(file_format).mode(mode)

        # Final save statement
        if save_as_table == True:
            data_writer.saveAsTable(file_path)
        else:
            data_writer.save(file_path)
        print(f"Table written in {file_path}")

Ancestors

Static methods

def generate_connection_string(adls_name: str, container: str, path: str = '') ‑> str

Generate the url that points to the required path inside the data lake.

Args

container : str
Name of the container where to find files or tables.
path : str
File or table location

Returns

str
the final url to read data from.

Methods

def get_latest_delta_datetime(self, table_name: str, max_datetime: Optional[str] = None, force_read_to_env: Optional[bool] = False) ‑> datetime.datetime

Get latest timestamp when the delta table was written. If max_date or max_datetime input is prior to oldest table timestamp return the latter.

Args

table_name : str
Name of the table to read from. Format must be "..". Catalog must be indicated without env prefix.
max_datetime : str, optional
filter out datetime newer (inclusive) to datetime input. Format must be '%Y-%m-%d %H:%M:%S' or '%Y-%m-%d'. Defaults to None.
force_read_to_env : Optional, bool
if True, force read on same env of execution. This will only works on '_analytics' catalogs. Default to False.

Returns

datetime.datetime
latest time when the delta table was written.
def get_path_most_recent_file(self, container: str, adls_name: str, file_name: str, file_format: str, look_folder: str = '', max_date: Optional[str] = None, max_iterations: int = 10)

Get the most recent file path from a directory with the structure: look_folder: - "2023-01-01" - filename.csv - "2023-02-01" - filename.csv - "2023-03-01" - filename.csv

Args

container : str
Name of the container where to find files.
adls_name : str
Name of datalake used to get connection config.
file_name : str
Name of the file that we want to get (no extension).
file_format : str
Extension of the file we are looking for.
look_folder : str, Optional
Main directory containing the data that needs to be selected. Default is root.
max_date : str, optional
launch date, that works as an upper limit (inclusive) for the latest copies of the file. Format must be %Y-%m-%d. Defaults to None.
max_iterations : int, optional
limit the number of searching tentatives. Defaults to 10.
def read_file(self, adls_name: str, container: str, file_path: str, file_format: str = 'delta', extra_options: Optional[dict] = None, create_view: Optional[bool] = False) ‑> pyspark.sql.dataframe.DataFrame

Reads a table directly from an ADLS instance's path location.

Args

adls_name : str
Name of datalake (or blob) used to get connection config.
container : str
Name of the container where to find files.
file_path : str
Path where to find file (or files). Use folder to read multiple files, or specific file path with extension.
file_format : str, optional
Format of the file to read. Defaults to "delta".
extra_options : Optional[dict], optional
Additional options to pass to the spark read command as a dictionary. Defaults to None.
create_view : bool, optional
return temp view containing entire table content. Defaults to False.

Returns

DataFrame
containing the data from the specified table.
def read_most_recent_file(self, container: str, adls_name: str, file_name: str, file_format: str, look_folder: str = '', max_date: Optional[str] = None, max_iterations: int = 100, extra_options: Optional[dict] = None, create_view: Optional[bool] = False) ‑> pyspark.sql.dataframe.DataFrame

Get the most recent file path from a directory with the structure: look_folder: - "2023-01-01" - filename.csv - "2023-02-01" - filename.csv - "2023-03-01" - filename.csv Then read the file and return the spark dataframe containing the data.

Args

container : str
Name of the container where to find and read file.
adls_name : str
Name of datalake used to get connection config.
file_name : str
Name of the file that we want to get.
file_format : str
Extension of the file we are looking for.
look_folder : str, Optional
Main directory containing the data that needs to be selected. Default is root.
date : str, optional
launch date, that works as an upper limit (inclusive) for the latest copies of the file. Format must be %Y-%m-%d. Defaults to None.
max_iterations : int, optional
limit the number of searching tentatives. Defaults to 100.
extra_options : Optional[dict], optional
Additional options to pass to the spark read command as a dictionary. Defaults to None.
create_view : bool, optional
return temp view containing entire table content. Defaults to False.

Returns

DataFrame
spark dataframe containing data inside most recent version of file required.
def read_table(self, table_name: str, max_datetime: Optional[str] = None, create_view: Optional[bool] = False, force_read_to_env: Optional[bool] = False) ‑> pyspark.sql.dataframe.DataFrame

Reads a table from unity catalog.

ATTENTION: momentary fix will let the user read dev catalog from dev environment (instead of test catalog from dev environment).

Args

table_name : str
Name of the table to read from. Format must be "..
". Catalog must be indicated without env prefix.
max_datetime : str, optional
filter out datetime newer (inclusive) to datetime input. Format must be '%Y-%m-%d %H:%M:%S' or '%Y-%m-%d'. Defaults to None.
table_format : str, optional
Format of the table to read (e.g., 'delta').. Defaults to "delta".
create_view : bool, optional
return temp view containing entire table content. Defaults to False.
force_read_to_env : Optional, bool
if True, force read on same env of execution. This will only works on '_analytics' catalogs. Default to False.

Returns

DataFrame
containing the data from the specified table.
def write_file(self, data: pyspark.sql.dataframe.DataFrame, adls_name: str, container: str, file_path: str, mode: str = 'append', file_format: str = 'delta', repartitioning: Optional[list] = None, partition_by: Optional[list] = None, extra_options: Optional[dict] = None, save_as_table: bool = False) ‑> None

Write data into datalake's table. Optionally provide "repartitiong" a list like [str, int] to perform repartitiong of dataframe before writing. Accepted repartitiong modality are "repartition" and "coalesce".

Args

data : DataFrame
spark dataframe containing data.
adls_name : str
Name of datalake used to get connection config.
container : str
Name of the container where to write files.
file_path : str
Path where to write file. Provide path containing a table. If empty, last part of path will be the name of the new table.
mode : str, optional
spark write mode. Defaults to "append".
file_format : str, optional
Format of the file to write. Defaults to "delta".
repartitioning : Optional[list], optional
List indicating how to perform pre-writing repartitioning. Defaults to None.
partition_by : Optional[list], optional
column(s) to use as writing partitions. Defaults to None.
extra_options : Optional[dict], optional
Additional options to pass to the spark write command as a dictionary. Defaults to None.
save_as_table : bool, optional
if True, the write command will be SaveAsTable, else it will use save. Defaults to False.
class IOInterface (spark)

Interface class for IO objects

Expand source code
class IOInterface:
    """
    Interface class for IO objects
    """

    def __init__(self, spark):
        self.spark = spark

    @staticmethod
    def _get_connection_config_json() -> dict:
        """Uses utils function to get connection.json from DBFS. It contains properties to every storage connection.

        Returns:
            Dict: Dictionary containing connection properties.
        """
        return get_connection_config_json()

    def _create_sdf_temp_view(self, sdf: DataFrame, table_name_path: str):
        """Return spark sql temp view based on dataframe.

        Args:
            sdf (DataFrame): spark dataframe.
            file_path (str): file path, used to generate the view name.
        """
        # Replace dots with slashes

        view_name = re.sub(string=table_name_path, pattern="\.", repl="/")  # type: ignore

        # Add backticks
        view_name = f"`{view_name}`"
        sdf.createOrReplaceTempView(view_name)
        print(f"Created spark sql view named {view_name}")

    def _set_catalog(self, catalog_name: str) -> None:
        """Set catalog on current spark session.

        Args:
            catalog_name (str): name of the catalog to use.
        """

        self.spark.sql(f"USE CATALOG {catalog_name}")

    def _assign_env_to_string(self, s: str, purpose: str, force_read_to_env: Optional[bool] = False) -> str:
        """Add catalog environment prefix to string. Could be complete table reference (catalog.schema.table) or just the catalog name.
            E.g. qdata.schema.table -> dev_qdata.schema.table

        Args:
            s (str): string to enrich with env prefix.
            purpose (str): string indicating if the string will be used to read or to write table. Allowed values are 'r' and 'w' only.
            force_read_to_env (Optional, bool): if True, force read on same env of execution.

        Returns:
            str: input string with prefix.
        """
        # Establish env based on tag and purpose
        cluster_tag = get_env_from_cluster_tag(spark=self.spark)
        if purpose == "r":
            if force_read_to_env == True:
                # Force read to happen in same env of execution.
                prefix = cluster_tag
            elif cluster_tag == "dev":
                # Momentary fix, from dev env we should read test catalog.
                prefix = cluster_tag
            elif cluster_tag in ["test", "prod"]:
                prefix = "prod"
            else:
                raise ValueError(f"{cluster_tag} tag on cluster not expected.")
        elif purpose == "w":
            if cluster_tag in ["dev", "test", "prod"]:
                prefix = cluster_tag
            else:
                raise ValueError(f"{cluster_tag} tag on cluster not expected.")
        else:
            raise ValueError(f"purpose parameter must be 'r' or 'w', found '{purpose}'")

        s = f"{prefix}_{s}"

        return s

    @abstractmethod
    def _get_connection_properties():
        NotImplementedError()

    @abstractmethod
    def read_table():
        NotImplementedError()

Subclasses

Methods

def read_table()
class IOSql (spark)

Extends the IO interface to SQL data structure.

Expand source code
class IOSql(IOInterface):
    """
    Extends the IO interface to SQL data structure.
    """

    def __init__(self, spark):
        super().__init__(spark)

    def _get_connection_properties(self, database_name: str) -> dict:
        """Returns the connection properties for the specified database.

        Args:
            database_name (str): Name of the database.

        Raises:
            ValueError: Raise error if the database is not sqlserver or oracle.

        Returns:
            dict: A dict containing connection configuration.
        """

        # Get database connection config from config json
        database_config = super()._get_connection_config_json().get(database_name)

        database_credentials = database_config.get("credentials")  # type: ignore

        # Get password from KeVault using relative secret
        pswd = get_secret_from_keyvault(
            key=database_credentials.get("password"), spark=self.spark, scope=database_credentials.get("scope")
        )

        # Generate url depending on database type and add driver if necessary
        if database_config.get("type") == "sqlserver":  # type: ignore
            url = f"jdbc:sqlserver://{database_credentials.get('host')};databaseName={database_credentials.get('database')}"
            driver = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
            connection_config = {
                "url": url,
                "user": database_credentials.get("username"),
                "password": pswd,
                "driver": driver,
                "fetchSize": 1000,
            }
        elif database_config.get("type") == "oracle":  # type: ignore
            url = f"jdbc:oracle:thin:@{database_credentials.get('host')}:{database_credentials.get('port')}/{database_credentials.get('service_name')}"
            driver = "oracle.jdbc.driver.OracleDriver"
            connection_config = {
                "url": url,
                "user": database_credentials.get("username"),
                "password": pswd,
                "driver": driver,
                "oracle.jdbc.timezoneAsRegion": False,
                "fetchSize": 1000,
            }
        else:
            raise ValueError(f"Unsupported database type: {database_config.get('type')}")  # type: ignore

        return connection_config

    @staticmethod
    def _convert_all_decimal_columns_to_double(sdf):
        decimal_cols = [col[0] for col in sdf.dtypes if col[1].startswith("decimal")]

        for col_name in decimal_cols:
            sdf = sdf.withColumn(col_name, F.col(col_name).cast("double"))
        return sdf

    @panama.logging.log_execution(blocking=True)
    def read_table(
        self,
        database_name: str,
        table_name: Optional[str] = None,
        query: Optional[str] = None,
        extra_options: Optional[dict] = None,
        create_view: Optional[bool] = False,
    ) -> DataFrame:
        """Reads data from a table in an SQLServer or Oracle database.

        Args:
            database_name (str): Name of the database to read from.
            table_name (Optional[str], optional): optional when input 'query' parameter, name of the table to read from. Format of table_name is "schema.table". Defaults to None.
            query (Optional[str], optional): optional when input 'table_name' parameter, query to execute instead of reading the whole table. Defaults to None.
            extra_options (Optional[dict], optional): Additional options to pass to the JDBC connector as a dictionary. Defaults to None.
            create_view (bool, optional): return temp view containing entire table content. Defaults to False.

        Raises:
            ValueError: Raise error if user try to input both query and table_name at the same time.

        Returns:
            DataFrame: A Spark DataFrame containing the data from the specified table or query.
        """

        # Check if user provided only one argument between table_name and query
        if (table_name is not None) and (query is not None):
            raise ValueError("Cannot specify both table_name and query.")

        connection_config = self._get_connection_properties(database_name)

        if table_name:
            connection_config.update({"dbtable": table_name})
        elif query:
            connection_config.update({"query": query})

        if extra_options:
            connection_config.update(extra_options)

        sdf = self.spark.read.format("jdbc").options(**connection_config).load()

        # Convert all decimal columns to doubleType standard
        sdf = self._convert_all_decimal_columns_to_double(sdf)

        if create_view:
            if table_name is None:
                table_name = re.search(string=query, pattern="((FROM)|(from))\s([\w\.]+)\s")  # type: ignore
            self._create_sdf_temp_view(sdf=sdf, table_name_path=table_name)  # type: ignore

        return sdf

Ancestors

Methods

def read_table(self, database_name: str, table_name: Optional[str] = None, query: Optional[str] = None, extra_options: Optional[dict] = None, create_view: Optional[bool] = False) ‑> pyspark.sql.dataframe.DataFrame

Reads data from a table in an SQLServer or Oracle database.

Args

database_name : str
Name of the database to read from.
table_name : Optional[str], optional
optional when input 'query' parameter, name of the table to read from. Format of table_name is "schema.table". Defaults to None.
query : Optional[str], optional
optional when input 'table_name' parameter, query to execute instead of reading the whole table. Defaults to None.
extra_options : Optional[dict], optional
Additional options to pass to the JDBC connector as a dictionary. Defaults to None.
create_view : bool, optional
return temp view containing entire table content. Defaults to False.

Raises

ValueError
Raise error if user try to input both query and table_name at the same time.

Returns

DataFrame
A Spark DataFrame containing the data from the specified table or query.