Source code for pyspark_pipeline_framework.runtime.session.wrapper
"""Spark session lifecycle management."""
from __future__ import annotations
import logging
import threading
import warnings
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pyspark.sql import SparkSession
from pyspark_pipeline_framework.core.config import SparkConfig
logger = logging.getLogger(__name__)
[docs]
class SparkSessionWrapper:
"""Manages SparkSession lifecycle for pipeline execution.
Supports:
- Local mode (master="local[*]")
- Cluster mode (master="yarn", "spark://...")
- Spark Connect (connect_string="sc://...")
Example:
>>> wrapper = SparkSessionWrapper.get_or_create(config)
>>> df = wrapper.spark.read.parquet("data.parquet")
>>> wrapper.stop()
# Or as context manager:
>>> with SparkSessionWrapper(config) as wrapper:
... df = wrapper.spark.read.parquet("data.parquet")
"""
_instance: SparkSessionWrapper | None = None
_lock = threading.Lock()
[docs]
def __init__(self, config: SparkConfig | None = None) -> None:
"""Initialize wrapper with optional config.
Args:
config: Spark configuration. If None, uses defaults with
app_name="pyspark-pipeline".
"""
if config is None:
config = SparkConfig(app_name="pyspark-pipeline")
self._config = config
self._spark: SparkSession | None = None
self._owns_session = False
self._session_lock = threading.Lock()
[docs]
@classmethod
def get_or_create(cls, config: SparkConfig | None = None) -> SparkSessionWrapper:
"""Get singleton instance or create new one.
Thread-safe singleton access. First call creates the instance,
subsequent calls return the same instance (ignoring config).
Args:
config: Spark configuration for first initialization.
Returns:
The singleton SparkSessionWrapper instance.
"""
with cls._lock:
if cls._instance is None:
cls._instance = cls(config)
return cls._instance
[docs]
@classmethod
def reset(cls) -> None:
"""Reset singleton instance.
Stops any owned session and clears the singleton.
Primarily used for testing.
"""
with cls._lock:
if cls._instance is not None:
cls._instance.stop()
cls._instance = None
@property
def spark(self) -> SparkSession:
"""Get or create SparkSession.
Lazily creates the session on first access. Thread-safe.
Returns:
Active SparkSession instance.
"""
with self._session_lock:
if self._spark is None:
self._spark = self._create_session()
self._owns_session = True
return self._spark
@property
def spark_context(self) -> Any:
"""Get SparkContext from session.
Note: Not available when using Spark Connect mode.
Returns:
SparkContext instance.
Raises:
RuntimeError: If using Spark Connect (no SparkContext available).
"""
if self._config.connect_string:
raise RuntimeError(
"SparkContext is not available in Spark Connect mode. "
"Use spark session directly for DataFrame operations."
)
return self.spark.sparkContext
@property
def sql_context(self) -> Any:
"""Get SQLContext from session.
.. deprecated::
SQLContext is deprecated since Spark 2.0. Use SparkSession directly.
This property may be removed in future versions.
Returns:
SQLContext instance.
Raises:
RuntimeError: If using Spark Connect or PySpark 4.0+.
"""
warnings.warn(
"sql_context is deprecated. Use spark session directly.",
DeprecationWarning,
stacklevel=2,
)
if self._config.connect_string:
raise RuntimeError("SQLContext is not available in Spark Connect mode.")
try:
from pyspark.sql import SQLContext
return SQLContext(self.spark_context)
except ImportError:
raise RuntimeError("SQLContext is not available in PySpark 4.0+. Use SparkSession directly.") from None
@property
def is_connect_mode(self) -> bool:
"""Check if using Spark Connect mode."""
return self._config.connect_string is not None
[docs]
def set_spark_session(self, spark: SparkSession) -> None:
"""Inject an existing SparkSession.
Use when running in an existing Spark environment
(e.g., Databricks, EMR notebooks).
Args:
spark: Existing SparkSession to use.
"""
with self._session_lock:
if self._spark is not None and self._owns_session:
logger.warning("Replacing owned SparkSession with injected one")
self._spark.stop()
self._spark = spark
self._owns_session = False
def _create_session(self) -> SparkSession:
"""Create new SparkSession from config."""
from pyspark.sql import SparkSession
# Spark Connect mode
if self._config.connect_string:
logger.info("Connecting via Spark Connect: %s", self._config.connect_string)
return SparkSession.builder.remote(self._config.connect_string).getOrCreate()
# Standard mode - use to_spark_conf_dict()
builder = SparkSession.builder
for key, value in self._config.to_spark_conf_dict().items():
builder = builder.config(key, value)
logger.info(
"Creating SparkSession: app=%s, master=%s",
self._config.app_name,
self._config.master,
)
return builder.getOrCreate()
[docs]
def stop(self) -> None:
"""Stop SparkSession if we own it."""
with self._session_lock:
if self._spark is not None and self._owns_session:
logger.info("Stopping SparkSession")
self._spark.stop()
self._spark = None
self._owns_session = False
def __enter__(self) -> SparkSessionWrapper:
"""Enter context manager."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Exit context manager, stopping session if owned."""
self.stop()