Skip to content

kups.core.assertion

JAX-compatible assertion tracing system with optional automatic fixing.

This module provides a comprehensive assertion system that works seamlessly with JAX transformations, including JIT compilation, automatic differentiation, and vectorization. Assertions can include optional fix functions for automatic error recovery.

Key Components:

NO_ARGS = _NO_ARGS() module-attribute

Sentinel instance to distinguish between no arguments and None as an argument.

RuntimeAssertion

A runtime assertion that validates computations with optional automatic fixing.

This class encapsulates a predicate that should evaluate to True, along with metadata for error reporting and optional repair mechanisms. Assertions are designed to work seamlessly with JAX transformations while providing rich debugging information.

Class Type Parameters:

Name Bound or Constraints Description Default
State

The type of state that can be modified by the fix function

required
FixArgs

The types of arguments passed to the fix function (can be a PyTree)

required

Attributes:

Name Type Description
predicate Array

A scalar boolean array indicating whether the assertion passes

message str

Human-readable error message with optional format placeholders

fmt_args dict[str, Array]

Dictionary of values to substitute into the message format string

exception_type type[Exception]

Type of exception to raise on assertion failure

static_info dict[str, Any]

Additional metadata for debugging (not traced by JAX)

fix_fn Fix[State, FixArgs] | None

Optional function to repair the state when assertion fails

fix_args FixArgs | _NO_ARGS

Arguments to pass to the fix function (can be complex PyTree structures)

Example
assertion = RuntimeAssertion(
    predicate=jnp.array(x > 0),
    message="Value must be positive, got {val}",
    fmt_args={"val": x},
    fix_fn=lambda state, threshold: jnp.maximum(state, threshold),
    fix_args=0.1
)
Note

The fix_args can be complex PyTree structures including nested dictionaries, tuples, and arrays. The assertion system properly handles flattening and unflattening these structures during JAX transformations.

Source code in src/kups/core/assertion.py
@dataclass
class RuntimeAssertion[State, FixArgs]:
    """
    A runtime assertion that validates computations with optional automatic fixing.

    This class encapsulates a predicate that should evaluate to True, along with
    metadata for error reporting and optional repair mechanisms. Assertions are
    designed to work seamlessly with JAX transformations while providing rich
    debugging information.

    Type Parameters:
        State: The type of state that can be modified by the fix function
        FixArgs: The types of arguments passed to the fix function (can be a PyTree)

    Attributes:
        predicate: A scalar boolean array indicating whether the assertion passes
        message: Human-readable error message with optional format placeholders
        fmt_args: Dictionary of values to substitute into the message format string
        exception_type: Type of exception to raise on assertion failure
        static_info: Additional metadata for debugging (not traced by JAX)
        fix_fn: Optional function to repair the state when assertion fails
        fix_args: Arguments to pass to the fix function (can be complex PyTree structures)

    Example:
        ```python
        assertion = RuntimeAssertion(
            predicate=jnp.array(x > 0),
            message="Value must be positive, got {val}",
            fmt_args={"val": x},
            fix_fn=lambda state, threshold: jnp.maximum(state, threshold),
            fix_args=0.1
        )
        ```

    Note:
        The fix_args can be complex PyTree structures including nested dictionaries,
        tuples, and arrays. The assertion system properly handles flattening and
        unflattening these structures during JAX transformations.
    """

    predicate: Array
    message: str = field(static=True)
    fmt_args: dict[str, Array] = field(default_factory=dict)
    exception_type: type[Exception] = field(static=True, default=AssertionError)
    static_info: dict[str, Any] = field(static=True, default_factory=dict)
    fix_fn: Fix[State, FixArgs] | None = field(static=True, default=None)
    fix_args: FixArgs | _NO_ARGS = field(default=NO_ARGS)

    def valid(self) -> bool:
        """Check if the assertion is valid (i.e., passes)."""
        return bool(self.predicate)

    def failed(self) -> bool:
        """Check if the assertion has failed."""
        return not self.valid()

    def __str__(self) -> str:
        """Return the formatted assertion message."""
        return self.message.format(**self.fmt_args)

    def check(self):
        """
        Check the assertion and raise an exception if it fails.

        Raises:
            Exception: The configured exception type if the assertion fails
        """
        if not bool(jnp.all(self.predicate)):
            raise self.exception_type(self.message.format(**self.fmt_args))

    @property
    def exception(self) -> Exception:
        """
        Create the exception instance that would be raised on assertion failure.

        Returns:
            An exception instance with the formatted error message
        """
        return self.exception_type(
            self.message.format(**(self.fmt_args | self.static_info))
        )

    def fix(self, state: State) -> State:
        """
        Attempt to fix the assertion failure by modifying the provided state.

        Args:
            state: The current state that needs to be repaired

        Returns:
            The modified state after applying the fix function

        Raises:
            NotImplementedError: If no fix function is available
            AssertionError: If fix arguments are missing when a fix function exists
        """
        if self.fix_fn is None:
            raise self.exception
        assert not isinstance(self.fix_args, _NO_ARGS), (
            "Fix arguments were not provided."
        )
        return self.fix_fn(state, self.fix_args)

exception property

Create the exception instance that would be raised on assertion failure.

Returns:

Type Description
Exception

An exception instance with the formatted error message

__str__()

Return the formatted assertion message.

Source code in src/kups/core/assertion.py
def __str__(self) -> str:
    """Return the formatted assertion message."""
    return self.message.format(**self.fmt_args)

check()

Check the assertion and raise an exception if it fails.

Raises:

Type Description
Exception

The configured exception type if the assertion fails

Source code in src/kups/core/assertion.py
def check(self):
    """
    Check the assertion and raise an exception if it fails.

    Raises:
        Exception: The configured exception type if the assertion fails
    """
    if not bool(jnp.all(self.predicate)):
        raise self.exception_type(self.message.format(**self.fmt_args))

failed()

Check if the assertion has failed.

Source code in src/kups/core/assertion.py
def failed(self) -> bool:
    """Check if the assertion has failed."""
    return not self.valid()

fix(state)

Attempt to fix the assertion failure by modifying the provided state.

Parameters:

Name Type Description Default
state State

The current state that needs to be repaired

required

Returns:

Type Description
State

The modified state after applying the fix function

Raises:

Type Description
NotImplementedError

If no fix function is available

AssertionError

If fix arguments are missing when a fix function exists

Source code in src/kups/core/assertion.py
def fix(self, state: State) -> State:
    """
    Attempt to fix the assertion failure by modifying the provided state.

    Args:
        state: The current state that needs to be repaired

    Returns:
        The modified state after applying the fix function

    Raises:
        NotImplementedError: If no fix function is available
        AssertionError: If fix arguments are missing when a fix function exists
    """
    if self.fix_fn is None:
        raise self.exception
    assert not isinstance(self.fix_args, _NO_ARGS), (
        "Fix arguments were not provided."
    )
    return self.fix_fn(state, self.fix_args)

valid()

Check if the assertion is valid (i.e., passes).

Source code in src/kups/core/assertion.py
def valid(self) -> bool:
    """Check if the assertion is valid (i.e., passes)."""
    return bool(self.predicate)

check_assertions(like=None)

A primitive that returns a scalar bool. When not wrapped in with_runtime_assertions, it always returns True. When wrapped, it will return the conjunction of all runtime assertions in the current context.

Parameters:

Name Type Description Default
like Array | None

An array whose device placement and sharding will be used for the output. If None, the output will be placed on the default device.

None

Returns: A scalar boolean array indicating whether all assertions pass.

Source code in src/kups/core/assertion.py
def check_assertions(like: Array | None = None) -> Array:
    """A primitive that returns a scalar bool. When not wrapped in with_runtime_assertions,
    it always returns True. When wrapped, it will return the conjunction of all
    runtime assertions in the current context.

    Args:
        like: An array whose device placement and sharding will be used for the output.
            If None, the output will be placed on the default device.
    Returns:
        A scalar boolean array indicating whether all assertions pass.
    """
    if like is None:
        like = jnp.array(True, dtype=jnp.bool)
    return check_assertion_p.bind(like)

cond_handler(interpreter, ctx, eqn, invals)

Custom cond handler that normalizes traceback strings before branch comparison.

runtime_assert appends source-location tracebacks to assertion messages. Since messages are static pytree fields, assertions at different source lines produce different treedefs. This handler strips traceback suffixes so that branches with the same base assertion compare as equal.

Source code in src/kups/core/assertion.py
def cond_handler(
    interpreter: Interpreter[AssertionContext],
    ctx: AssertionContext,
    eqn: JaxprEqn,
    invals: list[TracerValue],
) -> HandlerResult[AssertionContext]:
    """Custom cond handler that normalizes traceback strings before branch comparison.

    runtime_assert appends source-location tracebacks to assertion messages. Since
    messages are static pytree fields, assertions at different source lines produce
    different treedefs. This handler strips traceback suffixes so that branches with
    the same base assertion compare as equal.
    """
    _, bind_params = get_bind_params(eqn)
    branches = bind_params["branches"]

    context_aware_branch_fns = [
        reinterpret(jaxpr_as_fun(jaxpr), interpreter) for jaxpr in branches
    ]

    # Check that all branches produce the same context structure by tracing them with a dummy context
    branch_ctx_trees = [
        jax.jit(branch_fn).trace(ctx.push(), *invals[1:]).out_info
        for branch_fn in context_aware_branch_fns
    ]
    assert len(branches) > 0, "cond must have at least one branch"
    # Normalize by stripping tracebacks before structural comparison
    normalized = [_normalize_ctx_for_comparison(t) for t in branch_ctx_trees]
    try:
        for tree in normalized[1:]:
            _assert_same_tree(normalized[0], tree)
    except ValueError as e:
        raise ValueError("Cond branches return inconsistent contexts.") from e

    # Wrap branches to strip tracebacks from output contexts so that
    # jax.lax.cond sees matching pytree structures across branches.
    def _wrap_branch(fn: Callable) -> Callable:
        def wrapped(ctx: AssertionContext, *args: Any) -> Any:
            outvals, ctx_out = fn(ctx, *args)
            return outvals, _strip_ctx_tracebacks(
                ctx_out, note="\n(traceback unavailable: assertion inside cond branch)"
            )

        return wrapped

    normalized_branch_fns = [_wrap_branch(fn) for fn in context_aware_branch_fns]

    outvals, ctx_out = jax.lax.cond(
        invals[0], *reversed(normalized_branch_fns), ctx, *invals[1:]
    )
    return HandlerResult(ctx_out, outvals)

runtime_assert(predicate, message='', fmt_args=None, exception_type=AssertionError, static_info=None, fix_fn=None, fix_args=NO_ARGS)

Create a runtime assertion that integrates with JAX transformations.

This function creates assertions that can be traced through JAX transformations including JIT compilation, automatic differentiation, and vectorization. The assertion acts as an identity function during execution but records assertion metadata for later inspection.

Parameters:

Name Type Description Default
predicate Array

A boolean array indicating whether the assertion passes

required
message str

Error message with optional format placeholders (e.g., "Value {val} is invalid")

''
fmt_args dict[str, Array] | None

Dictionary mapping format placeholder names to values

None
exception_type type[Exception]

Type of exception to raise if assertion fails during checking

AssertionError
static_info dict[str, Any] | None

Additional metadata for debugging (not traced by JAX)

None
fix_fn Fix[State, FixArgs] | None

Optional function to repair state when assertion fails

None
fix_args FixArgs | _NO_ARGS

Arguments for the fix function (can be complex PyTree structures)

NO_ARGS

Type Parameters:

Name Bound or Constraints Description Default
State

Type of state that can be modified by the fix function

required
FixArgs

Type of arguments passed to the fix function (supports PyTree structures)

required
Example

Basic assertion:

x = jnp.array(5.0)
runtime_assert(
    predicate=x > 0,
    message="Value must be positive, got {val}",
    fmt_args={"val": x}
)

Assertion with fixing:

runtime_assert(
    predicate=x > threshold,
    message="Value {val} below threshold {thresh}",
    fmt_args={"val": x, "thresh": threshold},
    fix_fn=lambda state, args: jnp.maximum(state, args["min_val"]),
    fix_args={"min_val": jnp.array(1.0)}
)

Complex PyTree fix_args:

complex_args = {
    "thresholds": {"min": jnp.array(0.1), "max": jnp.array(10.0)},
    "multipliers": (jnp.array(2.0), jnp.array(3.0))
}
runtime_assert(
    predicate=x > 0,
    message="Invalid value",
    fix_args=complex_args
)

Note

The fix_args parameter supports arbitrarily nested PyTree structures including dictionaries, tuples, and arrays. These are automatically flattened and unflattened during JAX transformations.

Source code in src/kups/core/assertion.py
def runtime_assert[State, FixArgs](
    predicate: Array,
    message: str = "",
    fmt_args: dict[str, Array] | None = None,
    exception_type: type[Exception] = AssertionError,
    static_info: dict[str, Any] | None = None,
    fix_fn: Fix[State, FixArgs] | None = None,
    fix_args: FixArgs | _NO_ARGS = NO_ARGS,
):
    """
    Create a runtime assertion that integrates with JAX transformations.

    This function creates assertions that can be traced through JAX transformations
    including JIT compilation, automatic differentiation, and vectorization. The
    assertion acts as an identity function during execution but records assertion
    metadata for later inspection.

    Args:
        predicate: A boolean array indicating whether the assertion passes
        message: Error message with optional format placeholders (e.g., "Value {val} is invalid")
        fmt_args: Dictionary mapping format placeholder names to values
        exception_type: Type of exception to raise if assertion fails during checking
        static_info: Additional metadata for debugging (not traced by JAX)
        fix_fn: Optional function to repair state when assertion fails
        fix_args: Arguments for the fix function (can be complex PyTree structures)

    Type Parameters:
        State: Type of state that can be modified by the fix function
        FixArgs: Type of arguments passed to the fix function (supports PyTree structures)

    Example:
        Basic assertion:
        ```python
        x = jnp.array(5.0)
        runtime_assert(
            predicate=x > 0,
            message="Value must be positive, got {val}",
            fmt_args={"val": x}
        )
        ```

        Assertion with fixing:
        ```python
        runtime_assert(
            predicate=x > threshold,
            message="Value {val} below threshold {thresh}",
            fmt_args={"val": x, "thresh": threshold},
            fix_fn=lambda state, args: jnp.maximum(state, args["min_val"]),
            fix_args={"min_val": jnp.array(1.0)}
        )
        ```

        Complex PyTree fix_args:
        ```python
        complex_args = {
            "thresholds": {"min": jnp.array(0.1), "max": jnp.array(10.0)},
            "multipliers": (jnp.array(2.0), jnp.array(3.0))
        }
        runtime_assert(
            predicate=x > 0,
            message="Invalid value",
            fix_args=complex_args
        )
        ```

    Note:
        The fix_args parameter supports arbitrarily nested PyTree structures including
        dictionaries, tuples, and arrays. These are automatically flattened and
        unflattened during JAX transformations.
    """
    if fmt_args is None:
        fmt_args = {}
    if static_info is None:
        static_info = {}

    tb = _capture_traceback().replace("{", "{{").replace("}", "}}")
    # Convert static_info to hashable format (tuple of key-value pairs)
    static_info_hashable = tuple(sorted(static_info.items()))

    # Prepare inputs: predicate, fmt_args values, fix_args flattened (if present)
    inputs = [predicate, *fmt_args.values()]
    fix_args_tree = None
    if not isinstance(fix_args, _NO_ARGS):
        # Flatten fix_args PyTree and add to inputs
        fix_args_flat, fix_args_tree = jax.tree.flatten(fix_args)
        inputs.extend(fix_args_flat)

    assertion_p.bind(
        *inputs,
        fmt_arg_names=tuple(fmt_args.keys()),
        message=message + f"{_TRACEBACK_MARKER}{tb}",
        exception_type=exception_type,
        static_info_hashable=static_info_hashable,
        fix_fn=fix_fn,
        fix_args_tree=fix_args_tree,
    )

with_runtime_assertions(fn, policy=InterpreterPolicy.RAISE, context_sharding=None)

Decorator that enables runtime assertion tracing for JAX functions.

This decorator wraps a function to intercept and collect all runtime assertions created with runtime_assert during execution. The wrapped function returns both the original result and a tuple of all assertions that were evaluated, allowing for post-execution analysis, debugging, and optional error recovery.

Parameters:

Name Type Description Default
fn Callable[P, R]

The function to wrap with assertion tracing capabilities

required
policy InterpreterPolicy

Controls interpreter behavior on unhandled operations: - RAISE: Raise exception on unknown operations (default, safest) - WARN: Issue warning and continue with original function - SKIP: Silently continue with original function

RAISE
context_sharding PartitionSpec | None

Optional sharding specification for distributed contexts. When provided, assertion contexts are sharded according to this spec for multi-device computations.

None

Returns:

Type Description
Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]

A wrapped function that returns a tuple of (original_result, assertions_tuple).

Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]

The assertions tuple contains all RuntimeAssertion instances encountered

Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]

during execution, preserving order and enabling post-hoc analysis.

Type Parameters:

Name Bound or Constraints Description Default
P

Parameter specification of the wrapped function (ParamSpec)

required
R

Return type of the wrapped function

required
Example

Basic usage:

@with_runtime_assertions
def validate_computation(x):
    runtime_assert(x > 0, "x must be positive")
    return x ** 2

result, assertions = validate_computation(jnp.array(5.0))
# result = 25.0, assertions contains one RuntimeAssertion

With custom policy:

traced_fn = with_runtime_assertions(
    my_function,
    policy=InterpreterPolicy.WARN
)
result, assertions = traced_fn(inputs)

Distributed computation:

sharded_fn = with_runtime_assertions(
    distributed_computation,
    context_sharding=jax.sharding.PartitionSpec('data', None)
)

Error analysis and recovery:

result, assertions = traced_fn(initial_state)

# Check for failures
failed_assertions = [a for a in assertions if a.failed()]
if failed_assertions:
    # Attempt automatic fixing
    fixed_state = initial_state
    for assertion in failed_assertions:
        if assertion.fix_fn is not None:
            fixed_state = assertion.fix(fixed_state)

    # Re-run with fixed state
    result, _ = traced_fn(fixed_state)

Note

The decorator integrates seamlessly with JAX transformations including jit, vmap, grad, and scan. Assertions are properly threaded through control flow operations and maintain correct semantics under transformations.

Source code in src/kups/core/assertion.py
def with_runtime_assertions[**P, R](
    fn: Callable[P, R],
    policy: InterpreterPolicy = InterpreterPolicy.RAISE,
    context_sharding: jax.sharding.PartitionSpec | None = None,
) -> Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]:
    """
    Decorator that enables runtime assertion tracing for JAX functions.

    This decorator wraps a function to intercept and collect all runtime assertions
    created with `runtime_assert` during execution. The wrapped function returns
    both the original result and a tuple of all assertions that were evaluated,
    allowing for post-execution analysis, debugging, and optional error recovery.

    Args:
        fn: The function to wrap with assertion tracing capabilities
        policy: Controls interpreter behavior on unhandled operations:
            - RAISE: Raise exception on unknown operations (default, safest)
            - WARN: Issue warning and continue with original function
            - SKIP: Silently continue with original function
        context_sharding: Optional sharding specification for distributed contexts.
            When provided, assertion contexts are sharded according to this spec
            for multi-device computations.

    Returns:
        A wrapped function that returns a tuple of (original_result, assertions_tuple).
        The assertions tuple contains all RuntimeAssertion instances encountered
        during execution, preserving order and enabling post-hoc analysis.

    Type Parameters:
        P: Parameter specification of the wrapped function (ParamSpec)
        R: Return type of the wrapped function

    Example:
        Basic usage:
        ```python
        @with_runtime_assertions
        def validate_computation(x):
            runtime_assert(x > 0, "x must be positive")
            return x ** 2

        result, assertions = validate_computation(jnp.array(5.0))
        # result = 25.0, assertions contains one RuntimeAssertion
        ```

        With custom policy:
        ```python
        traced_fn = with_runtime_assertions(
            my_function,
            policy=InterpreterPolicy.WARN
        )
        result, assertions = traced_fn(inputs)
        ```

        Distributed computation:
        ```python
        sharded_fn = with_runtime_assertions(
            distributed_computation,
            context_sharding=jax.sharding.PartitionSpec('data', None)
        )
        ```

        Error analysis and recovery:
        ```python
        result, assertions = traced_fn(initial_state)

        # Check for failures
        failed_assertions = [a for a in assertions if a.failed()]
        if failed_assertions:
            # Attempt automatic fixing
            fixed_state = initial_state
            for assertion in failed_assertions:
                if assertion.fix_fn is not None:
                    fixed_state = assertion.fix(fixed_state)

            # Re-run with fixed state
            result, _ = traced_fn(fixed_state)
        ```

    Note:
        The decorator integrates seamlessly with JAX transformations including
        jit, vmap, grad, and scan. Assertions are properly threaded through
        control flow operations and maintain correct semantics under
        transformations.
    """
    dispatcher = Dispatcher(
        handlers={
            "assertion": assertion_handler,
            "check_assertion": check_assertion_handler,
            "jit": default_jit_handler,
            "scan": scan_handler,
            "while": while_handler,
            "cond": cond_handler,
            "shard_map": partial(shard_map_handler, context_sharding=context_sharding),
        }
    ).register_custom_matching_rule(_contains_assertion_primitive)

    interpreter = Interpreter(dispatcher, policy=policy, label="assertion_interpreter")

    reinterpreted = reinterpret(fn, interpreter=interpreter)

    def wrapped(
        *args: P.args, **kwargs: P.kwargs
    ) -> tuple[R, tuple[RuntimeAssertion, ...]]:
        ctx = AssertionContext()
        outvals, ctx = reinterpreted(ctx, *args, **kwargs)
        return outvals, ctx.assertions

    return wrapped