Source code for pyspark_pipeline_framework.core.quality.checks

"""Built-in data quality check factories."""

from __future__ import annotations

from typing import TYPE_CHECKING

from pyspark_pipeline_framework.core.quality.types import CheckResult, CheckTiming, DataQualityCheck
from pyspark_pipeline_framework.core.schema.definition import SchemaDefinition

if TYPE_CHECKING:
    from pyspark.sql import SparkSession


[docs] def row_count_check( table: str, min_rows: int, timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Check that a table has at least *min_rows* rows.""" def check(spark: SparkSession) -> CheckResult: count = spark.table(table).count() passed = count >= min_rows return CheckResult( check_name=f"row_count_{table}", passed=passed, message=f"Table {table} has {count} rows (min: {min_rows})", details={"count": count, "min_rows": min_rows}, ) return DataQualityCheck( name=f"row_count_{table}", timing=timing, check_fn=check, component_name=component_name, )
[docs] def null_check( table: str, column: str, max_null_pct: float = 0.0, timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Check that *column* has at most *max_null_pct* % null values.""" def check(spark: SparkSession) -> CheckResult: df = spark.table(table) total = df.count() nulls = df.filter(f"{column} IS NULL").count() pct = (nulls / total * 100) if total > 0 else 0.0 passed = pct <= max_null_pct return CheckResult( check_name=f"null_check_{table}_{column}", passed=passed, message=f"Column {column} has {pct:.2f}% nulls (max: {max_null_pct}%)", details={"null_count": nulls, "total": total, "pct": pct}, ) return DataQualityCheck( name=f"null_check_{table}_{column}", timing=timing, check_fn=check, component_name=component_name, )
[docs] def unique_check( table: str, columns: str | list[str], timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Check that *columns* form a unique key (no duplicates). Args: table: Registered temp view or table name. columns: Single column name or list of columns for composite uniqueness. timing: When to run the check. component_name: For ``AFTER_COMPONENT`` timing, which component to attach the check to. """ cols = [columns] if isinstance(columns, str) else list(columns) col_label = "_".join(cols) def check(spark: SparkSession) -> CheckResult: df = spark.table(table) total = df.count() distinct = df.select(*cols).distinct().count() duplicates = total - distinct passed = duplicates == 0 return CheckResult( check_name=f"unique_{table}_{col_label}", passed=passed, message=(f"Columns [{', '.join(cols)}] in {table}: {duplicates} duplicate rows out of {total}"), details={"total": total, "distinct": distinct, "duplicates": duplicates}, ) return DataQualityCheck( name=f"unique_{table}_{col_label}", timing=timing, check_fn=check, component_name=component_name, )
[docs] def range_check( table: str, column: str, min_value: float | int | None = None, max_value: float | int | None = None, timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Check that all values in *column* fall within [min_value, max_value]. At least one of *min_value* or *max_value* must be provided. Args: table: Registered temp view or table name. column: Numeric column to validate. min_value: Minimum allowed value (inclusive), or ``None`` for no lower bound. max_value: Maximum allowed value (inclusive), or ``None`` for no upper bound. timing: When to run the check. component_name: For ``AFTER_COMPONENT`` timing, which component to attach the check to. Raises: ValueError: If neither *min_value* nor *max_value* is provided. """ if min_value is None and max_value is None: msg = "At least one of min_value or max_value must be provided" raise ValueError(msg) def check(spark: SparkSession) -> CheckResult: df = spark.table(table) conditions: list[str] = [] if min_value is not None: conditions.append(f"{column} < {min_value}") if max_value is not None: conditions.append(f"{column} > {max_value}") filter_expr = " OR ".join(conditions) violations = df.filter(filter_expr).count() total = df.count() passed = violations == 0 if min_value is not None and max_value is not None: bound_desc = f"[{min_value}, {max_value}]" elif min_value is not None: bound_desc = f"[{min_value}, ∞)" else: bound_desc = f"(-∞, {max_value}]" return CheckResult( check_name=f"range_{table}_{column}", passed=passed, message=(f"Column {column} in {table}: {violations} out-of-range rows (expected {bound_desc})"), details={ "violations": violations, "total": total, "min_value": min_value, "max_value": max_value, }, ) return DataQualityCheck( name=f"range_{table}_{column}", timing=timing, check_fn=check, component_name=component_name, )
# Mapping from DataType enum values to Spark type name strings. _DATATYPE_TO_SPARK: dict[str, str] = { "string": "string", "integer": "int", "long": "bigint", "float": "float", "double": "double", "boolean": "boolean", "timestamp": "timestamp", "date": "date", "binary": "binary", "array": "array", "map": "map", "struct": "struct", }
[docs] def schema_check( table: str, schema: SchemaDefinition, check_types: bool = True, timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Check that *table* matches the expected *schema*. Verifies that all fields defined in *schema* exist in the table. When *check_types* is ``True`` (the default), also verifies that each field's data type matches. Args: table: Registered temp view or table name. schema: Expected schema definition. check_types: Whether to validate data types in addition to column names. timing: When to run the check. component_name: For ``AFTER_COMPONENT`` timing, which component to attach the check to. """ def check(spark: SparkSession) -> CheckResult: df = spark.table(table) actual_dtypes = dict(df.dtypes) # {"col": "type_string"} missing: list[str] = [] type_mismatches: list[str] = [] for field in schema.fields: if field.name not in actual_dtypes: missing.append(field.name) elif check_types: dt_key = field.data_type.value if hasattr(field.data_type, "value") else str(field.data_type) expected_type = _DATATYPE_TO_SPARK.get(dt_key, dt_key) actual_type = actual_dtypes[field.name] if actual_type != expected_type: type_mismatches.append(f"{field.name}: expected {expected_type}, got {actual_type}") passed = len(missing) == 0 and len(type_mismatches) == 0 issues: list[str] = [] if missing: issues.append(f"missing columns: {missing}") if type_mismatches: issues.append(f"type mismatches: {type_mismatches}") return CheckResult( check_name=f"schema_{table}", passed=passed, message=(f"Schema check for {table}: {'OK' if passed else '; '.join(issues)}"), details={ "missing": missing, "type_mismatches": type_mismatches, }, ) return DataQualityCheck( name=f"schema_{table}", timing=timing, check_fn=check, component_name=component_name, )
[docs] def custom_sql_check( name: str, sql: str, timing: CheckTiming = CheckTiming.AFTER_PIPELINE, component_name: str | None = None, ) -> DataQualityCheck: """Run an arbitrary SQL expression as a data quality check. The *sql* statement must return a single row with at least a ``passed`` column (boolean). An optional ``message`` column provides a custom result message. Example:: custom_sql_check( name="no_negative_amounts", sql="SELECT COUNT(*) = 0 AS passed FROM orders WHERE amount < 0", ) Args: name: Unique check name. sql: SQL query returning a row with a ``passed`` boolean column. timing: When to run the check. component_name: For ``AFTER_COMPONENT`` timing, which component to attach the check to. """ def check(spark: SparkSession) -> CheckResult: row = spark.sql(sql).head() if row is None: return CheckResult( check_name=name, passed=False, message=f"Custom SQL check '{name}': query returned no rows", details={"sql": sql}, ) passed = bool(row["passed"]) message = str(row["message"]) if "message" in row.asDict() else "" return CheckResult( check_name=name, passed=passed, message=message or f"Custom SQL check '{name}': {'PASSED' if passed else 'FAILED'}", details={"sql": sql}, ) return DataQualityCheck( name=name, timing=timing, check_fn=check, component_name=component_name, )