Source code for pyspark_pipeline_framework.runtime.schema_converter

"""Bidirectional conversion between SchemaDefinition and PySpark StructType.

Provides ``to_struct_type`` and ``from_struct_type`` for converting between
the framework's platform-independent schema model and PySpark's native types.
PySpark imports are deferred to function bodies so the module can be imported
without a Spark runtime.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from pyspark_pipeline_framework.core.schema.definition import DataType, SchemaDefinition, SchemaField

if TYPE_CHECKING:
    from pyspark.sql.types import DataType as SparkDataType
    from pyspark.sql.types import StructType

logger = logging.getLogger(__name__)


def _get_spark_type_map() -> dict[DataType, SparkDataType]:
    """Return mapping from framework DataType to PySpark type instances.

    Imports PySpark types lazily so this module remains importable
    without a Spark installation.
    """
    from pyspark.sql.types import (
        BinaryType,
        BooleanType,
        DateType,
        DoubleType,
        FloatType,
        IntegerType,
        LongType,
        StringType,
        TimestampType,
    )

    return {
        DataType.STRING: StringType(),
        DataType.INTEGER: IntegerType(),
        DataType.LONG: LongType(),
        DataType.FLOAT: FloatType(),
        DataType.DOUBLE: DoubleType(),
        DataType.BOOLEAN: BooleanType(),
        DataType.TIMESTAMP: TimestampType(),
        DataType.DATE: DateType(),
        DataType.BINARY: BinaryType(),
    }


def _get_reverse_type_map() -> dict[str, DataType]:
    """Return mapping from PySpark type class name to framework DataType."""
    return {
        "StringType": DataType.STRING,
        "IntegerType": DataType.INTEGER,
        "LongType": DataType.LONG,
        "FloatType": DataType.FLOAT,
        "DoubleType": DataType.DOUBLE,
        "BooleanType": DataType.BOOLEAN,
        "TimestampType": DataType.TIMESTAMP,
        "DateType": DataType.DATE,
        "BinaryType": DataType.BINARY,
    }


_COMPLEX_TYPES = frozenset({DataType.ARRAY, DataType.MAP, DataType.STRUCT})


def _data_type_to_spark(dt: DataType | str) -> SparkDataType:
    """Convert a single framework data type to its PySpark equivalent.

    Args:
        dt: A ``DataType`` enum member or string representation.

    Returns:
        The corresponding PySpark ``DataType`` instance.

    Raises:
        ValueError: If the type cannot be mapped (e.g. complex types without
            nested type information).
    """
    type_map = _get_spark_type_map()

    if isinstance(dt, DataType):
        if dt in _COMPLEX_TYPES:
            raise ValueError(
                f"Complex type '{dt.value}' requires nested type information "
                f"that SchemaField does not carry. Use PySpark types directly "
                f"for complex schemas."
            )
        spark_type = type_map.get(dt)
        if spark_type is None:
            raise ValueError(f"Unsupported DataType: {dt}")
        return spark_type

    # String value — try to coerce to DataType enum first
    try:
        enum_dt = DataType(dt)
        return _data_type_to_spark(enum_dt)
    except ValueError:
        raise ValueError(f"Cannot convert data type '{dt}' to a PySpark type") from None


def _spark_to_data_type(spark_type: SparkDataType) -> DataType | str:
    """Convert a PySpark data type to the framework equivalent.

    Args:
        spark_type: A PySpark ``DataType`` instance.

    Returns:
        The corresponding ``DataType`` enum member, or a string
        description for complex types.
    """
    from pyspark.sql.types import ArrayType, MapType, StructType

    reverse_map = _get_reverse_type_map()
    type_name = type(spark_type).__name__

    if type_name in reverse_map:
        return reverse_map[type_name]

    if isinstance(spark_type, ArrayType):
        return DataType.ARRAY
    if isinstance(spark_type, MapType):
        return DataType.MAP
    if isinstance(spark_type, StructType):
        return DataType.STRUCT

    return type_name


[docs] def to_struct_type(schema: SchemaDefinition) -> StructType: """Convert a ``SchemaDefinition`` to a PySpark ``StructType``. Args: schema: The framework schema definition. Returns: A PySpark ``StructType`` with corresponding fields. Raises: ValueError: If any field type cannot be mapped to a PySpark type. """ from pyspark.sql.types import StructField, StructType fields = [] for sf in schema.fields: spark_type = _data_type_to_spark(sf.data_type) metadata = sf.metadata if sf.metadata else None fields.append(StructField(sf.name, spark_type, nullable=sf.nullable, metadata=metadata)) return StructType(fields)
[docs] def from_struct_type( struct_type: StructType, description: str | None = None, ) -> SchemaDefinition: """Convert a PySpark ``StructType`` to a ``SchemaDefinition``. Args: struct_type: A PySpark ``StructType``. description: Optional description for the resulting schema. Returns: A ``SchemaDefinition`` with corresponding fields. """ fields = [] for spark_field in struct_type.fields: dt = _spark_to_data_type(spark_field.dataType) metadata = dict(spark_field.metadata) if spark_field.metadata else {} fields.append( SchemaField( name=spark_field.name, data_type=dt, nullable=spark_field.nullable, metadata=metadata, ) ) return SchemaDefinition(fields=fields, description=description)