Source code for pyspark_pipeline_framework.core.resilience.retry

"""Retry execution with exponential backoff and jitter."""

from __future__ import annotations

import functools
import logging
import random
import time
from collections.abc import Callable
from typing import TypeVar

from pyspark_pipeline_framework.core.config.retry import RetryConfig

logger = logging.getLogger(__name__)

T = TypeVar("T")


[docs] class RetryExecutor: """Executes callables with configurable retry logic. Uses exponential backoff with jitter based on a ``RetryConfig``. Args: config: Retry configuration specifying attempts, delays, and retryable exceptions. jitter_factor: Random jitter multiplier applied to each delay (0 disables jitter). sleep_func: Injectable sleep function for testing. Defaults to ``time.sleep``. """ def __init__( self, config: RetryConfig, jitter_factor: float = 0.25, sleep_func: Callable[[float], None] | None = None, ) -> None: self._config = config self._jitter_factor = jitter_factor self._sleep = sleep_func or time.sleep @property def config(self) -> RetryConfig: """Return the retry configuration.""" return self._config
[docs] def calculate_delay(self, attempt: int) -> float: """Calculate the delay in seconds for a given attempt number. Uses exponential backoff: ``min(initial * multiplier^attempt, max) * (1 + jitter)``. Args: attempt: Zero-based attempt index (0 = first retry). Returns: Delay in seconds. """ base = self._config.initial_delay_seconds * (self._config.backoff_multiplier**attempt) base = min(base, self._config.max_delay_seconds) if self._jitter_factor > 0: jitter = base * self._jitter_factor * random.random() base += jitter return base
[docs] def is_retryable(self, error: Exception) -> bool: """Check whether an exception should be retried. Matches against ``retry_on_exceptions`` from config. A name without a dot is matched against the class ``__name__``; a name with a dot is matched against the fully-qualified ``module.class`` path. Args: error: The exception to check. Returns: True if the exception is retryable. """ error_type = type(error) simple_name = error_type.__name__ qualified_name = f"{error_type.__module__}.{simple_name}" for exc_name in self._config.retry_on_exceptions: if "." in exc_name: if exc_name == qualified_name: return True else: if exc_name == simple_name: return True # Also match if the configured name is a parent class name for cls in error_type.__mro__: if cls.__name__ == exc_name: return True return False
[docs] def execute( self, func: Callable[[], T], on_retry: Callable[[int, Exception, float], None] | None = None, ) -> T: """Execute a callable with retry logic. Args: func: Zero-argument callable to execute. on_retry: Optional callback invoked before each retry with ``(attempt, exception, delay)`` where attempt is 1-based. Returns: The return value of *func*. Raises: Exception: The last exception if all attempts are exhausted, or immediately if the exception is not retryable. """ last_error: Exception | None = None for attempt in range(self._config.max_attempts): try: return func() except Exception as exc: last_error = exc is_last = attempt == self._config.max_attempts - 1 if is_last or not self.is_retryable(exc): raise delay = self.calculate_delay(attempt) logger.debug( "Attempt %d/%d failed (%s), retrying in %.3fs", attempt + 1, self._config.max_attempts, type(exc).__name__, delay, ) if on_retry is not None: on_retry(attempt + 1, exc, delay) self._sleep(delay) # Should never reach here, but satisfies type checker assert last_error is not None raise last_error
[docs] def with_retry( config: RetryConfig, jitter_factor: float = 0.25, sleep_func: Callable[[float], None] | None = None, ) -> Callable[[Callable[..., T]], Callable[..., T]]: """Decorator that wraps a function with retry logic. Args: config: Retry configuration. jitter_factor: Jitter multiplier (0 disables). sleep_func: Injectable sleep for testing. Returns: A decorator that adds retry behavior. """ executor = RetryExecutor(config, jitter_factor=jitter_factor, sleep_func=sleep_func) def decorator(func: Callable[..., T]) -> Callable[..., T]: @functools.wraps(func) def wrapper(*args: object, **kwargs: object) -> T: return executor.execute(lambda: func(*args, **kwargs)) return wrapper return decorator