Source code for ndonnx._dtypes

# Copyright (c) QuantCo 2023-2026
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from abc import ABC, abstractmethod
from types import NotImplementedType
from typing import TYPE_CHECKING, Generic, TypeVar

import numpy as np
from spox import Var
from typing_extensions import deprecated

if TYPE_CHECKING:
    from ndonnx.types import NestedSequence, OnnxShape, PyScalar

    from ._schema import DTypeInfoV1
    from ._typed_array import TyArrayBase, onnx


TY_ARRAY_BASE = TypeVar("TY_ARRAY_BASE", bound="TyArrayBase", covariant=True)


[docs] class DType(ABC, Generic[TY_ARRAY_BASE]): def __ndx_create__( self, val: PyScalar | np.ndarray | TyArrayBase | Var | NestedSequence, ) -> TY_ARRAY_BASE | NotImplementedType: return NotImplemented @abstractmethod def __ndx_result_type__( self, other: DType | PyScalar ) -> DType | NotImplementedType: ... @abstractmethod def __ndx_cast_from__(self, arr: TyArrayBase) -> TY_ARRAY_BASE: """Convert the given array to this data type. This function is used to implement ``TyArrayBase.astype`` and should not be called directly. """ ... @abstractmethod def __ndx_argument__(self, shape: OnnxShape) -> TY_ARRAY_BASE: ... @property @abstractmethod def __ndx_infov1__(self) -> DTypeInfoV1: raise NotImplementedError # Construction functions def __ndx_arange__( self, start: int | float | TyArrayBase, stop: int | float | TyArrayBase, step: int | float | TyArrayBase = 1, ) -> TY_ARRAY_BASE: return NotImplemented def __ndx_eye__( self, n_rows: int, n_cols: int | None = None, /, *, k: int = 0, ) -> TY_ARRAY_BASE: return NotImplemented def __ndx_ones__(self, shape: tuple[int, ...] | onnx.TyArrayInt64) -> TY_ARRAY_BASE: return NotImplemented def __ndx_zeros__( self, shape: tuple[int, ...] | onnx.TyArrayInt64 ) -> TY_ARRAY_BASE: return NotImplemented def __eq__(self, other) -> bool: if type(self) is not type(other): return False return (self.__dict__ == other.__dict__) and (self.__slots__ == other.__slots__) def __hash__(self) -> int: return hash((tuple(sorted(self.__dict__.items())), self.__slots__)) def __repr__(self) -> str: return self.__class__.__name__
[docs] @deprecated( "'to_numpy_dtype' is deprecated. Use the 'unwrap_numpy' method instead", ) def to_numpy_dtype(self) -> np.dtype: return self.unwrap_numpy()
[docs] def unwrap_numpy(self) -> np.dtype: raise ValueError(f"`{self}` provides no corresponding NumPy data type")