Source code for openmnglab.datamodel.pandas.verification

from __future__ import annotations

from typing import Optional, Any

import pandera as pa
from pandera.api.base.schema import BaseSchema
from pandera.api.pandas.array import ArraySchema
from pandera.api.pandas.types import PandasDtypeInputTypes
from pandera.dtypes import is_subdtype


[docs]class ComparisonError(Exception): def __init__(self, subject: str, base_val: Any, other_val: Any, reason: str = "not equal"): self.subject = subject self.base_val = base_val self.other_val = other_val self.reason = reason def __str__(self): return f"Failed comparison on subject {self.subject}: {self.reason} (base: {self.base_val}, other: {self.other_val})"
[docs]def compare_dtype(base: Optional[PandasDtypeInputTypes], other: Optional[PandasDtypeInputTypes]) -> bool: if base is not None and (other is None or not is_subdtype(other, base)): raise ComparisonError("datatype", base, other, reason="not a subtype") return True
[docs]def compare_baseschema(base: BaseSchema, other: BaseSchema) -> bool: if base.name is not None and (other.name is None or base.name != other.name): raise ComparisonError("name", base.name, other.name) return compare_dtype(base.dtype, other.dtype)
[docs]def compare_arrayschema(base: ArraySchema, other: ArraySchema) -> bool: if base.nullable != other.nullable: raise ComparisonError("nullability", base.nullable, other.nullable) if base.unique != other.unique: raise ComparisonError("unique elements", base.unique, other.unique) return compare_baseschema(base, other)
[docs]def compare_index(base: Optional[pa.Index | pa.MultiIndex], other: Optional[pa.Index | pa.MultiIndex]) -> bool: if base is None: return True if isinstance(base, pa.Index) and isinstance(other, pa.Index): return compare_arrayschema(base, other) if isinstance(base, pa.MultiIndex) and isinstance(other, pa.MultiIndex): return all( compare_arrayschema(base_i, other_i) for base_i, other_i in zip(base.indexes, other.indexes)) if not isinstance(other, type(base)): raise ComparisonError("type", type(base), type(other)) return True
[docs]def compare_column(base: pa.Column, other: pa.Column) -> bool: if base.required and not other.required: raise ComparisonError("required", base.required, other.required) return compare_arrayschema(base, other)
[docs]def compare_columns(base: Optional[dict[Any, pa.Column]], other: Optional[dict[Any, pa.Column]]) -> bool: if base is None: return True for key, column in base.items(): if key not in other: raise ComparisonError("column", key, None, "exists") assert (compare_column(column, other[key])) return True
[docs]def compare_dataframe_schemas(base: pa.DataFrameSchema, other: pa.DataFrameSchema) -> bool: if base.ordered and not other.ordered: raise ComparisonError("ordered", base.ordered, other.ordered) if base.unique != other.unique: raise ComparisonError("uniqueness", base.unique, other.unique) return compare_baseschema(base, other) and \ compare_index(base.index, other.index) and \ compare_columns(base.columns, other.columns)
[docs]def compare_series_schemas(base: pa.SeriesSchema, other: pa.SeriesSchema) -> bool: return compare_arrayschema(base, other) and compare_index(base.index, other.index)
[docs]def compare_schemas(base, other) -> bool: if isinstance(base, pa.SeriesSchema): if not isinstance(other, pa.SeriesSchema): raise ComparisonError("schema type", type(base), type(other)) return compare_series_schemas(base, other) elif isinstance(base, pa.DataFrameSchema): if not isinstance(other, pa.DataFrameSchema): raise ComparisonError("schema type", type(base), type(other)) return compare_dataframe_schemas(base, other) raise TypeError("provided base schema is neither a DataFrameSchema, nor a SeriesSchema") return False