diff --git a/dali/python/nvidia/dali/experimental/dali2/__init__.py b/dali/python/nvidia/dali/experimental/dali2/__init__.py index a0fd204698..334c9b35b1 100644 --- a/dali/python/nvidia/dali/experimental/dali2/__init__.py +++ b/dali/python/nvidia/dali/experimental/dali2/__init__.py @@ -20,3 +20,11 @@ from ._eval_context import * # noqa: F401, F403 from ._type import * # noqa: F401, F403 from ._device import * # noqa: F401, F403 +from ._tensor import Tensor, tensor, as_tensor # noqa: F401 +from ._batch import Batch, batch, as_batch # noqa: F401 + +from . import _fn +from . import ops + +ops._initialize() +_fn._initialize() diff --git a/dali/python/nvidia/dali/experimental/dali2/_batch.py b/dali/python/nvidia/dali/experimental/dali2/_batch.py new file mode 100644 index 0000000000..38c863c119 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/dali2/_batch.py @@ -0,0 +1,600 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union, Sequence +from ._type import DType, dtype as _dtype, type_id as _type_id +from ._tensor import ( + Tensor, + _is_full_slice, + _try_convert_enums, + tensor as _tensor, + as_tensor as _as_tensor, +) +import nvidia.dali.backend as _backend +from ._eval_context import EvalContext as _EvalContext +from ._device import Device +from . import _eval_mode +from . import _invocation +import nvtx + + +def _backend_device(backend: Union[_backend.TensorListCPU, _backend.TensorListGPU]) -> Device: + if isinstance(backend, _backend.TensorListCPU): + return Device("cpu") + elif isinstance(backend, _backend.TensorListGPU): + return Device("gpu", backend.device_id()) + else: + raise ValueError(f"Unsupported backend type: {type(backend)}") + + +def _is_tensor_type(x): + from . import _batch + + if isinstance(x, _batch.Batch): + raise ValueError("A list of Batch objects is not a valid argument type") + if isinstance(x, Tensor): + return True + if hasattr(x, "__array__"): + return True + if hasattr(x, "__cuda_array_interface__"): + return True + if hasattr(x, "__dlpack__"): + return True + return False + + +def _get_batch_size(x): + if isinstance(x, Batch): + return x.batch_size + if isinstance(x, (_backend.TensorListCPU, _backend.TensorListGPU)): + return len(x) + return None + + +class BatchedSlice: + def __init__(self, batch: "Batch"): + self._batch = batch + + def __getitem__(self, ranges: Any) -> "Batch": + if not isinstance(ranges, tuple): + ranges = (ranges,) + if len(ranges) == 0: + return self._batch + + if all(_is_full_slice(r) for r in ranges): + return self._batch + + args = {} + d = 0 + for i, r in enumerate(ranges): + if r is Ellipsis: + d = self._batch.ndim - len(ranges) + i + 1 + elif isinstance(r, slice): + if r.start is not None: + args[f"lo_{d}"] = r.start + if r.stop is not None: + args[f"hi_{d}"] = r.stop + if r.step is not None: + args[f"step_{d}"] = r.step + d += 1 + else: + args[f"at_{d}"] = r + d += 1 + + # print(args) + + from . import tensor_subscript + + return tensor_subscript(self._batch, **args) + + +def _arithm_op(name, *args, **kwargs): + from . import arithmetic_generic_op + + argsstr = " ".join(f"&{i}" for i in range(len(args))) + return arithmetic_generic_op(*args, expression_desc=f"{name}({argsstr})") + + +class _TensorList: + def __init__(self, batch: "Batch", indices: Optional[Union[list[int], range]] = None): + self._batch = batch + self._indices = indices or range(batch.batch_size) + + def __getitem__(self, range): + return self.select(range) + + def __len__(self): + return len(self._indices) + + def select(self, range): + if range == slice(None, None, None): + return self + if isinstance(range, slice): + return _TensorList(self._batch, self._indices[range]) + elif isinstance(range, list): + return _TensorList(self._batch, [self._indices[i] for i in range]) + else: + return self._batch.select(range) + + def tolist(): + return [self._batch._get_tensor(i) for i in self._indices] + + def as_batch(self): + return as_batch(self) + + +class Batch: + def __init__( + self, + tensors: Optional[Any] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, + invocation_result: Optional[_invocation.InvocationResult] = None, + copy: bool = False, + ): + assert isinstance(layout, str) or layout is None + self._wraps_external_data = False + self._tensors = None + self._backend = None + self._dtype = None + self._device = None + copied = False + if tensors is not None: + if isinstance(tensors, _backend.TensorListCPU) and ( + (dtype is None or _type_id(dtype) == _backend.dtype) + and (layout is None or layout == self._backend.layout) + ): + self._backend = tensors + self._ndim = self._backend.ndim() + self._dtype = dtype = DType.from_type_id(self._backend.dtype) + self._layout = layout = self._backend.layout() + if device is None: + device = Device("cpu") + else: + if device.device_type != "cpu": + copy = True + elif isinstance(tensors, _backend.TensorListGPU) and ( + (dtype is None or _type_id(dtype) == _backend.dtype) + and (layout is None or layout == self._backend.layout) + ): + self._backend = tensors + self._ndim = self._backend.ndim() + self._dtype = dtype = DType.from_type_id(self._backend.dtype) + self._layout = layout = self._backend.layout() + if device is None: + device = Device("gpu", self._backend.device_id) + else: + if device.device_type != "gpu": + copy = True + elif _is_tensor_type(tensors): + if copy: + t = _tensor(tensors, dtype=dtype, device=device, layout=layout) + else: + t = _as_tensor(tensors, dtype=dtype, device=device, layout=layout) + if t.ndim == 0: + raise ValueError("Cannot create a batch from a scalar") + if dtype is None: + dtype = t.dtype + if device is None: + device = t.device + if layout is None: + layout = t.layout + if t._backend is not None: + if isinstance(t._backend, _backend.TensorCPU): + self._backend = _backend.TensorListCPU(t._backend, layout=layout) + elif isinstance(t._backend, _backend.TensorGPU): + self._backend = _backend.TensorListGPU(t._backend, layout=layout) + else: + raise ValueError(f"Unsupported device type: {t.device.device_type}") + self._wraps_external_data = True + else: + sh = t.shape + tensors = [t[i] for i in range(sh[0])] + self._dtype = dtype + + elif len(tensors) == 0: + if dtype is None: + raise ValueError("Element type must be specified if the list is empty") + if device is None: + device = Device("cpu") + if layout is None: + layout = "" + self._dtype = dtype + else: + self._tensors = [] + for t in tensors: + sample = Tensor(t, dtype=dtype, device=device, layout=layout) + if dtype is None: + dtype = sample.dtype + if device is None: + device = sample.device + if layout is None: + layout = sample.layout + self._tensors.append(sample) + if sample._wraps_external_data: + self._wraps_external_data = True + else: + if not isinstance(t, Tensor) or t._backend is not sample._backend: + copied = True + self._dtype = dtype + + if self._dtype is None: + if self._backend is not None: + self._dtype = DType.from_type_id(self._backend.dtype) + else: + self._dtype = dtype + if self._device is None: + if self._backend is not None: + self._device = _backend_device(self._backend) + else: + self._device = device + self._layout = layout + self._invocation_result = invocation_result + self._ndim = None + if self._tensors and self._tensors[0]._shape: + self._ndim = len(self._tensors[0]._shape) + + if copy and self._backend is not None and not copied: + dev = self.to_device(self.device, force_copy=True) + if dtype is not None and dev.dtype != dtype: + from . import cast + + dev = cast(dev, dtype, device=device) + self.assign(dev.evaluate()) + copied = True + else: + if self._dtype is not None and dtype is not None and self._dtype != dtype: + from . import cast + + self.assign(cast(self, dtype, device=device)) + + if _eval_mode.EvalMode.current().value >= _eval_mode.EvalMode.eager.value: + self.evaluate() + + def _is_external(self) -> bool: + return self._wraps_external_data + + @staticmethod + def broadcast(sample, batch_size: int, device: Optional[Device] = None) -> "Batch": + if isinstance(sample, Batch): + raise ValueError("Cannot broadcast a Batch") + if _is_tensor_type(sample): + # TODO(michalz): Add broadcasting in native code + return Batch([Tensor(sample, device=device)] * batch_size) + import numpy as np + + with nvtx.annotate("to numpy and stack", domain="batch"): + arr = np.array(sample) + converted_dtype_id = None + if arr.dtype == np.float64: + arr = arr.astype(np.float32) + elif arr.dtype == np.int64: + arr = arr.astype(np.int32) + elif arr.dtype == np.uint64: + arr = arr.astype(np.uint32) + elif arr.dtype == object: + arr, converted_dtype_id = _try_convert_enums(arr) + arr = np.repeat(arr[np.newaxis], batch_size, axis=0) + + with nvtx.annotate("to backend", domain="batch"): + tl = _backend.TensorListCPU(arr) + if converted_dtype_id is not None: + tl.reinterpret(converted_dtype_id) + with nvtx.annotate("create batch", domain="batch"): + return Batch(tl, device=device) + + @property + def dtype(self) -> DType: + if self._dtype is None: + if self._backend is not None: + self._dtype = DType.from_type_id(self._backend.dtype) + elif self._invocation_result is not None: + self._dtype = _dtype(self._invocation_result.dtype) + elif self._tensors: + self._dtype = self._tensors[0].dtype + else: + raise ValueError("Cannot establish the number of dimensions of an empty Batch") + return self._dtype + + @property + def device(self) -> Device: + if self._device is None: + if self._invocation_result is not None: + self._device = self._invocation_result.device + # print("From invocation result", self._device) + elif self._tensors: + self._device = self._tensors[0].device + # print("From tensors", self._device) + else: + raise ValueError("Cannot establish the number of dimensions of an empty Batch") + return self._device + + @property + def layout(self) -> str: + if self._layout is None: + if self._invocation_result is not None: + self._layout = self._invocation_result.layout + elif self._tensors: + self._layout = self._tensors[0].layout + else: + raise ValueError("Cannot establish the number of dimensions of an empty Batch") + return self._layout + + @property + def ndim(self) -> int: + if self._ndim is None: + if self._invocation_result is not None: + self._ndim = self._invocation_result.ndim + elif self._backend is not None: + self._ndim = self._backend.ndim() + elif self._tensors: + self._ndim = self._tensors[0].ndim + else: + raise ValueError("Cannot establish the number of dimensions of an empty Batch") + return self._ndim + + @property + def tensors(self): + return _TensorList(self) + + def to_device(self, device: Device, force_copy: bool = False) -> "Batch": + if self.device == device and not force_copy: + return self + else: + with device: + from . import copy + + return copy(self, device=device.device_type) + + def cpu(self) -> "Batch": + return self.to_device(Device("cpu")) + + def gpu(self, index: Optional[int] = None) -> "Batch": + return self.to_device(Device("gpu", index)) + + @property + def slice(self): + """Interface for samplewise slicing. + + Regular slicing selects samples first and then slices each sample with common + slicing parameters. + + Samplewise slicing interface allows the slicing parmaters to be batches (with the same + number of samples) and the slicing parameters are applied to respective samples. + + ```Python + start = Batch([1, 2, 3]) + stop = Batch([4, 5, 6]) + step = Batch([1, 1, 2]) + sliced = input[start, stop, step] + # the result is equivalent to + sliced = Batch([ + sample[start[i]:stop[i]:step[i]] + for i, sample in enumerate(input) + ]) + ``` + + If the slicing parameters are not batches, they are broadcast to all samples. + """ + return BatchedSlice(self) + + def __iter__(self): + return iter(self.tensors) + + def select(self, r): + if r is ...: + return self + if isinstance(r, slice): + return Batch(self.tensors[r]) + elif isinstance(r, list): + return Batch(self.tensors[r]) + else: + return self._get_tensor(r) + + def _get_tensor(self, i): + if self._tensors is None: + self._tensors = [None] * self.batch_size + + t = self._tensors[i] + if t is None: + t = self._tensors[i] = Tensor(batch=self, index_in_batch=i) + if self._backend: + t._backend = self._backend[i] + return t + + def _plain_slice(self, ranges): + def _is_batch(x): + return _get_batch_size(x) is not None + + for r in ranges: + is_batch_arg = _is_batch(r) + if isinstance(r, slice): + if _is_batch(r.start) or _is_batch(r.stop) or _is_batch(r.step): + is_batch_arg = True + if is_batch_arg: + raise ValueError( + "Cannot use a batch as an index or slice. in ``Batch.__getitem__``.\n" + "Use ``.slice`` property to perform samplewise slicing." + ) + # print(ranges) + return self.slice.__getitem__(ranges) + + @property + def batch_size(self) -> int: + if self._backend is not None: + return len(self._backend) + elif self._tensors is not None: + return len(self._tensors) + elif self._invocation_result is not None: + return self._invocation_result.batch_size + else: + raise ValueError("Neither tensors nor invocation result are set") + + def _is_same_batch(self, other: "Batch") -> bool: + if self is other: + return True + return ( + self._backend is other._backend + and self._invocation_result is other._invocation_result + and ( + self._tensors is other._tensors + or [t._is_same_tensor(ot) for t, ot in zip(self._tensors, other._tensors)] + ) + ) + + @property + def shape(self): + if self._invocation_result is not None: + return self._invocation_result.shape + if self._backend is not None: + return self._backend.shape() + else: + assert self._tensors is not None + return [t.shape for t in self._tensors] + + def __str__(self) -> str: + return "Batch(\n" + str(self.evaluate()._backend) + ")" + + def evaluate(self): + with _EvalContext.get() as ctx: + if self._backend is None: + if self._invocation_result is not None: + self._backend = self._invocation_result.value(ctx) + else: + if self._device.device_type == "cpu": + backend_type = _backend.TensorListCPU + elif self._device.device_type == "gpu": + backend_type = _backend.TensorListGPU + else: + raise ValueError( + f"Internal error: Unsupported device type: {self._device.device_type}" + ) + self._backend = backend_type( + [t.evaluate()._backend for t in self._tensors], self.layout + ) + return self + + def __add__(self, other): + return _arithm_op("add", self, other) + + def __radd__(self, other): + return _arithm_op("add", other, self) + + def __sub__(self, other): + return _arithm_op("sub", self, other) + + def __rsub__(self, other): + return _arithm_op("sub", other, self) + + def __mul__(self, other): + return _arithm_op("mul", self, other) + + def __rmul__(self, other): + return _arithm_op("mul", other, self) + + def __pow__(self, other): + return _arithm_op("pow", self, other) + + def __rpow__(self, other): + return _arithm_op("pow", other, self) + + def __truediv__(self, other): + return _arithm_op("fdiv", self, other) + + def __rtruediv__(self, other): + return _arithm_op("fdiv", other, self) + + def __floordiv__(self, other): + return _arithm_op("div", self, other) + + def __rfloordiv__(self, other): + return _arithm_op("div", other, self) + + def __neg__(self): + return _arithm_op("minus", self) + + # Short-circuiting the execution, unary + is basically a no-op + def __pos__(self): + return self + + def __eq__(self, other): + return _arithm_op("eq", self, other) + + def __ne__(self, other): + return _arithm_op("neq", self, other) + + def __lt__(self, other): + return _arithm_op("lt", self, other) + + def __le__(self, other): + return _arithm_op("leq", self, other) + + def __gt__(self, other): + return _arithm_op("gt", self, other) + + def __ge__(self, other): + return _arithm_op("geq", self, other) + + def __and__(self, other): + return _arithm_op("bitand", self, other) + + def __rand__(self, other): + return _arithm_op("bitand", other, self) + + def __or__(self, other): + return _arithm_op("bitor", self, other) + + def __ror__(self, other): + return _arithm_op("bitor", other, self) + + def __xor__(self, other): + return _arithm_op("bitxor", self, other) + + def __rxor__(self, other): + return _arithm_op("bitxor", other, self) + + +def batch( + tensors: Union[Batch, Sequence[Any]], + dtype: Optional[DType] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, +): + if isinstance(tensors, Batch): + b = tensors.to_device(device, force_copy=True) + if dtype is not None and b.dtype != dtype: + from . import cast + + b = cast(b, dtype, device=device) + return b.evaluate() + else: + return Batch(tensors, dtype=dtype, device=device, layout=layout, copy=True) + + +def as_batch( + tensors: Union[Batch, Sequence[Any]], + dtype: Optional[DType] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, +): + if isinstance(tensors, Batch): + b = tensors.to_device(device) + if dtype is not None and b.dtype != dtype: + from . import cast + + b = cast(b, dtype, device=device) + return b + else: + return Batch(tensors, dtype=dtype, device=device, layout=layout) diff --git a/dali/python/nvidia/dali/experimental/dali2/_eval_context.py b/dali/python/nvidia/dali/experimental/dali2/_eval_context.py index 17e9438fe7..bb38993ec5 100644 --- a/dali/python/nvidia/dali/experimental/dali2/_eval_context.py +++ b/dali/python/nvidia/dali/experimental/dali2/_eval_context.py @@ -43,7 +43,9 @@ def __init__(self, num_threads=None, device_id=None, cuda_stream=None): if self._cuda_stream is None and self._device.device_type == "gpu": self._cuda_stream = _b.Stream(self._device.device_id) - self._thread_pool = _b._ThreadPool(num_threads or default_num_threads) + self._thread_pool = _b._ThreadPool( + num_threads or default_num_threads, self._device.device_id + ) @staticmethod def current(): diff --git a/dali/python/nvidia/dali/experimental/dali2/_fn.py b/dali/python/nvidia/dali/experimental/dali2/_fn.py new file mode 100644 index 0000000000..57a9cb0906 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/dali2/_fn.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import ops +from . import _op_builder + + +def _initialize(): + for op in ops._all_ops: + if op.op_name.startswith("_"): + continue + if op.schema.IsStateful(): + continue + + _op_builder.build_fn_wrapper(op) diff --git a/dali/python/nvidia/dali/experimental/dali2/_op_builder.py b/dali/python/nvidia/dali/experimental/dali2/_op_builder.py new file mode 100644 index 0000000000..aee56bf8c3 --- /dev/null +++ b/dali/python/nvidia/dali/experimental/dali2/_op_builder.py @@ -0,0 +1,509 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nvidia.dali.backend as _b +from nvidia.dali.fn import _to_snake_case +import makefun +from ._batch import Batch, _get_batch_size +from ._tensor import Tensor +from . import ops +from . import _type +import types +import copy +from . import _invocation, _device, _eval_mode, _eval_context +import nvidia.dali.ops as _ops +import nvidia.dali.types +import nvtx + + +def is_external(x): + if isinstance(x, Tensor): + return x._is_external() + if isinstance(x, Batch): + return x._is_external() + return False + + +def _scalar_decay(x): + if isinstance(x, _device.Device): + return x.device_type + if isinstance(x, _type.DType): + return x.type_id + if x is str: + return nvidia.dali.types.STRING + if x is bool: + return nvidia.dali.types.BOOL + if x is int or x is float: + raise ValueError( + f"Do not use Python built-in type {x} as an argument. " + f"Use one of the DALI types instead." + ) + return x + + +def _get_input_device(x): + with nvtx.annotate("get_input_device", domain="op_builder"): + if x is None: + return None + if isinstance(x, Batch): + return x.device + if isinstance(x, Tensor): + return x.device + if isinstance(x, _b.TensorListCPU): + return _device.Device("cpu") + if isinstance(x, _b.TensorListGPU): + return _device.Device("gpu") + if hasattr(x, "__cuda_array_interface__"): + return _device.Device("gpu") + if hasattr(x, "__dlpack_device__"): + dev = x.__dlpack_device__() + if int(dev[0]) == 1 or int(dev[0]) == 3: # CPU or CPU_PINNED + return _device.Device("cpu") + elif int(dev[0]) == 2: + return _device.Device("gpu", dev[1]) + else: + raise ValueError(f"Unknown DLPack device type: {dev.type}") + if hasattr(x, "__dlpack__"): + return _device.Device("cpu") + if isinstance(x, list) and x: + return _get_input_device(x[0]) + return None + + +def _get_input_device_type(x): + dev = _get_input_device(x) + return dev.device_type if dev is not None else None + + +def _to_tensor(x, device=None): + with nvtx.annotate("to_tensor", domain="op_builder"): + if x is None: + return None + if isinstance(x, Tensor): + if device is not None: + return x.to_device(device) + return x + if isinstance(x, _invocation.InvocationResult): + if x.is_batch: + raise ValueError("Batch invocation result cannot be used as a single tensor") + return Tensor(invocation_result=x, device=device) + return Tensor(x, device=device) + + +def _to_batch(x, batch_size, device=None): + with nvtx.annotate("to_batch", domain="op_builder"): + if x is None: + return None + if isinstance(x, Batch): + if device is not None: + return x.to_device(device) + return x + if isinstance(x, _invocation.InvocationResult): + if x.is_batch: + return Batch(invocation_result=x, device=device) + else: + x = _to_tensor(x) # fall back to regular replication + actual_batch_size = _get_batch_size(x) + if actual_batch_size is not None: + if batch_size is not None and actual_batch_size != batch_size: + raise ValueError(f"Unexpected batch size: {actual_batch_size} != {batch_size}") + return Batch(x, device=device) + + return Batch.broadcast(x, batch_size, device=device) + + +_unsupported_args = {"bytes_per_sample_hint", "preserve"} + + +def _find_or_create_module(root_module, module_path): + module = root_module + for path_part in module_path: + submodule = getattr(module, path_part, None) + if submodule is None: + submodule = types.ModuleType(path_part) + setattr(module, path_part, submodule) + module = submodule + return module + + +def build_operator_class(schema): + class_name = schema.OperatorName() + module_path = schema.ModulePath() + is_reader = "readers" in module_path + if is_reader: + from .. import dali2 as parent + + module = parent + else: + module = ops + legacy_op_class = None + import nvidia.dali.ops + + legacy_op_module = nvidia.dali.ops + for path_part in module_path: + legacy_op_module = getattr(legacy_op_module, path_part) + module = _find_or_create_module(module, module_path) + + legacy_op_class = getattr(legacy_op_module, class_name) + base = ops.Operator + if "readers" in module.__name__: + base = ops.Reader + op_class = type(class_name, (base,), {}) + op_class.schema = schema + op_class.op_name = class_name + op_class.fn_name = _to_snake_case(class_name) + op_class.legacy_op = legacy_op_class + op_class.is_stateful = schema.IsStateful() + op_class._instance_cache = {} + op_class.__init__ = build_constructor(schema, op_class) + op_class.__call__ = build_call_function(schema, op_class) + op_class.__module__ = module.__name__ + op_class.__qualname__ = class_name + setattr(module, class_name, op_class) + return op_class + + +def build_constructor(schema, op_class): + stateful = op_class.is_stateful + function_name = "__init__" + + call_args = [] + for arg in schema.GetArgumentNames(): + if arg in _unsupported_args: + continue + if schema.IsTensorArgument(arg): + continue + if schema.IsArgumentOptional(arg): + call_args.append(f"{arg}=None") + else: + call_args.append(arg) + + if call_args: + call_args = ["*"] + call_args + header_args = [ + "self", + "max_batch_size=None", + "name=None", + 'device="cpu"', + "num_inputs=None", + "call_arg_names=None", + ] + call_args + header = f"__init__({', '.join(header_args)})" + + def init(self, max_batch_size, name, **kwargs): + kwargs = {k: _scalar_decay(v) for k, v in kwargs.items()} + op_class.__base__.__init__(self, max_batch_size, name, **kwargs) + if stateful: + self._call_id = 0 + + function = makefun.create_function(header, init) + function.__qualname__ = f"{op_class.__name__}.{function_name}" + + return function + + +def build_call_function(schema, op_class): + stateful = op_class.is_stateful + call_args = [] + for arg in schema.GetArgumentNames(): + if arg in _unsupported_args: + continue + if not schema.IsTensorArgument(arg): + continue + if schema.IsArgumentOptional(arg): + call_args.append(f"{arg}=None") + else: + call_args.append(arg) + + inputs = [] + min_inputs = schema.MinNumInput() + max_inputs = schema.MaxNumInput() + input_indices = {} + arguments = schema.GetArgumentNames() + for i in range(max_inputs): + if schema.HasInputDox(): + input_name = schema.GetInputName(i) + if input_name in arguments: + input_name += "_input" + else: + input_name = f"input_{i}" + input_indices[input_name] = i + if i < min_inputs: + inputs.append(f"{input_name}") + else: + inputs.append(f"{input_name}=None") + + call_args = ["*", "batch_size=None"] + call_args + if inputs: + inputs = inputs + ["/"] + header = f"__call__({', '.join(['self'] + inputs + call_args)})" + + def call(self, *raw_args, batch_size=None, **raw_kwargs): + with nvtx.annotate(f"__call__: {self.op_name}", domain="op_builder"): + self._pre_call(*raw_args, **raw_kwargs) + with nvtx.annotate("__call__: get batch size", domain="op_builder"): + is_batch = batch_size is not None + if batch_size is None: + for i, x in enumerate(list(raw_args) + list(raw_kwargs.values())): + x_batch_size = _get_batch_size(x) + if x_batch_size is not None: + is_batch = True + if batch_size is not None: + if x_batch_size != batch_size: + raise ValueError( + f"Inconsistent batch size: {x_batch_size} != {batch_size}" + ) + else: + batch_size = x_batch_size + if not is_batch: + batch_size = self._max_batch_size or 1 + + inputs = [] + kwargs = {} + + if is_batch: + with nvtx.annotate("__call__: convert to batches", domain="op_builder"): + for i, inp in enumerate(raw_args): + if inp is None: + continue + input_device = self.input_device(i, _get_input_device_type(inp)) + inp = _to_batch(inp, batch_size, device=input_device) + inputs.append(inp) + for k, v in raw_kwargs.items(): + if v is None: + continue + kwargs[k] = _to_batch(v, batch_size, device=_device.Device("cpu")) + else: + with nvtx.annotate("__call__: convert to tensors", domain="op_builder"): + for inp in raw_args: + if inp is None: + continue + inputs.append(_to_tensor(inp)) + for k, v in raw_kwargs.items(): + if v is None: + continue + kwargs[k] = _to_tensor(v) + + with nvtx.annotate("__call__: shallowcopy", domain="op_builder"): + inputs = [copy.copy(x) for x in inputs] + kwargs = {k: copy.copy(v) for k, v in kwargs.items()} + + if stateful: + call_id = self._call_id + self._call_id += 1 + else: + call_id = None + with nvtx.annotate("__call__: construct Invocation", domain="op_builder"): + invocation = _invocation.Invocation( + self, + call_id, + inputs, + kwargs, + is_batch=is_batch, + batch_size=batch_size, + previous_invocation=self._last_invocation, + ) + + if stateful: + self._last_invocation = invocation + + if ( + _eval_mode.EvalMode.current() == _eval_mode.EvalMode.eager + or _eval_mode.EvalMode.current() == _eval_mode.EvalMode.sync_cpu + or _eval_mode.EvalMode.current() == _eval_mode.EvalMode.sync_full + or ( + _eval_mode.EvalMode.current() == _eval_mode.EvalMode.default + and ( + any(is_external(x) for x in inputs) + or any(is_external(x) for x in kwargs.values()) + ) + ) + ): + # Evaluate immediately + invocation.run(_eval_context.EvalContext.get()) + else: + # Lazy evaluation + # If there's an active evaluation context, add this invocation to it. + # When leaving the context, the invocation will be evaluated if it's still alive. + ctx = _eval_context.EvalContext.current() + if ctx is not None: + ctx._add_invocation(invocation, weak=not self.is_stateful) + + if is_batch: + if len(invocation) == 1: + return Batch(invocation_result=invocation[0]) + else: + return tuple( + Batch(invocation_result=invocation[i]) for i in range(len(invocation)) + ) + else: + if len(invocation) == 1: + return Tensor(invocation_result=invocation[0]) + else: + return tuple( + Tensor(invocation_result=invocation[i]) for i in range(len(invocation)) + ) + + function = makefun.create_function(header, call) + + return function + + +def _next_pow2(x): + return 1 << (x - 1).bit_length() + + +def build_fn_wrapper(op): + schema = op.schema + module_path = schema.ModulePath() + from .. import dali2 as parent + + module = parent + for path_part in module_path: + new_module = getattr(module, path_part, None) + if new_module is None: + new_module = types.ModuleType(path_part) + setattr(module, path_part, new_module) + module = new_module + + fn_name = _to_snake_case(op.schema.OperatorName()) + inputs = [] + min_inputs = schema.MinNumInput() + max_inputs = schema.MaxNumInput() + input_indices = {} + arguments = schema.GetArgumentNames() + for i in range(max_inputs): + if schema.HasInputDox(): + input_name = schema.GetInputName(i) + if input_name in arguments: + input_name += "_input" + else: + input_name = f"input_{i}" + input_indices[input_name] = i + if i < min_inputs: + inputs.append(f"{input_name}") + else: + inputs.append(f"{input_name}=None") + + fixed_args = [] + tensor_args = [] + signature_args = ["batch_size=None, device=None"] + for arg in op.schema.GetArgumentNames(): + if arg in _unsupported_args: + continue + if op.schema.IsTensorArgument(arg): + tensor_args.append(arg) + else: + fixed_args.append(arg) + if op.schema.IsArgumentOptional(arg): + signature_args.append(f"{arg}=None") + else: + signature_args.append(arg) + + if signature_args: + signature_args = ["*"] + signature_args + if inputs: + inputs = inputs + ["/"] + header = f"{fn_name}({', '.join(inputs + signature_args)})" + + def fn_call(*inputs, batch_size=None, device=None, **raw_kwargs): + if batch_size is None: + for x in inputs: + x_batch_size = _get_batch_size(x) + if x_batch_size is not None: + batch_size = x_batch_size + break + if batch_size is None: + for arg in raw_kwargs.values(): + x_batch_size = _get_batch_size(arg) + if x_batch_size is not None: + batch_size = x_batch_size + break + max_batch_size = _next_pow2(batch_size or 1) + init_args = { + arg: _scalar_decay(raw_kwargs[arg]) + for arg in fixed_args + if arg != "max_batch_size" and arg in raw_kwargs and raw_kwargs[arg] is not None + } + call_args = { + arg: _scalar_decay(raw_kwargs[arg]) + for arg in tensor_args + if arg in raw_kwargs and raw_kwargs[arg] is not None + } + # If device is not specified, infer it from the inputs and call_args + if device is None: + + def _infer_device(): + for inp in inputs: + if inp is None: + continue + dev = _get_input_device(inp) + if dev is not None and dev.device_type == "gpu": + return dev + for arg in raw_kwargs.values(): + if arg is None: + continue + dev = _get_input_device(arg) + if dev is not None and dev.device_type == "gpu": + return dev + return _device.Device("cpu") + + device = _infer_device() + elif not isinstance(device, _device.Device): + device = _device.Device(device) + + # Get or create the operator instance that matches the arguments + with nvtx.annotate(f"get instance {op.op_name}", domain="op_builder"): + op_inst = op.get( + max_batch_size=max_batch_size, + name=None, + device=device, + num_inputs=len(inputs), + call_arg_names=tuple(call_args.keys()), + **init_args, + ) + + # Call the operator (the result is an Invocation object) + return op_inst(*inputs, **call_args) + + function = makefun.create_function(header, fn_call) + function.op_class = op + function.schema = schema + function._generated = True + setattr(module, fn_name, function) + return function + + +def build_operators(): + _all_ops = _ops._registry._all_registered_ops() + all_op_classes = [] + deprecated = {} + op_map = {} + for op_name in _all_ops: + if op_name.endswith("ExternalSource") or op_name.endswith("PythonFunction"): + continue + + schema = _b.GetSchema(op_name) + deprecated_in_favor = schema.DeprecatedInFavorOf() + if deprecated_in_favor: + deprecated[op_name] = deprecated_in_favor + cls = build_operator_class(schema) + all_op_classes.append(cls) + op_map[op_name] = cls + for deprecated, in_favor in deprecated.items(): + schema = _b.GetSchema(deprecated) + module = _find_or_create_module(ops, schema.ModulePath()) + setattr(module, deprecated, op_map[in_favor]) + + return all_op_classes diff --git a/dali/python/nvidia/dali/experimental/dali2/_tensor.py b/dali/python/nvidia/dali/experimental/dali2/_tensor.py new file mode 100644 index 0000000000..fd368c492b --- /dev/null +++ b/dali/python/nvidia/dali/experimental/dali2/_tensor.py @@ -0,0 +1,702 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Tuple, Union +from ._type import DType, dtype as _dtype, type_id as _type_id +from ._device import Device +import nvidia.dali.backend as _backend +from ._eval_context import EvalContext as _EvalContext +from . import _eval_mode +from . import _invocation +import copy +import nvidia.dali.types + + +def _volume(shape: Tuple[int, ...]) -> int: + ret = 1 + for s in shape: + ret *= s + return ret + + +def _backend_device(backend: Union[_backend.TensorCPU, _backend.TensorGPU]) -> Device: + if isinstance(backend, _backend.TensorCPU): + return Device("cpu") + elif isinstance(backend, _backend.TensorGPU): + return Device("gpu", backend.device_id()) + else: + raise ValueError(f"Unsupported backend type: {type(backend)}") + + +def _get_array_interface(data): + if a_func := getattr(data, "__array__", None): + try: + return a_func() + except TypeError: # CUDA torch tensor, CuPy array, etc. + return None + else: + return None + + +def _try_convert_enums(arr): + assert arr.dtype == object + item = arr.flat[0] + import numpy as np + + if isinstance(item, nvidia.dali.types.DALIInterpType): + return arr.astype(np.int32), nvidia.dali.types.INTERP_TYPE + elif isinstance(item, nvidia.dali.types.DALIDataType): + return arr.astype(np.int32), nvidia.dali.types.DATA_TYPE + elif isinstance(item, nvidia.dali.types.DALIImageType): + return arr.astype(np.int32), nvidia.dali.types.IMAGE_TYPE + + +class Tensor: + def __init__( + self, + data: Optional[Any] = None, + dtype: Optional[Any] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, + batch: Optional[Any] = None, + index_in_batch: Optional[int] = None, + invocation_result: Optional[_invocation.InvocationResult] = None, + copy: bool = False, + ): + if layout is None: + layout = "" + elif not isinstance(layout, str): + raise ValueError(f"Layout must be a string, got {type(layout)}") + + self._slice = None + self._backend = None + self._batch = batch + self._index_in_batch = index_in_batch + self._invocation_result = None + self._device = None + self._shape = None + self._dtype = None + self._layout = None + self._wraps_external_data = False + + copied = False + + from . import _fn + + if dtype is not None: + if not isinstance(dtype, DType): + dtype = _dtype(dtype) + + if batch is not None: + from . import _batch + + if not isinstance(batch, _batch.Batch): + raise ValueError("The `batch` argument must be a `Batch`") + self._batch = batch + self._index_in_batch = index_in_batch + self._dtype = batch.dtype + self._device = batch.device + self._layout = batch.layout + elif data is not None: + if isinstance(data, _backend.TensorCPU): + self._backend = data + self._wraps_external_data = True + elif isinstance(data, _backend.TensorGPU): + self._backend = data + self._wraps_external_data = True + elif isinstance(data, Tensor): + if dtype is None or _type_id(dtype) == data.dtype.type_id: + if device is None or device == data.device: + self.assign(data) + self._wraps_external_data = data._wraps_external_data + else: + dev = data.to_device(device).evaluate() + if dev is not self: + copied = True + self.assign(dev) + self._wraps_external_data = not copied + else: + from . import cast + + self.assign(cast(data, dtype, device=device).evaluate()) + elif isinstance(data, TensorSlice): + self._slice = data + elif hasattr(data, "__dlpack_device__"): + dl_device_type, device_id = data.__dlpack_device__() + if int(dl_device_type) == 1: # CPU + self._backend = _backend.TensorCPU(data.__dlpack__(), layout) + elif int(dl_device_type) == 2: # GPU + # If the current context is on the same device, use the same stream. + ctx = _EvalContext.get() + if ctx.device_id == device_id: + stream = ctx.cuda_stream + else: + stream = _backend.Stream(device_id) + args = {"stream": stream.handle} + self._backend = _backend.TensorGPU( + data.__dlpack__(**args), + layout=layout, + stream=stream, + ) + else: + raise ValueError(f"Unsupported device type: {dl_device_type}") + self._wraps_external_data = True + elif a := _get_array_interface(data): + self._backend = _backend.TensorCPU(a, layout) + self._wraps_external_data = True + else: + import numpy as np + + if dtype is not None: + # TODO(michalz): Built-in enum handling + self._backend = _backend.TensorCPU( + np.array(data, dtype=nvidia.dali.types.to_numpy_type(dtype.type_id)), + layout, + False, + ) + copied = True + self._wraps_external_data = False + self._dtype = dtype + else: + arr = np.array(data) + # DALI doesn't support int64 and float64, so we need to convert them to int32 + # and float32, respectively. + converted_dtype_id = None + if arr.dtype == np.int64: + arr = arr.astype(np.int32) + elif arr.dtype == np.uint64: + arr = arr.astype(np.uint32) + elif arr.dtype == np.float64: + arr = arr.astype(np.float32) + elif arr.dtype == object: + (arr, converted_dtype_id) = _try_convert_enums(arr) + self._backend = _backend.TensorCPU(arr, layout, False) + if converted_dtype_id is not None: + self._backend.reinterpret(converted_dtype_id) + copied = True + self._wraps_external_data = False + + if self._backend is not None: + self._device = _backend_device(self._backend) + if device is None: + device = self._device + else: + if device is None: + if self._device is None: + device = Device("cpu") + else: + device = self._device + self._device = device + + if self._backend is not None: + self._shape = self._backend.shape() + self._dtype = DType.from_type_id(self._backend.dtype) + self._layout = self._backend.layout() + + if isinstance(self._backend, _backend.TensorCPU) and device != _backend_device( + self._backend + ): + self.assign(self.to_device(device).evaluate()) + elif invocation_result is not None: + self._invocation_result = invocation_result + self._device = invocation_result.device + else: + raise ValueError("Either data, expression or batch and index must be provided") + + if _eval_mode.EvalMode.current().value >= _eval_mode.EvalMode.eager.value: + self.evaluate() + + if copy and self._backend is not None and not copied: + self.assign(self.to_device(device, True).evaluate()) + + def _is_external(self) -> bool: + return self._wraps_external_data + + def cpu(self) -> "Tensor": + return self.to_device(Device("cpu")) + + def gpu(self, index: Optional[int] = None) -> "Tensor": + return self.to_device(Device("gpu", index)) + + @property + def device(self) -> Device: + if self._device is not None: + return self._device + if self._invocation_result is not None: + self._device = self._invocation_result.device + return self._device + else: + raise RuntimeError("Device not set") + + def to_device(self, device: Device, force_copy: bool = False) -> "Tensor": + if self.device == device and not force_copy: + return self + else: + with device: + from . import copy + + return copy(self, device=device.device_type) + + def assign(self, other: "Tensor"): + if other is self: + return + self._device = other._device + self._shape = other._shape + self._dtype = other._dtype + self._layout = other._layout + self._backend = other._backend + self._slice = other._slice + self._batch = other._batch + self._index_in_batch = other._index_in_batch + self._invocation_result = other._invocation_result + self._wraps_external_data = other._wraps_external_data + + @property + def data(self): + if not self._backend: + self.evaluate() + return self._backend + + @property + def ndim(self) -> int: + if self._backend is not None: + return self._backend.ndim() + elif self._slice is not None: + return self._slice.ndim + elif self._invocation_result is not None: + return self._invocation_result.ndim + elif self._batch is not None: + return self._batch.ndim + else: + raise RuntimeError("Cannot determine the number of dimensions of the tensor.") + + @property + def shape(self) -> Tuple[int, ...]: + if self._shape is None: + if self._invocation_result is not None: + self._shape = self._invocation_result.shape + elif self._slice: + self._shape = self._slice.shape + elif self._batch is not None: + self._shape = self._batch.shape[self._index_in_batch] + else: + self._shape = self._backend.shape() + return self._shape + + @property + def dtype(self) -> DType: + if self._dtype is None: + if self._invocation_result is not None: + self._dtype = _dtype(self._invocation_result.dtype) + elif self._slice: + self._dtype = self._slice.dtype + elif self._batch is not None: + self._dtype = self._batch.dtype + else: + self._dtype = _dtype(self._backend.dtype) + return self._dtype + + @property + def layout(self) -> str: + if self._layout is None: + if self._invocation_result is not None: + self._layout = self._invocation_result.layout + elif self._slice: + self._layout = self._slice.layout + elif self._batch is not None: + self._layout = self._batch.layout + else: + self._layout = self._backend.layout() + return self._layout + + @property + def size(self) -> int: + return _volume(self.shape) + + @property + def nbytes(self) -> int: + return self.size * self.dtype.bytes + + @property + def itemsize(self) -> int: + return self.dtype.bytes + + def item(self) -> Any: + if self.size != 1: + raise ValueError(f"Tensor has {self.size} elements, expected 1") + import numpy as np + + with _EvalContext.get(): + return np.array(self.cpu().evaluate()._backend).item() + + def evaluate(self): + if self._backend is None: + with _EvalContext.get() as ctx: + if self._slice: + self._backend = self._slice.evaluate()._backend + elif self._batch is not None: + t = self._batch._tensors[self._index_in_batch] + if t is self: + self._backend = self._batch.evaluate()._backend[self._index_in_batch] + else: + self._backend = t.evaluate()._backend + else: + assert self._invocation_result is not None + self._backend = self._invocation_result.value(ctx) + self._shape = self._backend.shape() + self._dtype = DType.from_type_id(self._backend.dtype) + self._layout = self._backend.layout() + return self + + def __getitem__(self, ranges: Any) -> "Tensor": + if not isinstance(ranges, tuple): + ranges = (ranges,) + + if all(_is_full_slice(r) or r is Ellipsis for r in ranges): + return self + else: + if self._slice: + return self._slice.__getitem__(ranges) + else: + return Tensor(TensorSlice(self, ranges)) + + def _is_same_tensor(self, other: "Tensor") -> bool: + return ( + self._backend is other._backend + and self._invocation_result is other._invocation_result + and self._slice is other._slice + ) + + def __str__(self) -> str: + return "Tensor(\n" + str(self.evaluate()._backend) + ")" + + def __add__(self, other): + return _arithm_op("add", self, other) + + def __radd__(self, other): + return _arithm_op("add", other, self) + + def __sub__(self, other): + return _arithm_op("sub", self, other) + + def __rsub__(self, other): + return _arithm_op("sub", other, self) + + def __mul__(self, other): + return _arithm_op("mul", self, other) + + def __rmul__(self, other): + return _arithm_op("mul", other, self) + + def __pow__(self, other): + return _arithm_op("pow", self, other) + + def __rpow__(self, other): + return _arithm_op("pow", other, self) + + def __truediv__(self, other): + return _arithm_op("fdiv", self, other) + + def __rtruediv__(self, other): + return _arithm_op("fdiv", other, self) + + def __floordiv__(self, other): + return _arithm_op("div", self, other) + + def __rfloordiv__(self, other): + return _arithm_op("div", other, self) + + def __neg__(self): + return _arithm_op("minus", self) + + # Short-circuiting the execution, unary + is basically a no-op + def __pos__(self): + return self + + def __eq__(self, other): + return _arithm_op("eq", self, other) + + def __ne__(self, other): + return _arithm_op("neq", self, other) + + def __lt__(self, other): + return _arithm_op("lt", self, other) + + def __le__(self, other): + return _arithm_op("leq", self, other) + + def __gt__(self, other): + return _arithm_op("gt", self, other) + + def __ge__(self, other): + return _arithm_op("geq", self, other) + + def __and__(self, other): + return _arithm_op("bitand", self, other) + + def __rand__(self, other): + return _arithm_op("bitand", other, self) + + def __or__(self, other): + return _arithm_op("bitor", self, other) + + def __ror__(self, other): + return _arithm_op("bitor", other, self) + + def __xor__(self, other): + return _arithm_op("bitxor", self, other) + + def __rxor__(self, other): + return _arithm_op("bitxor", other, self) + + +def _arithm_op(name, *args, **kwargs): + from . import _fn + + argsstr = " ".join(f"&{i}" for i in range(len(args))) + from . import arithmetic_generic_op + + return arithmetic_generic_op(*args, expression_desc=f"{name}({argsstr})") + + +def _is_int_value(tested: Any, reference: int) -> bool: + return isinstance(tested, int) and tested == reference + + +def _is_full_slice(r: Any) -> bool: + if isinstance(r, slice): + return ( + (r.start is None or _is_int_value(r.start, 0)) + and (r.stop is None) + and (r.step is None or _is_int_value(r.step, 1)) + ) + else: + return False + + +def _is_index(r: Any) -> bool: + return not isinstance(r, slice) and r is not Ellipsis + + +def _clamp(value: int, lo: int, hi: int) -> int: + return max(lo, min(value, hi)) + + +def _scalar_value(value: Any) -> int: + if isinstance(value, int): + return value + elif isinstance(value, Tensor): + return value.item() + else: + raise ValueError(f"Unsupported type: {type(value)}") + + +class TensorSlice: + def __init__(self, tensor: Tensor, ranges: Tuple[Any, ...]): + self._tensor = copy.copy(tensor) + self._ranges = [copy.copy(r) for r in ranges] + self._ndim_dropped = 0 + self._shape = None + self._absolute_ranges = None + self._layout = None + num_ranges = len(ranges) + ellipsis_found = False + for r in ranges: + if _is_index(r): + self._ndim_dropped += 1 + elif r is Ellipsis: + if ellipsis_found: + raise ValueError("Only one Ellipsis is allowed.") + num_ranges -= 1 + ellipsis_found = True + if num_ranges > tensor.ndim: + raise ValueError( + f"Number of ranges ({num_ranges}) " + f"is greater than the number of dimensions of the tensor ({tensor.ndim})" + ) + + @property + def ndim(self) -> int: + return self._tensor.ndim - self._ndim_dropped + + @property + def shape(self) -> Tuple[int, ...]: + if self._shape is None: + shape = [] + if self._absolute_ranges is None: + self._absolute_ranges = self._canonicalize_ranges(self._ranges, self._tensor.shape) + for r in self._absolute_ranges: + if isinstance(r, slice): + shape.append((r.stop + r.step - r.start - 1) // r.step) + self._shape = tuple(shape) + return self._shape + + @property + def dtype(self) -> DType: + return self._tensor.dtype + + @property + def device(self) -> Device: + return self._tensor.device + + @property + def layout(self) -> str: + if self._layout is not None: + return self._layout + input_layout = self._tensor.layout + if self._ndim_dropped == 0 or input_layout == "" or input_layout is None: + self._layout = input_layout + return self._layout + + j = 0 + layout = "" + for i, r in enumerate(self._ranges): + if isinstance(self._ranges[i], slice): + layout += input_layout[j] + j += 1 + elif r is Ellipsis: + j += self._tensor.ndim - len(self._ranges) + 1 + else: + j += 1 # skip this dimension + self._layout = layout + return self._layout + + @staticmethod + def _canonicalize_ranges(ranges, in_shape) -> Tuple[int, ...]: + d = 0 + abs_ranges = [] + for i, r in enumerate(ranges): + if r is Ellipsis: + print("in_shape", in_shape) + to_skip = len(in_shape) - len(ranges) + 1 + print("to_skip", to_skip) + for _ in range(to_skip): + abs_ranges.append(slice(0, in_shape[d], 1)) + d += 1 + continue + if isinstance(r, slice): + step = _scalar_value(r.step) if r.step is not None else 1 + if step == 0: + raise ValueError("slice step cannot be zero") + extent = in_shape[d] + start, stop = 0, extent + if r.start is not None: + start = _scalar_value(r.start) + if start < 0: + start += extent + start = _clamp(start, 0, extent) + if r.stop is not None: + stop = _scalar_value(r.stop) + if stop < 0: + stop += extent + stop = _clamp(stop, start, extent) + abs_ranges.append(slice(start, stop, step)) + else: + idx = _scalar_value(r) + if idx < 0: + idx += in_shape[d] + if idx < 0 or idx >= in_shape[d]: + raise IndexError( + f"Index {idx} is out of bounds for dimension {d} with size {in_shape[d]}" + ) + abs_ranges.append(idx) + d += 1 + while d < len(in_shape): + abs_ranges.append(slice(0, in_shape[d], 1)) + d += 1 + + return tuple(abs_ranges) + + def __getitem__(self, ranges: Any) -> "Tensor": + if not isinstance(ranges, tuple): + ranges = (ranges,) + + if all(_is_full_slice(r) or r is Ellipsis for r in ranges): + return Tensor(self) + else: + ranges = self._canonicalize_ranges(ranges, self.shape) + abs_ranges = list(self._absolute_ranges) + i = 0 + for d, r in enumerate(self._absolute_ranges): + if isinstance(r, slice): + if isinstance(ranges[i], slice): + start = r.start + ranges[i].start + stop = r.start + ranges[i].stop + step = r.step * ranges[i].step + abs_ranges[d] = slice(start, stop, step) + else: + abs_ranges[d] = r.start + ranges[i] * r.step + i += 1 + result = TensorSlice(self._tensor, tuple(abs_ranges)) + if _eval_mode.EvalMode.current().value >= _eval_mode.EvalMode.eager.value: + result.evaluate() + return Tensor(result) + + def evaluate(self): + with _EvalContext.get(): + if len(self._ranges) == 0: + return self._tensor.evaluate() + + if all(_is_full_slice(r) for r in self._ranges): + return self._tensor.evaluate() + + args = {} + d = 0 + for i, r in enumerate(self._ranges): + if r is Ellipsis: + d = self._tensor.ndim - len(self._ranges) + i + 1 + elif isinstance(r, slice): + if r.start is not None: + args[f"lo_{d}"] = r.start + if r.stop is not None: + args[f"hi_{d}"] = r.stop + if r.step is not None: + args[f"step_{d}"] = r.step + d += 1 + else: + args[f"at_{d}"] = r + d += 1 + + from . import tensor_subscript + + return tensor_subscript(self._tensor, **args).evaluate() + + +def tensor( + data: Any, + dtype: Optional[Any] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, +): + """Copies an existing tensor-like object into a DALI tensor. + + @param data: A tensor-like object, a list or a scalar value. + @param dtype: The requested data type of the tensor. + @param device: The device to use for the tensor. + @param layout: The layout of the tensor. + """ + return Tensor(data, dtype=dtype, device=device, layout=layout, copy=True) + + +def as_tensor( + data: Any, + dtype: Optional[Any] = None, + device: Optional[Device] = None, + layout: Optional[str] = None, +): + """Wraps an existing tensor-like object into a DALI tensor. + + This function avoids copying the data if possible. + """ + return Tensor(data, dtype=dtype, device=device, layout=layout, copy=False) diff --git a/dali/python/nvidia/dali/experimental/dali2/ops.py b/dali/python/nvidia/dali/experimental/dali2/ops.py new file mode 100644 index 0000000000..a4fc891d3f --- /dev/null +++ b/dali/python/nvidia/dali/experimental/dali2/ops.py @@ -0,0 +1,441 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import _device +from . import _invocation +from . import _eval_context +import nvidia.dali as dali +from typing import Optional +import nvidia.dali.backend_impl as _b +from ._tensor import Tensor +from ._batch import Batch + + +class Operator: + def __init__( + self, + max_batch_size, + name=None, + device="cpu", + num_inputs=None, + call_arg_names=None, + **kwargs, + ): + self._name = name + self._max_batch_size = max_batch_size + self._init_args = kwargs + self._num_inputs = num_inputs + self._call_arg_names = None if call_arg_names is None else tuple(call_arg_names) + self._api_type = None + if isinstance(device, str): + self._device = _device.Device( + name=device, + device_id=kwargs.get("device_id", _device.Device.default_device_id(device)), + ) + else: + if not isinstance(device, _device.Device): + raise TypeError( + f"`device` must be a Device instance or a string, got {type(device)}" + ) + self._device = device + self._input_meta = [] + self._arg_meta = {} + self._num_outputs = None + self._output_devices = None + self._op_inst = None + self._op_backend = None + self._op_spec = None + self._last_invocation = None + + @classmethod + def get( + cls, + max_batch_size: int, + name: Optional[str] = None, + device: Optional[_device.Device] = None, + num_inputs: Optional[int] = None, + call_arg_names: Optional[list[str]] = None, + **init_args, + ): + if device is None: + device = _device.Device.current() + if not isinstance(device, _device.Device): + raise TypeError("device must be a Device instance") + + def freeze_arg(arg): + if isinstance(arg, list): + return tuple(arg) + return arg + + def freeze_args(args): + sorted_keys = sorted(args.keys()) + return tuple([(k, freeze_arg(args[k])) for k in sorted_keys]) + + call_arg_names = freeze_arg(call_arg_names) + key = (device, max_batch_size, num_inputs, call_arg_names, freeze_args(init_args)) + inst = cls._instance_cache.get(key, None) + if inst is None: + with device: + inst = cls( + max_batch_size, + name=name, + device=device, + num_inputs=num_inputs, + call_arg_names=call_arg_names, + **init_args, + ) + cls._instance_cache[key] = inst + return inst + + def infer_num_outputs(self, *inputs, **args): + self._init_spec(inputs, args) + return self._num_outputs + + def input_device(self, index: int, actual_device: Optional[str] = None): + default_input_device = "gpu" if self._device.device_type == "gpu" else "cpu" + dev_type = self.schema.GetInputDevice(index, actual_device, default_input_device) + if dev_type is None: + return self._device + return _device.Device(dev_type, self._device.device_id) # inherit the device id + + def infer_output_devices(self, *inputs, **args): + self._init_spec(inputs, args) + return self._output_devices + + def _pre_call(self, *inputs, **args): + pass + + def _is_backend_initialized(self): + return self._op_backend is not None + + def _reset_backend(self): + self._op_backend = None + self._op_spec = None + + def _init_spec(self, inputs, args): + if self._op_spec is None: + self._num_inputs = len(inputs) + self._call_arg_names = tuple(args.keys()) + import nvidia.dali as dali + + with self._device: + input_nodes = [ + dali.data_node.DataNode( + name=f"input_{i}", device=inputs[i].device.device_type, source=None + ) + for i in range(len(inputs)) + ] + arg_nodes = { + name: dali.data_node.DataNode(name=f"arg_{name}", device="cpu", source=None) + for name in args + } + op = self.legacy_op( + name=self._name, device=self._device.device_type, **self._init_args + ) + self._op_inst = op + out = op(*input_nodes, **arg_nodes) + if isinstance(out, (list, tuple)): + spec = out[0].source.spec + else: + spec = out.source.spec + + self._op_spec = spec + + if isinstance(out, (tuple, list)): + self._output_devices = [] + self._num_outputs = len(out) + for o in out: + device_type = o.device + device_id = self._device.device_id + self._output_devices.append(_device.Device(device_type, device_id)) + else: + self._num_outputs = 1 + self._output_devices = [_device.Device(out.device, self._device.device_id)] + + self._set_meta(inputs, args) + + def _init_backend(self, ctx, inputs, args): + if self._op_backend is not None: + return + + if ctx is None: + ctx = _eval_context.EvalContext.get() + with self._device: + with ctx: + self._init_spec(inputs, args) + if ctx._thread_pool is not None: + self._op_spec.AddArg("num_threads", ctx._thread_pool.num_threads) + else: + self._op_spec.AddArg("num_threads", 1) + self._op_spec.AddArg( + "device_id", + ( + self._device.device_id + if self._device.device_type == "gpu" or self._device.device_type == "mixed" + else dali.types.CPU_ONLY_DEVICE_ID + ), + ) + if self._max_batch_size is None: + self._max_batch_size = 1 + self._op_spec.AddArg("max_batch_size", self._max_batch_size) + self._op_backend = _b._Operator(self._op_spec) + + def run(self, ctx, *inputs, batch_size=None, **args): + if ( + batch_size is not None + and self._max_batch_size is not None + and batch_size > self._max_batch_size + and self.schema.IsStateful() + ): + raise RuntimeError( + f"The batch size {batch_size} is larger than the `max_batch_size` " + f"{self._max_batch_size} specified when the operator was created." + ) + + def _is_batch(): + nonlocal inputs, args + for input in inputs: + if isinstance(input, ((_b.TensorListCPU, _b.TensorListGPU))): + return True + for input in args.values(): + if isinstance(input, ((_b.TensorListCPU, _b.TensorListGPU))): + return True + return False + + is_batch = batch_size is not None or _is_batch() + if self._is_backend_initialized(): + if self.schema.IsStateful(): + # clearing the backend in a stateful op would destroy the state + self.check_compatible(inputs, batch_size, args) + elif not self.is_compatible(inputs, batch_size, args): + # we can reinitialize a stateless operator - not very efficient :( + self._reset_backend() + + self._init_backend(ctx, inputs, args) + workspace = _b._Workspace(ctx._thread_pool, ctx._cuda_stream) + for i, input in enumerate(inputs): + workspace.AddInput(self._to_batch(input).evaluate()._backend) + for name, arg in args.items(): + workspace.AddArgumentInput(name, self._to_batch(arg).evaluate()._backend) + self._op_backend.SetupAndRun(workspace, batch_size) + out = workspace.GetOutputs() + if is_batch: + return tuple(out) + else: + tensors = tuple(o[0] for o in out) + return tensors + + def _to_batch(self, x): + if not isinstance(x, Batch): + return Batch([x]) + else: + return x + + def _set_meta(self, inputs, args): + self._input_meta = [self._make_meta(input) for input in inputs] + self._arg_meta = {name: self._make_meta(arg) for name, arg in args.items()} + + def is_compatible(self, inputs, batch_size, args): + if batch_size is not None: + if batch_size > self._max_batch_size: + return False + if self._input_meta != [self._make_meta(input) for input in inputs]: + return False + if self._arg_meta != {name: self._make_meta(arg) for name, arg in args.items()}: + return False + return True + + def check_compatible(self, inputs, batch_size, args): + def error_header(): + return ( + f"The invocation of operator {self.display_name} " + f"is not compatible with the previous call:\n" + ) + + if batch_size is not None: + if batch_size > self._max_batch_size: + raise RuntimeError( + f"{error_header()}" + f"The batch size {batch_size} is larger than the `max_batch_size` " + f"{self._max_batch_size} specified when the operator was created." + ) + + if len(inputs) != len(self._input_meta): + raise RuntimeError( + f"{error_header()}" + f"The number of inputs ({len(inputs)}) does not match the number " + f"of inputs used in the previous call ({len(self._input_meta)})." + ) + for i, input in enumerate(inputs): + if self._input_meta[i] != self._make_meta(input): + raise RuntimeError( + f"{error_header()}" + f"The input {i} is not compatible with the input used in the previous call." + ) + for name, arg in args.items(): + if name not in self._arg_meta: + raise RuntimeError( + f"{error_header()}" f"The argument `{name}` was not used in the previous call." + ) + if self._arg_meta[name] != self._make_meta(arg): + raise RuntimeError( + f"{error_header()}" + f"The argument `{name}` is not compatible with the argument used in the " + f"previous call." + ) + for name in self._arg_meta: + if name not in args: + raise RuntimeError( + f"{error_header()}" + f"The argument `{name}` used in the previous call was not supplied in the " + f"current one." + ) + + def _make_meta(self, x): + is_batch = False + if isinstance(x, _invocation.Invocation): + is_batch = x.is_batch + elif isinstance(x, Batch): + is_batch = True + else: + is_batch = False + + return { + "is_batch": is_batch, + "ndim": x.ndim, + "layout": x.layout, + "dtype": x.dtype, + } + + @property + def display_name(self): + if "display_name" in self._init_args: + type_name = self._init_args["display_name"] + else: + type_name = self.schema.OperatorName() + if self._name is not None: + return f'type_name "{self._name}"' + else: + return type_name + + +class Reader(Operator): + def __init__( + self, + batch_size=None, + name=None, + device="cpu", + num_inputs=None, + call_arg_names=None, + **kwargs, + ): + if name is None: + name = f"Reader_{id(self)}" + self._actual_batch_size = batch_size + self._batch_size = batch_size + super().__init__( + self._actual_batch_size, name, device, num_inputs, call_arg_names, **kwargs + ) + + def _pre_call(self, *inputs, **args): + if self._api_type is None: + self._api_type = "run" + elif self._api_type != "run": + raise RuntimeError( + "Cannot mix `samples`, `batches` and `run`/`__call__` on the same reader." + ) + + def run(self, ctx=None, *inputs, **args): + if self._api_type is None: + self._api_type = "run" + elif self._api_type != "run": + raise RuntimeError( + "Cannot mix `samples`, `batches` and `run`/`__call__` on the same reader." + ) + + return super().run(ctx, *inputs, **args) + + def samples(self, ctx: Optional[_eval_context.EvalContext] = None): + if self._api_type is None: + self._api_type = "samples" + elif self._api_type != "samples": + raise RuntimeError( + "Cannot mix `samples`, `batches` and `run`/`__call__` on the same reader." + ) + + if ctx is None: + ctx = _eval_context.EvalContext.get() + with ctx: + if not self._is_backend_initialized(): + if self._actual_batch_size is None: + self._actual_batch_size = 1 + if self._max_batch_size is None: + self._max_batch_size = self._actual_batch_size + self._init_backend(ctx, (), {}) + meta = self._op_backend.GetReaderMeta() + idx = 0 + while idx < meta["epoch_size_padded"]: + outputs = super().run(ctx, batch_size=self._actual_batch_size) + batch_size = len(outputs[0]) + assert batch_size == self._actual_batch_size + idx += batch_size + for x in zip(*outputs): + outs = tuple(Tensor(o) for o in x) + yield outs + + def batches(self, batch_size=None, ctx: Optional[_eval_context.EvalContext] = None): + if self._api_type is None: + self._api_type = "batches" + elif self._api_type != "batches": + raise RuntimeError("Cannot mix samples(), batches() and run() on the same reader.") + + if ctx is None: + ctx = _eval_context.EvalContext.get() + with ctx: + if batch_size is None: + batch_size = self._batch_size + if batch_size is None: + raise ValueError("Batch size was not specified") + if not self._op_backend: + if self._max_batch_size and self._max_batch_size < batch_size: + raise ValueError( + f"`batch_size` {batch_size} is larger than the `max_batch_size` " + f"{self._max_batch_size} specified when the operator was created" + ) + self._max_batch_size = batch_size + self._init_backend(ctx, (), {}) + else: + if self._max_batch_size and self._max_batch_size != batch_size: + raise ValueError( + f"`batch_size` {batch_size} is different than the `max_batch_size` " + f"{self._max_batch_size} used in the previous call" + ) + meta = self._op_backend.GetReaderMeta() + idx = 0 + while idx < meta["epoch_size_padded"]: + outputs = super().run(ctx, batch_size=batch_size) + batch_size_returned = len(outputs[0]) + assert batch_size_returned == batch_size + idx += batch_size_returned + yield tuple(Batch(o) for o in outputs) + + +_all_ops = [] + + +def _initialize(): + from . import _op_builder + + global _all_ops + _all_ops = _op_builder.build_operators() diff --git a/dali/python/nvidia/dali/ops/_signatures.py b/dali/python/nvidia/dali/ops/_signatures.py index 4204abdefe..28536c37fb 100644 --- a/dali/python/nvidia/dali/ops/_signatures.py +++ b/dali/python/nvidia/dali/ops/_signatures.py @@ -63,6 +63,10 @@ def __repr__(self): # This is not the DataNode you are looking for. _DataNode = _create_annotation_placeholder("DataNode") +# for dynamic API +_Tensor = _create_annotation_placeholder("Tensor") +_Batch = _create_annotation_placeholder("Batch") + # The placeholder for the DALI Enum types, as the bindings from backend don't play nice, # we need actual Python classes. _DALIDataType = _create_annotation_placeholder("DALIDataType") @@ -138,28 +142,36 @@ def _get_annotation_input_regular(schema): return Union[_DataNode, _TensorLikeIn] -def _get_annotation_return_regular(schema): - """Produce the return annotation for DALI operator suitable for primary, non-MIS overload. +def _get_annotation_return_helper(schema, return_type=_DataNode): + """Produce the return annotation for DALI operator suitable for regular API that doesn't mix + entities (Tensor, Batch, DataNode, but not MIS). Note the flattening, single output is not packed in Sequence. """ if schema.HasOutputFn(): # Dynamic number of outputs, not known at "compile time" - return_annotation = Union[_DataNode, Sequence[_DataNode], None] + return_annotation = Union[return_type, Sequence[return_type], None] else: # Call it with a dummy spec, as we don't have Output function num_regular_output = schema.CalculateOutputs(_b.OpSpec("")) if num_regular_output == 0: return_annotation = None elif num_regular_output == 1: - return_annotation = _DataNode + return_annotation = return_type else: # Here we could utilize the fact, that the tuple has known length, but we can't # as DALI operators return a list # Also, we don't advertise the actual List type, hence the Sequence. - return_annotation = Sequence[_DataNode] + return_annotation = Sequence[return_type] return return_annotation +def _get_annotation_return_regular(schema): + """Produce the return annotation for DALI operator suitable for primary, non-MIS overload. + Note the flattening, single output is not packed in Sequence. + """ + return _get_annotation_return_helper(schema=schema, return_type=_DataNode) + + def _get_annotation_input_mis(schema): """Return the annotation for multiple input sets, used for the secondary operator overload. A function is used as a global variable can be confused with type alias. @@ -199,6 +211,25 @@ def _get_annotation_return_mis(schema): return_annotation = Union[Sequence[_DataNode], List[Sequence[_DataNode]]] return return_annotation +# Dynamic API handling: +# TODO(klecki): adjust to match the spec + + +def _get_annotation_input_dynamic_sample(schema): + return Union[_Tensor, _TensorLikeIn] + + +def _get_annotation_return_dynamic_sample(schema): + return _get_annotation_return_helper(schema=schema, return_type=_Tensor) + + +def _get_annotation_input_dynamic_batch(schema): + return Union[_Batch, _TensorLikeIn] + + +def _get_annotation_return_dynamic_batch(schema): + return _get_annotation_return_helper(schema=schema, return_type=_Batch) + def _get_positional_input_params(schema, input_annotation_gen=_get_annotation_input_regular): """Get the list of positional only inputs to the operator. @@ -487,6 +518,31 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True, ) +# TODO(klecki): Adjust _gen_dynamic_signature function to customize stubs for ndd module. +def _gen_dynamic_signature(schema, schema_name, fn_name): + """Write the stub of the fn API function with the docstring, for given operator.""" + return inspect_repr_fixups( + f""" +@overload +def {fn_name}{_call_signature(schema, include_inputs=True, include_kwargs=True, + input_annotation_gen=_get_annotation_input_dynamic_sample, + return_annotation_gen=_get_annotation_return_dynamic_sample)}: + \"""{_docs._docstring_generator_fn(schema_name)} + \""" + ... + +@overload +def {fn_name}{_call_signature(schema, include_inputs=True, include_kwargs=True, + input_annotation_gen=_get_annotation_input_dynamic_batch, + return_annotation_gen=_get_annotation_return_dynamic_batch)}: + \"""{_docs._docstring_generator_fn(schema_name)} + \""" + ... + +""" + ) + + # Preamble with license and helper imports for the stub file. # We need the placeholders for actual Python classes, as the ones that are exported from backend # don't seem to work with the intellisense. @@ -514,6 +570,34 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True, from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType +""" + +# TODO(klecki): Fill the missing type imports in dynamic API. +# Probably dali dtypes etc need to be added here. +_HEADER_DYNAMIC = """ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Optional, overload +from typing import Any, List, Sequence + +from nvidia.dali._typing import TensorLikeIn, TensorLikeArg + +from ._tensor import Tensor +from ._batch import Batch + + """ @@ -585,12 +669,21 @@ def _group_signatures(api: str): "generated": [], } - api_module = fn if api == "fn" else ops + import nvidia.dali.experimental.dali2 as ndd + + if api == "fn": + api_module = fn + elif api == "ops": + api_module = ops + elif api == "dynamic": + api_module = ndd + + api_naming_convention = "fn" if api in ("fn", "dynamic") else "ops" for schema_name in sorted(_registry._all_registered_ops()): schema = _b.TryGetSchema(schema_name) - _, module_nesting, op_name = _names._process_op_name(schema_name, api=api) + _, module_nesting, op_name = _names._process_op_name(schema_name, api=api_naming_convention) op = _get_op(api_module, module_nesting + [op_name]) if schema is None: @@ -612,7 +705,8 @@ def _group_signatures(api: str): class StubFileManager: - def __init__(self, nvidia_dali_path: Path, api: str): + + def __init__(self, nvidia_dali_path: Path, api: Path): self._module_to_file = {} self._nvidia_dali_path = nvidia_dali_path self._api = api @@ -632,7 +726,12 @@ def get(self, module_nesting: List[str]): open(file_path, "w").close() # clear the file f = open(file_path, "a") self._module_to_file[module_path] = f - f.write(_HEADER) + # TODO(klecki): Adjust when it's ready + # if "dynamic" in str(self._api): + if "dali2" in str(self._api): + f.write(_HEADER_DYNAMIC) + else: + f.write(_HEADER) full_module_nesting = [""] + module_nesting # Find out all the direct submodules and add the imports submodules_dict = self._module_tree @@ -651,41 +750,56 @@ def close(self): def gen_all_signatures(nvidia_dali_path, api): - """Generate the signatures for "fn" or "ops" api. + """Generate the signatures for "fn", "ops" or "dynamic" api. Parameters ---------- nvidia_dali_path : Path The path to the wheel pre-packaging to the nvidia/dali directory. api : str - "fn" or "ops" + "fn", "ops" or "dynamic" """ nvidia_dali_path = Path(nvidia_dali_path) + api_path = api if api in ("fn", "ops") else Path("experimental") / "dali2" + api_naming_convention = "fn" if api in ("fn", "dynamic") else "ops" - with closing(StubFileManager(nvidia_dali_path, api)) as stub_manager: + with closing(StubFileManager(nvidia_dali_path, api_path)) as stub_manager: sig_groups = _group_signatures(api) + if api == "dynamic": + print(f"Generating signatures for dynamic API in {sig_groups.keys()=}") + for k, v in sig_groups.items(): + print(f"{k}: {len(v)} signatures") + # Python-only and the manually defined ones are reexported from their respective modules - for schema_name, op in sig_groups["python_only"] + sig_groups["python_wrapper"]: - _, module_nesting, op_name = _names._process_op_name(schema_name, api=api) + # TODO(klecki): Handle this in dynamic API, I have no idea if we have python function et al. + if api != "dynamic": + for schema_name, op in sig_groups["python_only"] + sig_groups["python_wrapper"]: + _, module_nesting, op_name = _names._process_op_name(schema_name, api=api) - stub_manager.get(module_nesting).write( - f"\n\nfrom {op._impl_module} import" f" ({op.__name__} as {op.__name__})\n\n" - ) + stub_manager.get(module_nesting).write( + f"\n\nfrom {op._impl_module} import" f" ({op.__name__} as {op.__name__})\n\n" + ) # we do not go over sig_groups["hidden_or_internal"] at all as they are supposed to not be # directly visible # Runtime generated classes use fully specified stubs. for schema_name, op in sig_groups["generated"]: - _, module_nesting, op_name = _names._process_op_name(schema_name, api=api) + _, module_nesting, op_name = _names._process_op_name( + schema_name, api=api_naming_convention + ) schema = _b.TryGetSchema(schema_name) if api == "fn": stub_manager.get(module_nesting).write( _gen_fn_signature(schema, schema_name, op_name) ) - else: + elif api == "ops": stub_manager.get(module_nesting).write( _gen_ops_signature(schema, schema_name, op_name) ) + elif api == "dynamic": + stub_manager.get(module_nesting).write( + _gen_dynamic_signature(schema, schema_name, op_name) + ) diff --git a/internal_tools/python_stub_generator.py b/internal_tools/python_stub_generator.py index 4fd16afacc..9921d29121 100755 --- a/internal_tools/python_stub_generator.py +++ b/internal_tools/python_stub_generator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,3 +36,4 @@ _signatures.gen_all_signatures(Path(args.wheel_path), "fn") _signatures.gen_all_signatures(Path(args.wheel_path), "ops") + _signatures.gen_all_signatures(Path(args.wheel_path), "dynamic") # experimental/dali2