"""Pipeline lifecycle hooks protocol and infrastructure."""
from __future__ import annotations
import logging
from typing import Any, Protocol
from pyspark_pipeline_framework.core.config.component import ComponentConfig
from pyspark_pipeline_framework.core.config.pipeline import PipelineConfig
from pyspark_pipeline_framework.core.resilience.circuit_breaker import CircuitState
from pyspark_pipeline_framework.core.utils import safe_call
logger = logging.getLogger(__name__)
[docs]
class PipelineHooks(Protocol):
"""Protocol defining lifecycle callbacks for pipeline execution.
Implementations receive notifications at key points during pipeline
execution. This protocol is NOT ``@runtime_checkable`` — use
structural typing or ``hasattr`` checks.
"""
[docs]
def before_pipeline(self, config: PipelineConfig) -> None:
"""Called before the pipeline starts executing."""
...
[docs]
def after_pipeline(self, config: PipelineConfig, result: Any) -> None:
"""Called after the pipeline finishes (success or failure)."""
...
[docs]
def before_component(self, config: ComponentConfig, index: int, total: int) -> None:
"""Called before each component executes."""
...
[docs]
def after_component(self, config: ComponentConfig, index: int, total: int, duration_ms: int) -> None:
"""Called after each component completes successfully."""
...
[docs]
def on_component_failure(self, config: ComponentConfig, index: int, error: Exception) -> None:
"""Called when a component raises an exception."""
...
[docs]
def on_retry_attempt(
self,
config: ComponentConfig,
attempt: int,
max_attempts: int,
delay_ms: int,
error: Exception,
) -> None:
"""Called before a retry attempt."""
...
[docs]
def on_circuit_breaker_state_change(
self,
component_name: str,
old_state: CircuitState,
new_state: CircuitState,
) -> None:
"""Called when a circuit breaker changes state."""
...
[docs]
class NoOpHooks:
"""Hooks implementation that does nothing.
Useful as a default or placeholder.
"""
[docs]
def before_pipeline(self, config: PipelineConfig) -> None:
pass
[docs]
def after_pipeline(self, config: PipelineConfig, result: Any) -> None:
pass
[docs]
def before_component(self, config: ComponentConfig, index: int, total: int) -> None:
pass
[docs]
def after_component(self, config: ComponentConfig, index: int, total: int, duration_ms: int) -> None:
pass
[docs]
def on_component_failure(self, config: ComponentConfig, index: int, error: Exception) -> None:
pass
[docs]
def on_retry_attempt(
self,
config: ComponentConfig,
attempt: int,
max_attempts: int,
delay_ms: int,
error: Exception,
) -> None:
pass
[docs]
def on_circuit_breaker_state_change(
self,
component_name: str,
old_state: CircuitState,
new_state: CircuitState,
) -> None:
pass
[docs]
class CompositeHooks:
"""Broadcasts lifecycle events to multiple hooks implementations.
Exceptions raised by individual hooks are caught and logged so that
one misbehaving hook does not break the pipeline.
"""
def __init__(self, *hooks: PipelineHooks) -> None:
self._hooks: tuple[PipelineHooks, ...] = hooks
def _call_all(self, method: str, *args: Any, **kwargs: Any) -> None:
"""Invoke *method* on every registered hook, swallowing errors."""
for hook in self._hooks:
def _invoke(h: PipelineHooks = hook) -> None:
getattr(h, method)(*args, **kwargs)
safe_call(_invoke, logger, "Hook %s.%s raised an exception", type(hook).__name__, method)
[docs]
def before_pipeline(self, config: PipelineConfig) -> None:
self._call_all("before_pipeline", config)
[docs]
def after_pipeline(self, config: PipelineConfig, result: Any) -> None:
self._call_all("after_pipeline", config, result)
[docs]
def before_component(self, config: ComponentConfig, index: int, total: int) -> None:
self._call_all("before_component", config, index, total)
[docs]
def after_component(self, config: ComponentConfig, index: int, total: int, duration_ms: int) -> None:
self._call_all("after_component", config, index, total, duration_ms)
[docs]
def on_component_failure(self, config: ComponentConfig, index: int, error: Exception) -> None:
self._call_all("on_component_failure", config, index, error)
[docs]
def on_retry_attempt(
self,
config: ComponentConfig,
attempt: int,
max_attempts: int,
delay_ms: int,
error: Exception,
) -> None:
self._call_all("on_retry_attempt", config, attempt, max_attempts, delay_ms, error)
[docs]
def on_circuit_breaker_state_change(
self,
component_name: str,
old_state: CircuitState,
new_state: CircuitState,
) -> None:
self._call_all(
"on_circuit_breaker_state_change",
component_name,
old_state,
new_state,
)