Source code for ndonnx._array

# Copyright (c) QuantCo 2023-2026
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import math
import operator as std_ops
from collections.abc import Callable
from enum import Enum
from types import EllipsisType
from typing import Any

import numpy as np
from spox import Var
from typing_extensions import deprecated

from ndonnx import DType

from ._namespace_info import Device, device
from ._typed_array import TyArrayBase, onnx
from ._typed_array import funcs as tyfuncs
from ._typed_array.masked_onnx import TyMaArray
from .extensions import get_mask
from .types import GetItemKey, OnnxShape, PyScalar, SetitemKey

_BinaryOp = Callable[
    ["Array", "PyScalar | Array | np.ndarray | np.generic"],
    "Array",
]
_Axisparam = int | tuple[int, ...] | None


def _build_forward(
    std_op: Callable[[TyArrayBase, PyScalar], TyArrayBase],
    sigil: str,
    this_name: str,
    reflected_name: str,
) -> _BinaryOp:
    def fun(self, rhs: PyScalar | Array | np.ndarray | np.generic) -> Array:
        if isinstance(rhs, np.ndarray | np.generic):
            rhs = Array._constant(value=np.asarray(rhs), dtype=None)
        if isinstance(rhs, PyScalar):
            # Note: NumPy generic are subclasses of Python scalars in np1x
            return Array._from_tyarray(std_op(self._tyarray, rhs))
        if not isinstance(rhs, Array):
            return NotImplemented
        res = getattr(self._tyarray, this_name)(rhs._tyarray)
        if res is NotImplemented:
            res = getattr(rhs._tyarray, reflected_name)(self._tyarray)
        if res is NotImplemented:
            raise TypeError(
                f"unsupported operand (data) types for `{sigil}`: `{self.dtype}` and `{rhs.dtype}`"
            )
        return Array._from_tyarray(res)

    return fun


def _build_backward(
    std_op: Callable[[TyArrayBase | PyScalar, TyArrayBase | PyScalar], TyArrayBase],
    sigil: str,
    this_name: str,
    reflected_name: str,
) -> _BinaryOp:
    def fun(self, lhs: PyScalar | Array | np.ndarray | np.generic) -> Array:
        if isinstance(lhs, np.ndarray | np.generic):
            lhs = Array._constant(value=np.asarray(lhs), dtype=None)
        if isinstance(lhs, PyScalar):
            # Note: NumPy generic are subclasses of Python scalars in np1x
            return Array._from_tyarray(std_op(lhs, self._tyarray))
        if not isinstance(lhs, Array):
            return NotImplemented
        res = getattr(self._tyarray, this_name)(lhs._tyarray)
        if res is NotImplemented:
            res = getattr(lhs._tyarray, reflected_name)(self._tyarray)
        if res is NotImplemented:
            raise TypeError(
                f"unsupported operand (data) types for `{sigil}`: `{self.dtype}` and `{lhs.dtype}`"
            )
        return Array._from_tyarray(res)

    return fun


def _make_binary_dunder(
    std_op: Callable[[TyArrayBase | PyScalar, TyArrayBase | PyScalar], TyArrayBase],
    sigil: str,
    forward_name: str,
    backward_name: str,
) -> tuple[_BinaryOp, _BinaryOp]:
    """Create a forward and reflected version for a binary dunder method."""
    # If we return 'NotImplemented' from methods such as __add__ the
    # interpreter will create an error message that does not display
    # the arrays dtype. E.g. `"TypeError: ... +: Not Implemented for
    # 'Array' and 'Other'`. This is ok, if `Other` is some unrelated
    # class, but we would not want to show an error massage such as
    # `"TypeError: ... +: Not Implemented for 'Array' and 'Array'` in
    # cases where the data types are not compatible.
    #
    # We want to cover the following scenarios for incompatible operands.
    #
    # Other() + Array:
    #   -> Array.__radd__:
    #     -> return NotImplemented
    # Array + Other():
    #   -> Array.__add__:
    #     -> return NotImplemented
    # Array(dtype1) + Array(dtype2):
    #   -> Option 1:
    #     -> operator.add(TyArray(dtype1), TyArray(dtype2)):
    #        - Tries __add__ and __radd__ on the TyArray objects
    #        - Raise TypeError with bad error message
    #        - Catch and raise new error with better message
    #   -> Option 2:
    #     -> Manually try __add__ and __radd__ on the TyArray objects
    #       -> Pass through the returned NotImplemented object
    #       -> Still a bad error message for the user, but does not expose internal class names
    #   -> Option 3 (Taken in the current implementation):
    #     -> Manually try __add__ and __radd__ on the TyArray objects
    #       -> Check for NotImplemented
    #       -> Raise error with nice error message
    #       -> A bit more cumbersome to implement
    return _build_forward(std_op, sigil, forward_name, backward_name), _build_backward(
        std_op, sigil, backward_name, forward_name
    )


[docs] class Array: """User-facing objects that makes no assumption about any data type related logic.""" _tyarray: TyArrayBase # `__array_priority__` governs which operand is first called in # operations involving a NumPy array/generic and a custom class # (like this one). If not set, the first operand is always called # first. This is problematic since NumPy does not directly return # NotImplemented if called with an ndonnx.Array object. Setting # this priority higher than that of a NumPy array (i.e. 0) ensures # that a situation such as `np.ndarray + ndx.Array` will first # call `ndx.Array.__radd__`. __array_priority__ = 1 def __init__(self, *args, **kwargs) -> None: raise TypeError( "'Array' cannot be instantiated directly. Use the 'ndonnx.array' or 'ndonnx.asarray' functions instead" ) @classmethod def _argument(cls, /, *, shape: OnnxShape, dtype: DType) -> Array: inst = cls.__new__(cls) inst._tyarray = dtype.__ndx_argument__(shape) return inst @classmethod def _constant( cls, /, *, value: PyScalar | np.ndarray, dtype: DType | None ) -> Array: return cls._from_tyarray(tyfuncs.astyarray(value, dtype=dtype)) @classmethod def _from_tyarray(cls, tyarray: TyArrayBase, /) -> Array: if not isinstance(tyarray, TyArrayBase): raise TypeError(f"expected 'TypedArrayBase', found `{type(tyarray)}`") inst = cls.__new__(cls) inst._tyarray = tyarray return inst @property def device(self) -> Device: return device @property def dtype(self) -> DType: return self._tyarray.dtype @property def dynamic_shape(self) -> Array: """Runtime shape of this array as a 1D int64 tensor.""" shape = self._tyarray.dynamic_shape return Array._from_tyarray(shape) @property def mT(self) -> Array: # noqa: N802 return Array._from_tyarray(self._tyarray.mT) @property def ndim(self) -> int: return len(self.shape) @property def shape(self) -> tuple[int | None, ...]: shape = self._tyarray.shape return tuple(None if isinstance(item, str) else item for item in shape) @property def size(self) -> int | None: static_dims = [] for el in self.shape: if el is None: return None static_dims.append(el) return math.prod(static_dims) @property def dynamic_size(self) -> Array: """Return the size of an array as scalar array. Contrary to `Array.size` this function also works on dynamically sized arrays. """ # Special cases that allow for shortcuts if self.ndim == 0: return Array._from_tyarray(onnx.const(1, dtype=onnx.int64)) size = self._tyarray.dynamic_size return Array._from_tyarray(size) @property def T(self) -> Array: # noqa: N802 return Array._from_tyarray(self._tyarray.T) @property @deprecated( "'Array.null' is deprecated in favor of 'ndonnx.extensions.get_mask'", ) def null(self) -> None | Array: return get_mask(self) @property @deprecated( "'Array.values' is deprecated in favor of 'ndonnx.extensions.get_data'", ) def values(self) -> Array: if isinstance(self._tyarray, TyMaArray): return Array._from_tyarray(self._tyarray.data) if isinstance(self._tyarray, onnx.TyArray): return Array._from_tyarray(self._tyarray) raise ValueError(f"`{self.dtype}` is not a nullable built-in type")
[docs] def astype(self, dtype: DType, *, copy=True) -> Array: new_data = self._tyarray.astype(dtype, copy=copy) return Array._from_tyarray(new_data)
[docs] def copy(self) -> Array: return Array._from_tyarray(self._tyarray.copy())
[docs] @deprecated( "'Array.to_numpy' is deprecated in favor of 'Array.unwrap_numpy'", ) def to_numpy(self) -> np.ndarray | None: try: return self.unwrap_numpy() except ValueError: return None
[docs] @deprecated( "'Array.spox_var' is deprecated in favor of 'Array.disassemble' or 'Array.unwrap_spox'", ) def spox_var(self) -> Var: """Unwrap the underlying ``spox.Var`` object if ``self`` is of primitive data type. Otherwise, raise an exception. """ return self.unwrap_spox()
[docs] def unwrap_spox(self) -> Var: """Unwrap the underlying ``spox.Var`` object if ``self`` is of primitive data type. Otherwise, raise an exception. """ if isinstance(self._tyarray, onnx.TyArray): return self._tyarray.disassemble() raise TypeError( "cannot safely unwrap underlying 'spox.Var' object(s) " f"from array of data type `{self.dtype}`" )
[docs] def to_device(self, device: Any, /, *, stream: int | Any | None = None) -> Array: raise ValueError("ONNX provides no control over the used device")
[docs] def unwrap_numpy(self) -> np.ndarray: """Return the propagated value as a NumPy array if available. Raises ------ ValueError: If no propagated value is available. """ return self._tyarray.unwrap_numpy()
[docs] def disassemble(self) -> dict[str, Var] | Var: """Disassemble into the constituent ``spox.Var`` objects. The particular layout depends on the data type. """ return self._tyarray.disassemble()
def __dlpack__( self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, dl_device: tuple[Enum, int] | None = None, copy: bool | None = None, ) -> Any: raise BufferError("ndonnx does not support the export of array data") def __dlpack_device__(self) -> tuple[Enum, int]: raise ValueError("ONNX provides no control over the used device") def __iter__(self): try: n, *_ = self.shape except IndexError: raise ValueError("iteration over 0-d array") if isinstance(n, int): return (self[i, ...] for i in range(n)) raise ValueError( "iteration requires dimension of static length, but dimension 0 is dynamic" ) def __getitem__(self, key: GetItemKey, /) -> Array: idx = _normalize_arrays_in_getitem_key(key) data = self._tyarray[idx] return type(self)._from_tyarray(data) def __setitem__( self, key: SetitemKey, value: str | int | float | bool | Array, /, ) -> None: # Specs say that the data type of self must not be changed. updates = ( value._tyarray.astype(self.dtype) if isinstance(value, Array) else tyfuncs.astyarray(value, dtype=self.dtype) ) idx = _normalize_arrays_in_setitem_key(key) self._tyarray[idx] = updates def __bool__(self, /) -> bool: return bool(self.unwrap_numpy()) def __complex__(self, /) -> complex: return self.unwrap_numpy().__complex__() def __float__(self, /) -> float: return float(self.unwrap_numpy()) def __index__(self, /) -> int: return self.unwrap_numpy().__index__() def __int__(self, /) -> int: return int(self.unwrap_numpy()) def __array_namespace__(self, /, *, api_version: str | None = None) -> Any: import ndonnx as ndx return ndx # We spell out __eq__ and __ne__ so that mypy may pick up the # change in return type (Array rather than bool) def __eq__(self, other: PyScalar | Array | np.ndarray | np.generic) -> Array: # type: ignore[override] if not isinstance(other, PyScalar | Array | np.ndarray | np.generic): return NotImplemented return Array._from_tyarray(self._tyarray == _astyarray_or_pyscalar(other)) def __ne__(self, other: PyScalar | Array | np.ndarray | np.generic) -> Array: # type: ignore[override] if not isinstance(other, PyScalar | Array | np.ndarray | np.generic): return NotImplemented return Array._from_tyarray(self._tyarray != _astyarray_or_pyscalar(other)) __add__, __radd__ = _make_binary_dunder(std_ops.add, "+", "__add__", "__radd__") __and__, __rand__ = _make_binary_dunder(std_ops.and_, "&", "__and__", "__rand__") __floordiv__, __rfloordiv__ = _make_binary_dunder( std_ops.floordiv, "//", "__floordiv__", "__rfloordiv__" ) __ge__, __le__ = _make_binary_dunder(std_ops.ge, ">=", "__ge__", "__le__") __gt__, __lt__ = _make_binary_dunder(std_ops.gt, ">", "__gt__", "__lt__") __lshift__, __rlshift__ = _make_binary_dunder( std_ops.lshift, "<<", "__lshift__", "__rlshift__" ) __matmul__, __rmatmul__ = _make_binary_dunder( std_ops.matmul, "@", "__matmul__", "__rmatmul__" ) __mod__, __rmod__ = _make_binary_dunder(std_ops.mod, "%", "__mod__", "__rmod__") __mul__, __rmul__ = _make_binary_dunder(std_ops.mul, "*", "__mul__", "__rmul__") __or__, __ror__ = _make_binary_dunder(std_ops.or_, "|", "__or__", "__ror__") __pow__, __rpow__ = _make_binary_dunder(std_ops.pow, "**", "__pow__", "__rpow__") __rshift__, __rrshift__ = _make_binary_dunder( std_ops.rshift, ">>", "__rshift__", "__rrshift__" ) __sub__, __rsub__ = _make_binary_dunder(std_ops.sub, "-", "__sub__", "__rsub__") __truediv__, __rtruediv__ = _make_binary_dunder( std_ops.truediv, "/", "__truediv__", "__rtruediv__" ) __xor__, __rxor__ = _make_binary_dunder(std_ops.xor, "^", "__xor__", "__rxor__") def __abs__(self, /) -> Array: data = self._tyarray.__abs__() return Array._from_tyarray(data) def __invert__(self, /) -> Array: return Array._from_tyarray(~self._tyarray) def __neg__(self, /) -> Array: return Array._from_tyarray(-self._tyarray) def __pos__(self, /) -> Array: return Array._from_tyarray(+self._tyarray) def __repr__(self) -> str: value_repr = ", ".join( [f"{k}: {v}" for k, v in self._tyarray.__ndx_value_repr__().items()] ) # We only add shape information if we don't have a constant value to show shape_info = ( "" if self._tyarray.is_constant else f" shape={self._tyarray.shape}," ) return f"array({value_repr},{shape_info} dtype={self.dtype})" # Non-standard functions exposed by NumPy and ndonnx <=0.9
[docs] def sum(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.sum` for documentation.""" return Array._from_tyarray(self._tyarray.sum(axis=axis, keepdims=keepdims))
[docs] def prod(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.prod` for documentation.""" return Array._from_tyarray(self._tyarray.prod(axis=axis, keepdims=keepdims))
[docs] def max(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.max` for documentation.""" return Array._from_tyarray(self._tyarray.max(axis=axis, keepdims=keepdims))
[docs] def min(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.min` for documentation.""" return Array._from_tyarray(self._tyarray.min(axis=axis, keepdims=keepdims))
[docs] def all(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.all` for documentation.""" return Array._from_tyarray(self._tyarray.all(axis=axis, keepdims=keepdims))
[docs] def any(self, axis: _Axisparam = 0, keepdims: bool = False) -> Array: """See :py:func:`ndonnx.any` for documentation.""" return Array._from_tyarray(self._tyarray.any(axis=axis, keepdims=keepdims))
def _astyarray_or_pyscalar( val: PyScalar | Array | Var | np.ndarray | np.generic, ) -> TyArrayBase | PyScalar: if isinstance(val, np.generic): val = np.asarray(val) if isinstance(val, Array): return val._tyarray if isinstance(val, int | float | str): return val return tyfuncs.astyarray(val) def _normalize_arrays_in_getitem_key(key: GetItemKey) -> onnx.GetitemIndex: if isinstance(key, Array): if isinstance(key._tyarray.dtype, onnx.Bool): return key._tyarray.astype(onnx.bool_) if isinstance(key, int | slice | EllipsisType | Array | None): return _normalize_getitem_key_item(key) if isinstance(key, tuple): return tuple(_normalize_getitem_key_item(el) for el in key) raise IndexError(f"unexpected key type: `{type(key)}`") def _normalize_arrays_in_setitem_key(key: SetitemKey) -> onnx.SetitemIndex: if isinstance(key, Array): if isinstance(key._tyarray.dtype, onnx.Bool): return key._tyarray.astype(onnx.bool_) if isinstance(key, int | slice | EllipsisType | Array | None): return _normalize_setitem_key_item(key) if isinstance(key, tuple): return tuple(_normalize_setitem_key_item(el) for el in key) raise IndexError(f"unexpected key type: `{type(key)}`") def _normalize_getitem_key_item( item: int | slice | EllipsisType | Array | None, ) -> int | slice | EllipsisType | onnx.TyArrayInt64 | None: return None if item is None else _normalize_setitem_key_item(item) def _normalize_setitem_key_item( item: int | slice | EllipsisType | Array, ) -> int | slice | EllipsisType | onnx.TyArrayInt64: if isinstance(item, int | EllipsisType): return item if isinstance(item, Array): if isinstance(item.dtype, onnx.Integer): return item._tyarray.astype(onnx.int64) raise IndexError( f"indexing arrays must be of integer or boolean data type; found `{item.dtype}`" ) def _normalize_slice_arg(el: int | Array | None) -> int | onnx.TyArrayInt64 | None: if isinstance(el, int | None): return el if not isinstance(el.dtype, onnx.Integer): raise IndexError( f"arrays in 'slice' objects must be of integer data types; found `{el.dtype}" ) if el.ndim != 0: raise IndexError( f"arrays in 'slice' objects must be rank-0; found `{el.ndim}" ) return el._tyarray.astype(onnx.int64) if isinstance(item, slice): start = _normalize_slice_arg(item.start) stop = _normalize_slice_arg(item.stop) step = _normalize_slice_arg(item.step) return slice(start, stop, step) raise IndexError(f"invalid index key type: `{type(item)}`")