Source code for ndonnx._elementwise

# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
"""Element-wise free functions.

Each function directly dispatches to the inner typed array.
"""

import builtins
from collections.abc import Callable
from functools import wraps
from typing import TypeVar

import ndonnx as ndx
from ndonnx import Array, DType

from ._array_tyarray_interop import unwrap_tyarray
from ._typed_array import funcs as tyfuncs

F = TypeVar("F", bound=Callable[..., Array])


def _ensure_array_in_args(fn: F) -> F:
    """Decorator for element-wise binary functions."""

    @wraps(fn)
    def wrapped(a, b) -> Array:
        if isinstance(a, Array) or isinstance(b, Array):
            return fn(a, b)
        raise TypeError(
            f"at least one argument to '{fn.__name__}' must be of type 'ndonnx.Array'"
        )

    return wrapped  # type: ignore


[docs] @_ensure_array_in_args def add(a: Array | int | float, b: Array | int | float) -> Array: return ndx.asarray(a + b)
[docs] def abs(array: Array, /) -> Array: return Array._from_tyarray(builtins.abs(array._tyarray))
[docs] def acos(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.acos())
[docs] def acosh(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.acosh())
[docs] def asin(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.asin())
[docs] def asinh(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.asinh())
[docs] def atan(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.atan())
[docs] @_ensure_array_in_args def atan2(x1: Array | int | float, x2: Array | int | float, /) -> Array: # Requires special operator to meet standards precision requirements # TODO: Add upstream tracking issue raise NotImplementedError
[docs] def atanh(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.atanh())
[docs] @_ensure_array_in_args def bitwise_and(x1: Array | int | bool, x2: Array | int | bool, /) -> Array: return ndx.asarray(x1 & x2)
[docs] @_ensure_array_in_args def bitwise_left_shift(x1: Array | int, x2: Array | int, /) -> Array: return ndx.asarray(x1 << x2)
[docs] def bitwise_invert(x: Array, /) -> Array: return ~x
[docs] @_ensure_array_in_args def bitwise_or(x1: Array | int | bool, x2: Array | int | bool, /) -> Array: return ndx.asarray(x1 | x2)
[docs] @_ensure_array_in_args def bitwise_right_shift(x1: Array | int, x2: Array | int, /) -> Array: return ndx.asarray(x1 >> x2)
[docs] @_ensure_array_in_args def bitwise_xor(x1: Array | int | bool, x2: Array | int | bool, /) -> Array: return ndx.asarray(x1 ^ x2)
[docs] def ceil(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.ceil())
[docs] def clip( x: Array, /, min: None | int | float | Array = None, max: None | int | float | Array = None, ) -> Array: min_max: list[int | float | DType] = [] if min is not None: if isinstance(min, Array): min_max.append(min.dtype) else: min_max.append(min) if max is not None: if isinstance(max, Array): min_max.append(max.dtype) else: min_max.append(max) dtype = tyfuncs.result_type(x.dtype, *min_max) min_ = None if min is None else ndx.asarray(min, dtype=dtype)._tyarray max_ = None if max is None else ndx.asarray(max, dtype=dtype)._tyarray return Array._from_tyarray(x._tyarray.astype(dtype).clip(min=min_, max=max_))
[docs] def cos(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.cos())
[docs] def cosh(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.cosh())
[docs] @_ensure_array_in_args def copysign(x1: Array | int | float, x2: Array | int | float, /) -> Array: raise NotImplementedError
[docs] @_ensure_array_in_args def divide(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 / x2)
[docs] def exp(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.exp())
def expm1(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.expm1())
[docs] @_ensure_array_in_args def equal(x1: Array | int | float | bool, x2: Array | int | float | bool, /) -> Array: return ndx.asarray(x1 == x2)
[docs] def floor(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.floor())
[docs] @_ensure_array_in_args def floor_divide(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 // x2)
[docs] @_ensure_array_in_args def greater(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 > x2)
[docs] @_ensure_array_in_args def greater_equal(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 >= x2)
[docs] @_ensure_array_in_args def hypot(x1: Array | int | float, x2: Array | int | float, /) -> Array: raise NotImplementedError
[docs] def isfinite(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.isfinite())
[docs] def isinf(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.isinf())
[docs] def isnan(array: Array, /) -> Array: return Array._from_tyarray(array._tyarray.isnan())
[docs] @_ensure_array_in_args def less(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 < x2)
[docs] @_ensure_array_in_args def less_equal(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 <= x2)
[docs] def log(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.log())
def log1p(x: Array, /) -> Array: raise NotImplementedError
[docs] def log2(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.log2())
[docs] def log10(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.log10())
[docs] @_ensure_array_in_args def logaddexp(x1: Array | int | float, x2: Array | int | float, /) -> Array: return Array._from_tyarray( tyfuncs.logaddexp(unwrap_tyarray(x1), unwrap_tyarray(x2)) )
[docs] @_ensure_array_in_args def logical_and(x1: Array | bool, x2: Array | bool, /) -> Array: return Array._from_tyarray( tyfuncs.logical_and(unwrap_tyarray(x1), unwrap_tyarray(x2)) )
[docs] def logical_not(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.logical_not())
[docs] @_ensure_array_in_args def logical_or(x1: Array | bool, x2: Array | bool, /) -> Array: return Array._from_tyarray( tyfuncs.logical_or(unwrap_tyarray(x1), unwrap_tyarray(x2)) )
[docs] @_ensure_array_in_args def logical_xor(x1: Array | bool, x2: Array | bool, /) -> Array: return Array._from_tyarray( tyfuncs.logical_xor(unwrap_tyarray(x1), unwrap_tyarray(x2)) )
[docs] @_ensure_array_in_args def maximum(x1: Array | int | float, x2: Array | int | float, /) -> Array: return Array._from_tyarray(tyfuncs.maximum(unwrap_tyarray(x1), unwrap_tyarray(x2)))
[docs] @_ensure_array_in_args def minimum(x1: Array | int | float, x2: Array | int | float, /) -> Array: return Array._from_tyarray(tyfuncs.minimum(unwrap_tyarray(x1), unwrap_tyarray(x2)))
[docs] @_ensure_array_in_args def multiply(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 * x2)
[docs] def negative(x: Array, /) -> Array: return Array._from_tyarray(-x._tyarray)
[docs] @_ensure_array_in_args def nextafter(x1: Array | int | float, x2: Array | int | float, /) -> Array: # Requires special ONNX operator # TODO: Add upstream tracking issue raise NotImplementedError
[docs] @_ensure_array_in_args def not_equal( x1: Array | int | float | bool, x2: Array | int | float | bool, / ) -> Array: return ndx.asarray(x1 != x2)
[docs] def positive(x: Array, /) -> Array: return Array._from_tyarray(+x._tyarray)
[docs] @_ensure_array_in_args def pow(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1**x2)
[docs] def real(x: Array, /) -> Array: raise NotImplementedError
[docs] def reciprocal(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.reciprocal())
[docs] @_ensure_array_in_args def remainder(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 % x2)
[docs] def round(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.round())
[docs] def sign(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.sign())
[docs] def signbit(x: Array, /) -> Array: raise NotImplementedError
[docs] def sin(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.sin())
[docs] def sinh(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.sinh())
[docs] def square(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray * x._tyarray)
[docs] def sqrt(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.sqrt())
[docs] @_ensure_array_in_args def subtract(x1: Array | int | float, x2: Array | int | float, /) -> Array: return ndx.asarray(x1 - x2)
[docs] def tan(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.tan())
[docs] def tanh(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.tanh())
[docs] def trunc(x: Array, /) -> Array: return Array._from_tyarray(x._tyarray.trunc())