# Copyright (c) QuantCo 2023-2026
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import builtins
import math
from collections.abc import Sequence
from typing import Literal, NamedTuple
from warnings import warn
import numpy as np
from spox import Var
import ndonnx as ndx
from ndonnx.types import NestedSequence, OnnxShape, PyScalar
from ._array import Array, DType
from ._array_tyarray_interop import unwrap_tyarray
from ._namespace_info import Device
from ._typed_array import funcs as tyfuncs
from ._typed_array import onnx
[docs]
def argument(
*,
shape: OnnxShape,
dtype: ndx.DType,
) -> Array:
"""Creates a new lazy ndonnx array.
This is used to define inputs to an ONNX model.
Parameters
----------
shape
The shape of the array. String-dimensions denote symbolic dimensions and must be globally consistent.
`None`-dimensions denote unknown dimensions.
dtype
The data type of the array.
Returns
-------
Array
The new array representing input(s) of the computational graphs.
"""
return Array._argument(shape=shape, dtype=dtype)
[docs]
def asarray(
obj: Array | PyScalar | np.ndarray | NestedSequence | Var,
/,
*,
dtype: ndx.DType | None = None,
device: None | Device = None,
copy: bool | None = None,
) -> Array:
if not copy and copy is not None:
# Must copy or raise
if not isinstance(obj, Array):
raise ValueError(
f"cannot create 'Array' from object of type `{type(obj)}` without copying"
)
if obj.dtype == dtype or dtype is None:
return Array._from_tyarray(obj._tyarray)
raise ValueError(
f"cannot create Array with data type `{dtype}` from array "
f"of data type `{obj.dtype}` without copying data"
)
if isinstance(obj, Array):
return Array._from_tyarray(tyfuncs.astyarray(obj._tyarray, dtype=dtype))
else:
if isinstance(obj, Sequence) and not isinstance(obj, str | Array):
np_arr = np.asarray(obj, dtype=object)
if (
builtins.all(isinstance(el, Array) for el in np_arr.flatten())
and np_arr.size > 0
):
# This is a list of ndonnx arrays
warn(
"providing a sequence of 'Array's to 'asarray' is not defined by the array-api-standard and may be removed from ndonnx in the future",
stacklevel=2,
)
dtype = result_type(*np_arr.flatten()) if dtype is None else dtype
out = concat([a.astype(dtype)[None, ...] for a in np_arr.flatten()])
out_shape = concat([asarray(np_arr.shape), out.dynamic_shape[1:]])
return reshape(out, out_shape)
return Array._from_tyarray(tyfuncs.astyarray(obj, dtype=dtype))
[docs]
def from_dlpack(
x: object, /, *, device: None = None, copy: bool | None = None
) -> Array:
raise BufferError("ndonnx does not (yet) support the export of array data")
[docs]
def all(
x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> Array:
return Array._from_tyarray(x._tyarray.all(axis=axis, keepdims=keepdims))
[docs]
def any(
x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> Array:
return Array._from_tyarray(x._tyarray.any(axis=axis, keepdims=keepdims))
[docs]
def arange(
start: int | float | Array,
/,
stop: int | float | Array | None = None,
step: int | float | Array = 1,
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
for item in [start, stop, step]:
if item is None:
continue
if isinstance(item, Array):
if not item.ndim == 0:
raise ValueError("array arguments to 'arange' must be of rank 0")
elif not isinstance(item, int | float):
raise TypeError(
f"unexpected type for 'start', 'stop', or 'step': `{type(item)}`"
)
if stop is None:
stop = start
start = 0
array_or_scalar = [
el._tyarray if isinstance(el, Array) else el for el in [start, stop, step]
]
return Array._from_tyarray(tyfuncs.arange(dtype, *array_or_scalar))
[docs]
def argmax(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
return Array._from_tyarray(x._tyarray.argmax(axis=axis, keepdims=keepdims))
[docs]
def argmin(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
return Array._from_tyarray(x._tyarray.argmin(axis=axis, keepdims=keepdims))
[docs]
def count_nonzero(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
return Array._from_tyarray(x._tyarray.count_nonzero(axis=axis, keepdims=keepdims))
[docs]
def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
return Array._from_tyarray(
x._tyarray.argsort(axis=axis, descending=descending, stable=stable)
)
[docs]
def nonzero(x: Array, /) -> tuple[Array, ...]:
return tuple(Array._from_tyarray(el) for el in x._tyarray.nonzero())
[docs]
def astype(
x: Array, dtype: DType, /, *, copy: bool = True, device: None | Device = None
) -> Array:
if not copy and x.dtype == dtype:
return x
return x.astype(dtype)
[docs]
def broadcast_arrays(*arrays: Array) -> list[Array]:
if len(arrays) < 2:
return [a.copy() for a in arrays]
def numeric_like(x):
if isdtype(x, "numeric"):
return x
else:
return zeros_like(x, dtype=ndx.int64)
it = iter(arrays)
ret = numeric_like(next(it))
while (x := next(it, None)) is not None:
ret = ret + numeric_like(x)
target_shape = ret.dynamic_shape
return [broadcast_to(a, target_shape) for a in arrays]
[docs]
def broadcast_to(x: Array, /, shape: tuple[int, ...] | Array) -> Array:
if isinstance(shape, Array):
if not isinstance(shape._tyarray, onnx.TyArrayInt64):
raise ValueError(
f"dynamic shape must be of data type int64, found `{shape.dtype}`"
)
return Array._from_tyarray(x._tyarray.broadcast_to(shape._tyarray))
return Array._from_tyarray(x._tyarray.broadcast_to(shape)).copy()
[docs]
def can_cast(from_: DType | Array, to: DType, /) -> bool:
try:
result_type(from_, to)
return True
except TypeError:
return False
[docs]
def concat(
arrays: tuple[Array, ...] | list[Array], /, *, axis: None | int = 0
) -> Array:
data = tyfuncs.concat([arr._tyarray for arr in arrays], axis=axis)
return Array._from_tyarray(data)
[docs]
def cumulative_prod(
x: Array,
/,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
) -> Array:
data = x._tyarray.cumulative_prod(
axis=axis, dtype=dtype, include_initial=include_initial
)
return Array._from_tyarray(data)
[docs]
def cumulative_sum(
x: Array,
/,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
) -> Array:
data = x._tyarray.cumulative_sum(
axis=axis, dtype=dtype, include_initial=include_initial
)
return Array._from_tyarray(data)
[docs]
def max(
x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> Array:
"""Calculates the maximum value of the input array x.
Reduction over zero-sized inputs return the minimum possible value for the input
data type.
"""
return Array._from_tyarray(x._tyarray.max(axis=axis, keepdims=keepdims))
[docs]
def mean(
x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> Array:
return Array._from_tyarray(x._tyarray.mean(axis=axis, keepdims=keepdims))
[docs]
def meshgrid(*arrays: Array, indexing: str = "xy") -> list[Array]:
ndim = len(arrays)
if indexing not in ("xy", "ij"):
raise ValueError(
f"'indexing' argument must be one 'xy' or 'ij', found `{indexing}`"
)
base_shape = (1,) * ndim
output = [
reshape(x, base_shape[:i] + (-1,) + base_shape[i + 1 :])
for i, x in enumerate(arrays)
]
if indexing == "xy" and ndim > 1:
# switch first and second axis
output[0] = reshape(output[0], (1, -1) + base_shape[2:])
output[1] = reshape(output[1], (-1, 1) + base_shape[2:])
return broadcast_arrays(*output)
[docs]
def moveaxis(
x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /
) -> Array:
return Array._from_tyarray(x._tyarray.moveaxis(source, destination))
[docs]
def min(
x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> Array:
"""Calculates the minimum value of the input array x.
Reduction over zero-sized inputs return the maximum possible value for the input
data type.
"""
return Array._from_tyarray(x._tyarray.min(axis=axis, keepdims=keepdims))
[docs]
def prod(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
dtype: DType | None = None,
keepdims: bool = False,
) -> Array:
return Array._from_tyarray(
x._tyarray.prod(axis=axis, dtype=dtype, keepdims=keepdims)
)
[docs]
def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
return Array._from_tyarray(
x._tyarray.sort(axis=axis, descending=descending, stable=stable)
)
[docs]
def std(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
correction: int | float = 0.0,
keepdims: bool = False,
) -> Array:
return Array._from_tyarray(
x._tyarray.std(axis=axis, correction=correction, keepdims=keepdims)
)
[docs]
def sum(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
dtype: DType | None = None,
keepdims: bool = False,
) -> Array:
return Array._from_tyarray(
x._tyarray.sum(axis=axis, dtype=dtype, keepdims=keepdims)
)
[docs]
def var(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
correction: int | float = 0.0,
keepdims: bool = False,
) -> Array:
return Array._from_tyarray(
x._tyarray.variance(axis=axis, correction=correction, keepdims=keepdims)
)
[docs]
def empty(
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
return zeros(shape=shape, dtype=dtype)
[docs]
def empty_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
return zeros_like(x, dtype=dtype)
def equal(x1: Array, x2: Array, /) -> Array:
return x1 == x2
[docs]
def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
if not -x.ndim - 1 <= axis <= x.ndim:
raise IndexError(
"'axis' must be within `[-{x.ndim}-1, {x.ndim}]`, found `{axis}`"
)
if axis < 0:
axis = x.ndim + axis + 1
key = tuple(None if el == axis else slice(None, None) for el in range(x.ndim + 1))
return x[key]
[docs]
def expm1(x: Array, /) -> Array:
# Requires special operator to meet standards precision requirements
# TODO: Add upstream tracking issue
raise NotImplementedError
[docs]
def log1p(x: Array, /) -> Array:
# Requires special operator to meet standards precision requirements
# TODO: Add upstream tracking issue
raise NotImplementedError
[docs]
def conj(x: Array, /) -> Array:
# Support for complex numbers is very broken in the ONNX standard
raise NotImplementedError
[docs]
def imag(x: Array, /) -> Array:
# Support for complex numbers is very broken in the ONNX standard
raise NotImplementedError
[docs]
def eye(
n_rows: int,
n_cols: int | None = None,
/,
*,
k: int = 0,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
nparr = np.eye(n_rows, n_cols, k=k)
return asarray(nparr, dtype=dtype)
[docs]
def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
if axis is None:
index = [slice(None, None, -1) for _ in range(x.ndim)]
else:
index = [slice(None, None, 1) for _ in range(x.ndim)]
if isinstance(axis, int):
axis = (axis,)
for ax in axis:
index[ax] = slice(None, None, -1)
return x[tuple(index)]
[docs]
def full(
shape: int | tuple[int, ...] | Array,
fill_value: bool | int | float | str,
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
if dtype is None:
dtype = tyfuncs._infer_dtype(fill_value)
if isinstance(shape, int):
shape = (shape,)
elif isinstance(shape, Array) and shape.ndim == 0:
# Ensure shape is 1D
shape = shape[None]
if isinstance(shape, tuple):
if len(shape) == 0:
return asarray(fill_value, dtype=dtype)
if math.prod(shape) == 0:
# using `broadcast_to` to create an array with fewer elements
# is technically supported by the ONNX standard, but quite odd
# and seemingly broken in the onnxruntime.
return reshape(asarray([], dtype=dtype), shape=shape)
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
[docs]
def full_like(
x: Array,
/,
fill_value: bool | int | float | str,
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
shape = x.dynamic_shape
fill = asarray(fill_value, dtype=dtype or x.dtype)
return broadcast_to(fill, shape)
[docs]
def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool:
if isinstance(kind, str):
if kind == "bool":
return dtype == ndx.bool
elif kind == "signed integer":
return dtype in (ndx.int8, ndx.int16, ndx.int32, ndx.int64)
elif kind == "unsigned integer":
return dtype in (ndx.uint8, ndx.uint16, ndx.uint32, ndx.uint64)
elif kind == "integral":
return isinstance(dtype, onnx.IntegerDTypes)
elif kind == "real floating":
return isinstance(dtype, onnx.Floating)
elif kind == "complex floating":
# 'complex floating' is not supported"
return False
elif kind == "numeric":
return isinstance(dtype, onnx.NumericDTypes)
elif isinstance(kind, DType):
return dtype == kind
elif isinstance(kind, tuple):
return builtins.any(isdtype(dtype, k) for k in kind)
raise TypeError(f"kind must be a string or a dtype, not {type(kind)}")
[docs]
def linspace(
start: int | float | complex,
stop: int | float | complex,
/,
num: int,
*,
dtype: DType | None = None,
device: None | Device = None,
endpoint: bool = True,
) -> Array:
dtype = dtype or ndx._default_float
if not isinstance(dtype, onnx.DTypes):
raise ValueError(f"only primitive data types are supported, found `{dtype}`")
return asarray(np.linspace(start, stop, num=num, endpoint=endpoint), dtype=dtype)
[docs]
def matmul(x1: Array, x2: Array, /) -> Array:
return x1 @ x2
[docs]
def matrix_transpose(x: Array, /) -> Array:
return Array._from_tyarray(x._tyarray.mT)
[docs]
def ones(
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
dtype = dtype or ndx._default_float
shape = (shape,) if isinstance(shape, int) else shape
return Array._from_tyarray(tyfuncs.ones(dtype, shape))
[docs]
def ones_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
dtype = dtype or x.dtype
return full_like(x, 1, dtype=dtype)
[docs]
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
data = x._tyarray.permute_dims(axes=axes)
return Array._from_tyarray(data)
[docs]
def reshape(
x: Array, /, shape: tuple[int, ...] | Array, *, copy: bool | None = None
) -> Array:
if copy is not None and not copy:
raise ValueError("avoiding a copy in reshape operations is unsupported")
if isinstance(shape, Array):
shape_data = shape._tyarray
if not isinstance(shape_data, onnx.TyArrayInt64):
raise TypeError(
"'shape' must be of data type int64 if provided as an Array"
)
if not shape.ndim == 1 or not builtins.all(
isinstance(el, int) for el in shape.shape
):
# Otherwise, we lose the rank information which we assume
# to always be available.
raise ValueError(
"'shape' must be a 1D tensor of static shape if provided as an 'Array'"
)
return Array._from_tyarray(x._tyarray.reshape(shape_data))
return Array._from_tyarray(x._tyarray.reshape(shape))
[docs]
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
repeats_: int | onnx.TyArrayInt64
if isinstance(repeats, int):
repeats_ = repeats
elif isinstance(repeats._tyarray, onnx.TyArrayInteger):
repeats_ = repeats._tyarray.astype(onnx.int64)
else:
raise TypeError(
f"'repeats' argument must be of type 'int' or an array with an integer data type, found `{repeats}`"
)
return Array._from_tyarray(x._tyarray.repeat(repeats_, axis=axis))
[docs]
def result_type(*arrays_and_dtypes: Array | DType | PyScalar) -> DType:
def dtype_or_scalar(obj: Array | DType | PyScalar) -> DType | PyScalar:
if isinstance(obj, Array):
return obj.dtype
return obj
if len(arrays_and_dtypes) == 0:
raise ValueError("at least one array or dtype is required")
items = sorted(
arrays_and_dtypes,
key=lambda item: int(isinstance(item, Array | DType)),
reverse=True,
)
first, *others = items
if not isinstance(first, Array | DType):
raise ValueError(
"arguments to 'result_type' must contain at least one 'Array' or 'DType' object"
)
if isinstance(first, Array):
first = first.dtype
return tyfuncs.result_type(first, *(dtype_or_scalar(el) for el in others))
[docs]
def roll(
x: Array,
/,
shift: int | tuple[int, ...],
*,
axis: int | tuple[int, ...] | None = None,
) -> Array:
return Array._from_tyarray(x._tyarray.roll(shift=shift, axis=axis))
[docs]
def searchsorted(
x1: Array,
x2: Array,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Array | None = None,
) -> Array:
if sorter is None:
sorter_ = None
elif not isinstance(sorter._tyarray, onnx.TyArrayInteger):
raise TypeError(
f"'sorter' must have an integer data type, found `{sorter.dtype}`"
)
else:
sorter_ = sorter._tyarray
return Array._from_tyarray(
tyfuncs.searchsorted(x1._tyarray, x2._tyarray, side=side, sorter=sorter_)
)
[docs]
def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array:
arrays = [expand_dims(x, axis=axis) for x in arrays]
return concat(arrays, axis=axis)
[docs]
def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array:
return Array._from_tyarray(x._tyarray.squeeze(axis))
[docs]
def take(x: Array, indices: Array, /, *, axis: int | None = None) -> Array:
if not isinstance(indices._tyarray, onnx.TyArrayInt64):
raise TypeError(
f"'indices' must be of data type 'int64' found `{indices.dtype}`"
)
if indices.ndim != 1:
raise ValueError("'indices' must be a 1D array")
if axis is None and x.ndim > 1:
raise ValueError(
"'axis' argument must be provided if 'x' has more than one axis"
)
return Array._from_tyarray(x._tyarray.take(indices._tyarray, axis=axis))
[docs]
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
if not isinstance(indices._tyarray, onnx.TyArrayInt64):
raise TypeError(
f"'indices' must be of data type 'int64' found `{indices.dtype}`"
)
if indices.ndim != x.ndim:
raise ValueError("'x' and 'indices' must have the same number of axes")
if not (-x.ndim <= axis < x.ndim):
raise ValueError(
"'axis' argument must be compatible with number of axes in 'x'"
)
return Array._from_tyarray(x._tyarray.take_along_axis(indices._tyarray, axis=axis))
[docs]
def tile(x: Array, repetitions: tuple[int, ...], /) -> Array:
return Array._from_tyarray(x._tyarray.tile(repetitions))
[docs]
def tril(x: Array, /, *, k: int = 0) -> Array:
return Array._from_tyarray(x._tyarray.tril(k=k))
[docs]
def triu(x: Array, /, *, k: int = 0) -> Array:
return Array._from_tyarray(x._tyarray.triu(k=k))
[docs]
def tensordot(
x1: Array, x2: Array, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2
) -> Array:
res = x1._tyarray.__ndx_tensordot__(x2._tyarray, axes=axes)
if res == NotImplemented:
raise TypeError(
f"unsupported operand data types for 'tensordot': `{x1.dtype}` and `{x2.dtype}`"
)
return Array._from_tyarray(res)
class UniqueAll(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array
[docs]
def unique_all(x: Array, /) -> UniqueAll:
values, indices, inverse_indices, counts = (
Array._from_tyarray(tyarr) for tyarr in x._tyarray.unique_all()
)
return UniqueAll(values, indices, inverse_indices, counts)
class UniqueCounts(NamedTuple):
values: Array
counts: Array
[docs]
def unique_counts(x: Array, /) -> UniqueCounts:
uall = unique_all(x)
return UniqueCounts(uall.values, uall.counts)
class UniqueInverse(NamedTuple):
values: Array
inverse_indices: Array
[docs]
def unique_inverse(x: Array, /) -> UniqueInverse:
uall = unique_all(x)
return UniqueInverse(uall.values, uall.inverse_indices)
[docs]
def unique_values(x: Array, /) -> Array:
uall = unique_all(x)
return uall.values
[docs]
def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]:
# Only possible for statically known dimensions
if not isinstance(x.shape[axis], int):
raise ValueError(
f"'unstack' can only be applied to statically known dimensions, but axis `{axis}` has dynamic length."
)
return tuple(el for el in moveaxis(x, axis, 0))
[docs]
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
if x1._tyarray.shape[axis] != x2._tyarray.shape[axis]:
raise ValueError("summed over dimensions must match")
prod = x1 * x2
return sum(prod, axis=axis, dtype=prod.dtype)
[docs]
def where(
cond: Array,
a: Array | int | float | bool | str,
b: Array | int | float | bool | str,
) -> Array:
if not isinstance(cond._tyarray, onnx.TyArrayBool):
raise TypeError(f"'cond' must be of data type 'bool', found `{cond.dtype}`")
data = tyfuncs.where(cond._tyarray, unwrap_tyarray(a), unwrap_tyarray(b))
return Array._from_tyarray(data)
[docs]
def zeros(
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
dtype = dtype or ndx._default_float
shape = (shape,) if isinstance(shape, int) else shape
return Array._from_tyarray(tyfuncs.zeros(dtype, shape))
[docs]
def zeros_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
dtype = dtype or x.dtype
return full_like(x, 0, dtype=dtype)
[docs]
def diff(
a: Array,
/,
*,
axis: int = -1,
n: int = 1,
prepend: Array | None = None,
append: Array | None = None,
) -> Array:
res = a._tyarray.diff(
axis=axis,
n=n,
prepend=None if prepend is None else prepend._tyarray,
append=None if append is None else append._tyarray,
)
return Array._from_tyarray(res)