Source code for ndonnx._build

# Copyright (c) QuantCo 2023-2026
# SPDX-License-Identifier: BSD-3-Clause

import onnx
from spox import Var
from spox import build as spox_build

from ._array import Array
from ._schema import SchemaV1


[docs] def build( inputs: dict[str, Array], outputs: dict[str, Array], drop_unused=False ) -> onnx.ModelProto: """Build and ONNX model from the provided argument-Arrays and outputs. Parameters ---------- inputs Inputs of the model outputs Outputs of the model Returns ------- onnx.ModelProto ONNX model """ ins = _arrays_to_vars(inputs) outs = _arrays_to_vars(outputs) mp = spox_build(ins, outs, drop_unused_inputs=drop_unused) # Find the names of inputs that have not been dropped, so that only these # are inserted into the schema. graph_inputs = [input.name for input in mp.graph.input] kept_inputs = [] for name, arr in inputs.items(): if any(n in graph_inputs for n in _disassemble_named_array(name, arr)): kept_inputs.append(name) schema_v1 = { "ndonnx_schema": SchemaV1( input_schema={ k: v.dtype.__ndx_infov1__ for k, v in inputs.items() if k in kept_inputs }, output_schema={k: v.dtype.__ndx_infov1__ for k, v in outputs.items()}, version=1, ).to_json() } onnx.helper.set_model_props(mp, schema_v1) return mp
def _arrays_to_vars(dct_of_arrs: dict[str, Array]) -> dict[str, Var]: out: dict[str, Var] = {} for name, arr in dct_of_arrs.items(): out |= _disassemble_named_array(name, arr) return out def _disassemble_named_array(name: str, arr: Array) -> dict[str, Var]: # Take an Array, and create a map of its component parts prefixed with a name. # TODO: Use a different separator for the public name and the nested components? public_separator = "_" out: dict[str, Var] = {} components = arr._tyarray.disassemble() if isinstance(components, Var): out[name] = components else: for k, v in components.items(): out[f"{name}{public_separator}{k}"] = v return out