kups.core.assertion
¶
JAX-compatible assertion tracing system with optional automatic fixing.
This module provides a comprehensive assertion system that works seamlessly with JAX transformations, including JIT compilation, automatic differentiation, and vectorization. Assertions can include optional fix functions for automatic error recovery.
Key Components:
- RuntimeAssertion: Core assertion dataclass with optional fixing capabilities
- runtime_assert: Function to create assertions that work with JAX transformations
- with_runtime_assertions: Decorator to enable assertion tracing in functions
NO_ARGS = _NO_ARGS()
module-attribute
¶
Sentinel instance to distinguish between no arguments and None as an argument.
RuntimeAssertion
¶
A runtime assertion that validates computations with optional automatic fixing.
This class encapsulates a predicate that should evaluate to True, along with metadata for error reporting and optional repair mechanisms. Assertions are designed to work seamlessly with JAX transformations while providing rich debugging information.
Class Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
The type of state that can be modified by the fix function |
required | |
FixArgs
|
The types of arguments passed to the fix function (can be a PyTree) |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
predicate |
Array
|
A scalar boolean array indicating whether the assertion passes |
message |
str
|
Human-readable error message with optional format placeholders |
fmt_args |
dict[str, Array]
|
Dictionary of values to substitute into the message format string |
exception_type |
type[Exception]
|
Type of exception to raise on assertion failure |
static_info |
dict[str, Any]
|
Additional metadata for debugging (not traced by JAX) |
fix_fn |
Fix[State, FixArgs] | None
|
Optional function to repair the state when assertion fails |
fix_args |
FixArgs | _NO_ARGS
|
Arguments to pass to the fix function (can be complex PyTree structures) |
Example
Note
The fix_args can be complex PyTree structures including nested dictionaries, tuples, and arrays. The assertion system properly handles flattening and unflattening these structures during JAX transformations.
Source code in src/kups/core/assertion.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | |
exception
property
¶
Create the exception instance that would be raised on assertion failure.
Returns:
| Type | Description |
|---|---|
Exception
|
An exception instance with the formatted error message |
__str__()
¶
check()
¶
Check the assertion and raise an exception if it fails.
Raises:
| Type | Description |
|---|---|
Exception
|
The configured exception type if the assertion fails |
Source code in src/kups/core/assertion.py
failed()
¶
fix(state)
¶
Attempt to fix the assertion failure by modifying the provided state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
The current state that needs to be repaired |
required |
Returns:
| Type | Description |
|---|---|
State
|
The modified state after applying the fix function |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If no fix function is available |
AssertionError
|
If fix arguments are missing when a fix function exists |
Source code in src/kups/core/assertion.py
check_assertions(like=None)
¶
A primitive that returns a scalar bool. When not wrapped in with_runtime_assertions, it always returns True. When wrapped, it will return the conjunction of all runtime assertions in the current context.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
like
|
Array | None
|
An array whose device placement and sharding will be used for the output. If None, the output will be placed on the default device. |
None
|
Returns: A scalar boolean array indicating whether all assertions pass.
Source code in src/kups/core/assertion.py
cond_handler(interpreter, ctx, eqn, invals)
¶
Custom cond handler that normalizes traceback strings before branch comparison.
runtime_assert appends source-location tracebacks to assertion messages. Since messages are static pytree fields, assertions at different source lines produce different treedefs. This handler strips traceback suffixes so that branches with the same base assertion compare as equal.
Source code in src/kups/core/assertion.py
runtime_assert(predicate, message='', fmt_args=None, exception_type=AssertionError, static_info=None, fix_fn=None, fix_args=NO_ARGS)
¶
Create a runtime assertion that integrates with JAX transformations.
This function creates assertions that can be traced through JAX transformations including JIT compilation, automatic differentiation, and vectorization. The assertion acts as an identity function during execution but records assertion metadata for later inspection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predicate
|
Array
|
A boolean array indicating whether the assertion passes |
required |
message
|
str
|
Error message with optional format placeholders (e.g., "Value {val} is invalid") |
''
|
fmt_args
|
dict[str, Array] | None
|
Dictionary mapping format placeholder names to values |
None
|
exception_type
|
type[Exception]
|
Type of exception to raise if assertion fails during checking |
AssertionError
|
static_info
|
dict[str, Any] | None
|
Additional metadata for debugging (not traced by JAX) |
None
|
fix_fn
|
Fix[State, FixArgs] | None
|
Optional function to repair state when assertion fails |
None
|
fix_args
|
FixArgs | _NO_ARGS
|
Arguments for the fix function (can be complex PyTree structures) |
NO_ARGS
|
Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
State
|
Type of state that can be modified by the fix function |
required | |
FixArgs
|
Type of arguments passed to the fix function (supports PyTree structures) |
required |
Example
Basic assertion:
x = jnp.array(5.0)
runtime_assert(
predicate=x > 0,
message="Value must be positive, got {val}",
fmt_args={"val": x}
)
Assertion with fixing:
runtime_assert(
predicate=x > threshold,
message="Value {val} below threshold {thresh}",
fmt_args={"val": x, "thresh": threshold},
fix_fn=lambda state, args: jnp.maximum(state, args["min_val"]),
fix_args={"min_val": jnp.array(1.0)}
)
Complex PyTree fix_args:
Note
The fix_args parameter supports arbitrarily nested PyTree structures including dictionaries, tuples, and arrays. These are automatically flattened and unflattened during JAX transformations.
Source code in src/kups/core/assertion.py
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 | |
with_runtime_assertions(fn, policy=InterpreterPolicy.RAISE, context_sharding=None)
¶
Decorator that enables runtime assertion tracing for JAX functions.
This decorator wraps a function to intercept and collect all runtime assertions
created with runtime_assert during execution. The wrapped function returns
both the original result and a tuple of all assertions that were evaluated,
allowing for post-execution analysis, debugging, and optional error recovery.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[P, R]
|
The function to wrap with assertion tracing capabilities |
required |
policy
|
InterpreterPolicy
|
Controls interpreter behavior on unhandled operations: - RAISE: Raise exception on unknown operations (default, safest) - WARN: Issue warning and continue with original function - SKIP: Silently continue with original function |
RAISE
|
context_sharding
|
PartitionSpec | None
|
Optional sharding specification for distributed contexts. When provided, assertion contexts are sharded according to this spec for multi-device computations. |
None
|
Returns:
| Type | Description |
|---|---|
Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]
|
A wrapped function that returns a tuple of (original_result, assertions_tuple). |
Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]
|
The assertions tuple contains all RuntimeAssertion instances encountered |
Callable[P, tuple[R, tuple[RuntimeAssertion, ...]]]
|
during execution, preserving order and enabling post-hoc analysis. |
Type Parameters:
| Name | Bound or Constraints | Description | Default |
|---|---|---|---|
P
|
Parameter specification of the wrapped function (ParamSpec) |
required | |
R
|
Return type of the wrapped function |
required |
Example
Basic usage:
@with_runtime_assertions
def validate_computation(x):
runtime_assert(x > 0, "x must be positive")
return x ** 2
result, assertions = validate_computation(jnp.array(5.0))
# result = 25.0, assertions contains one RuntimeAssertion
With custom policy:
traced_fn = with_runtime_assertions(
my_function,
policy=InterpreterPolicy.WARN
)
result, assertions = traced_fn(inputs)
Distributed computation:
sharded_fn = with_runtime_assertions(
distributed_computation,
context_sharding=jax.sharding.PartitionSpec('data', None)
)
Error analysis and recovery:
result, assertions = traced_fn(initial_state)
# Check for failures
failed_assertions = [a for a in assertions if a.failed()]
if failed_assertions:
# Attempt automatic fixing
fixed_state = initial_state
for assertion in failed_assertions:
if assertion.fix_fn is not None:
fixed_state = assertion.fix(fixed_state)
# Re-run with fixed state
result, _ = traced_fn(fixed_state)
Note
The decorator integrates seamlessly with JAX transformations including jit, vmap, grad, and scan. Assertions are properly threaded through control flow operations and maintain correct semantics under transformations.
Source code in src/kups/core/assertion.py
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 | |