kups.core.utils.ops
¶
Array operations with broadcasting utilities.
Provides helper functions for array dimension expansion, conditional selection with automatic broadcasting, and axis-specific padding.
expand_last_dims(operand, other)
¶
Expand trailing dimensions of operand to match other's rank.
Appends size-1 dimensions so that operand can broadcast against an
array (or shape tuple) with more axes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operand
|
Array
|
Input array to expand. |
required |
other
|
Array | tuple[int, ...]
|
Reference array or shape tuple whose rank is the target. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
Source code in src/kups/core/utils/ops.py
pad_axis(operand, to_pad, axis)
¶
Pad a single axis of an array with zeros.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operand
|
Array
|
Array to pad. |
required |
to_pad
|
tuple[int, int]
|
|
required |
axis
|
int
|
Axis index to pad. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Padded array with the same dtype as |
Source code in src/kups/core/utils/ops.py
select_n(which, *cands)
¶
Like jax.lax.select_n but short-circuits when all candidates are identical.
At trace time, if every candidate is the same tracer (is check),
the selection is a no-op and the single candidate is returned directly,
avoiding an unnecessary select_n primitive in the jaxpr.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
which
|
Array
|
Integer array indexing into |
required |
*cands
|
Array
|
Candidate arrays to select from. |
()
|
Returns:
| Type | Description |
|---|---|
Array
|
The selected array, or |
Array
|
are the same object. |
Source code in src/kups/core/utils/ops.py
where_broadcast_last(condition, x, y)
¶
Element-wise jnp.where with condition broadcast on trailing dims.
Expands condition to match the shapes of x and y before
selecting, so a lower-rank condition naturally broadcasts over trailing
feature dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
condition
|
Array
|
Boolean array used for selection. |
required |
x
|
Array | ArrayLike
|
Values selected where |
required |
y
|
Array | ArrayLike
|
Values selected where |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array with shape |