Source code for pyspark_pipeline_framework.runner.simple_runner

"""Simple pipeline runner with resilience and hooks."""

from __future__ import annotations

import logging
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any

from pyspark_pipeline_framework.core.config.loader import load_from_file
from pyspark_pipeline_framework.core.config.pipeline import PipelineConfig
from pyspark_pipeline_framework.core.config.validator import validate_pipeline
from pyspark_pipeline_framework.core.resilience.circuit_breaker import CircuitBreaker
from pyspark_pipeline_framework.runner.component_executor import ComponentExecutor
from pyspark_pipeline_framework.runner.hooks import NoOpHooks, PipelineHooks
from pyspark_pipeline_framework.runner.result import ComponentResult, PipelineResult, PipelineResultStatus
from pyspark_pipeline_framework.runtime.loader import validate_component_class
from pyspark_pipeline_framework.runtime.session.wrapper import SparkSessionWrapper

logger = logging.getLogger(__name__)


[docs] class SimplePipelineRunner: """Executes a pipeline by running components in topological order. Supports retry, circuit breaker, hooks, and Spark session injection. The runner does **not** manage SparkSession lifecycle — the caller owns start/stop. If no ``spark_wrapper`` is provided, one is created from ``config.spark`` but never stopped automatically. Args: config: Pipeline configuration. spark_wrapper: Optional pre-built session wrapper. hooks: Lifecycle hooks (default: ``NoOpHooks``). fail_fast: Stop on first component failure (default: ``True``). clock: Injectable monotonic clock for testing. sleep_func: Injectable sleep for testing retry delays. validate_before_run: Run static config validation before execution (default: ``True``). Set to ``False`` to skip pre-flight checks. """ def __init__( self, config: PipelineConfig, spark_wrapper: SparkSessionWrapper | None = None, hooks: PipelineHooks | None = None, fail_fast: bool = True, clock: Callable[[], float] | None = None, sleep_func: Callable[[float], None] | None = None, validate_before_run: bool = True, ) -> None: self._config = config self._spark_wrapper = spark_wrapper or SparkSessionWrapper(config.spark) self._hooks: PipelineHooks = hooks or NoOpHooks() self._fail_fast = fail_fast self._clock = clock or time.monotonic self._sleep_func = sleep_func self._validate_before_run = validate_before_run self._circuit_breakers: dict[str, CircuitBreaker] = {} self._executor = ComponentExecutor( hooks=self._hooks, circuit_breakers=self._circuit_breakers, clock=self._clock, sleep_func=self._sleep_func, )
[docs] @classmethod def from_file(cls, path: str | Path, **kwargs: Any) -> SimplePipelineRunner: """Create a runner from a HOCON configuration file. Args: path: Path to the HOCON file. **kwargs: Forwarded to the constructor. Returns: Configured ``SimplePipelineRunner``. """ config = load_from_file(str(path), PipelineConfig) return cls(config, **kwargs)
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def run(self, completed_components: set[str] | None = None) -> PipelineResult: """Execute the pipeline. Args: completed_components: Optional set of component names to skip (already completed in a prior run). Pass ``None`` to run all components. Returns: ``PipelineResult`` with per-component outcomes and overall status. """ start = self._clock() # Pre-flight config validation if self._validate_before_run: validation = validate_pipeline(self._config) for w in validation.warnings: logger.warning("Validation warning: %s", w) if not validation.is_valid: errors_msg = "; ".join(e.message for e in validation.errors) logger.error("Pipeline validation failed: %s", errors_msg) return PipelineResult( status=PipelineResultStatus.FAILURE, pipeline_name=self._config.name, total_duration_ms=int((self._clock() - start) * 1000), ) self._call_hook("before_pipeline", self._config) execution_order = self._config.get_execution_order() # Filter to enabled components only enabled_order: list[str] = [] for name in execution_order: comp_config = self._config.get_component(name) if comp_config is not None and comp_config.enabled: enabled_order.append(name) total = len(enabled_order) results: list[ComponentResult] = [] had_failure = False for index, name in enumerate(enabled_order): if completed_components and name in completed_components: logger.debug("Skipping already-completed component '%s'", name) continue comp_config = self._config.get_component(name) assert comp_config is not None # guaranteed by filter above result = self._executor.execute( comp_config, self._spark_wrapper.spark, index, total, ) results.append(result) if not result.success: had_failure = True if self._fail_fast: break total_duration_ms = int((self._clock() - start) * 1000) if not results or not had_failure: status = PipelineResultStatus.SUCCESS elif all(not r.success for r in results): status = PipelineResultStatus.FAILURE else: status = PipelineResultStatus.PARTIAL_SUCCESS pipeline_result = PipelineResult( status=status, pipeline_name=self._config.name, component_results=results, total_duration_ms=total_duration_ms, ) self._call_hook("after_pipeline", self._config, pipeline_result) return pipeline_result
[docs] def dry_run(self) -> list[str]: """Validate component classes without executing them. Returns: A list of warning strings. Empty means all classes are valid. """ warnings: list[str] = [] for comp_config in self._config.components: if comp_config.enabled: try: warnings.extend(validate_component_class(comp_config.class_path)) except Exception as exc: warnings.append(f"Cannot load '{comp_config.class_path}': {exc}") return warnings
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _call_hook(self, method: str, *args: Any) -> None: """Invoke a hook method defensively — errors are logged, not raised.""" from pyspark_pipeline_framework.core.utils import safe_call safe_call( lambda: getattr(self._hooks, method)(*args), logger, "Hook %s.%s raised an exception", type(self._hooks).__name__, method, )