From 56c8bba370d80ef9ec77e6e7d1a2b3af396d83df Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Wed, 10 Apr 2024 23:02:00 +0200 Subject: [PATCH 1/7] Improve support for DALI enum types Support for DALI enum types (DALIDataType, DALIImageType, DALIInterpType) is added to Constant, Cast, Choice and Copy operators. Thanks to Constant op support, followin syntax will work now: ``` fn.random.choice([types.DALIInterpType.INTERP_LINEAR, types.DALIInterpType.INTERP_NN]) ``` allowing convenient selection of enum parameters. Casting support is allowed only between non-fp types and enums. Explicit error is raised when the enums are used with buffer protcol (conversion to numpy, printing etc) - Python expectes pointer-to-object representation there, while we return them as C-style enums with numeric value under the hood. As an alternative we can allow to just access the underlying data but I have chosen a bit more restricive approach for now. Signed-off-by: Krzysztof Lecki --- dali/kernels/common/cast_gpu.cu | 57 +++++++----- dali/operators/generic/cast.h | 8 +- dali/operators/generic/constant.h | 5 +- dali/operators/random/choice.h | 2 +- dali/operators/random/choice_cpu.cc | 3 + dali/pipeline/data/types.h | 12 +++ dali/python/nvidia/dali/_backend_enums.pyi | 3 + dali/python/nvidia/dali/types.py | 32 ++++++- dali/util/pybind.h | 14 +++ docs/data_types.rst | 3 +- include/dali/core/convert.h | 102 +++++++++++++++++++-- 11 files changed, 200 insertions(+), 41 deletions(-) diff --git a/dali/kernels/common/cast_gpu.cu b/dali/kernels/common/cast_gpu.cu index a0aa6761bdf..a381f07dc56 100644 --- a/dali/kernels/common/cast_gpu.cu +++ b/dali/kernels/common/cast_gpu.cu @@ -19,6 +19,7 @@ #include "dali/kernels/common/utils.h" #include "dali/kernels/kernel.h" #include "dali/kernels/dynamic_scratchpad.h" +#include "dali/pipeline/data/types.h" namespace dali { namespace kernels { @@ -102,32 +103,38 @@ void CastGPU::Run(KernelContext &ctx, #define INSTANTIATE_IMPL(Out, In) template struct DLL_PUBLIC CastGPU; -#define INSTANTIATE_FOREACH_INTYPE(Out) \ - INSTANTIATE_IMPL(Out, bool); \ - INSTANTIATE_IMPL(Out, uint8_t); \ - INSTANTIATE_IMPL(Out, uint16_t); \ - INSTANTIATE_IMPL(Out, uint32_t); \ - INSTANTIATE_IMPL(Out, uint64_t); \ - INSTANTIATE_IMPL(Out, int8_t); \ - INSTANTIATE_IMPL(Out, int16_t); \ - INSTANTIATE_IMPL(Out, int32_t); \ - INSTANTIATE_IMPL(Out, int64_t); \ - INSTANTIATE_IMPL(Out, float); \ - INSTANTIATE_IMPL(Out, double); \ - INSTANTIATE_IMPL(Out, dali::float16); - -INSTANTIATE_FOREACH_INTYPE(bool); \ -INSTANTIATE_FOREACH_INTYPE(uint8_t); \ -INSTANTIATE_FOREACH_INTYPE(uint16_t); \ -INSTANTIATE_FOREACH_INTYPE(uint32_t); \ -INSTANTIATE_FOREACH_INTYPE(uint64_t); \ -INSTANTIATE_FOREACH_INTYPE(int8_t); \ -INSTANTIATE_FOREACH_INTYPE(int16_t); \ -INSTANTIATE_FOREACH_INTYPE(int32_t); \ -INSTANTIATE_FOREACH_INTYPE(int64_t); \ -INSTANTIATE_FOREACH_INTYPE(float); \ -INSTANTIATE_FOREACH_INTYPE(double); \ +#define INSTANTIATE_FOREACH_INTYPE(Out) \ + INSTANTIATE_IMPL(Out, bool); \ + INSTANTIATE_IMPL(Out, uint8_t); \ + INSTANTIATE_IMPL(Out, uint16_t); \ + INSTANTIATE_IMPL(Out, uint32_t); \ + INSTANTIATE_IMPL(Out, uint64_t); \ + INSTANTIATE_IMPL(Out, int8_t); \ + INSTANTIATE_IMPL(Out, int16_t); \ + INSTANTIATE_IMPL(Out, int32_t); \ + INSTANTIATE_IMPL(Out, int64_t); \ + INSTANTIATE_IMPL(Out, float); \ + INSTANTIATE_IMPL(Out, double); \ + INSTANTIATE_IMPL(Out, dali::float16); \ + INSTANTIATE_IMPL(Out, DALIDataType); \ + INSTANTIATE_IMPL(Out, DALIImageType); \ + INSTANTIATE_IMPL(Out, DALIInterpType); + +INSTANTIATE_FOREACH_INTYPE(bool); +INSTANTIATE_FOREACH_INTYPE(uint8_t); +INSTANTIATE_FOREACH_INTYPE(uint16_t); +INSTANTIATE_FOREACH_INTYPE(uint32_t); +INSTANTIATE_FOREACH_INTYPE(uint64_t); +INSTANTIATE_FOREACH_INTYPE(int8_t); +INSTANTIATE_FOREACH_INTYPE(int16_t); +INSTANTIATE_FOREACH_INTYPE(int32_t); +INSTANTIATE_FOREACH_INTYPE(int64_t); +INSTANTIATE_FOREACH_INTYPE(float); +INSTANTIATE_FOREACH_INTYPE(double); INSTANTIATE_FOREACH_INTYPE(dali::float16); +INSTANTIATE_FOREACH_INTYPE(DALIDataType); +INSTANTIATE_FOREACH_INTYPE(DALIImageType); +INSTANTIATE_FOREACH_INTYPE(DALIInterpType); } // namespace cast } // namespace kernels diff --git a/dali/operators/generic/cast.h b/dali/operators/generic/cast.h index 7df38001056..4057dbbe405 100644 --- a/dali/operators/generic/cast.h +++ b/dali/operators/generic/cast.h @@ -19,13 +19,14 @@ #include "dali/core/convert.h" #include "dali/core/tensor_shape.h" +#include "dali/pipeline/data/types.h" #include "dali/pipeline/operator/checkpointing/stateless_operator.h" namespace dali { #define CAST_ALLOWED_TYPES \ (bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float16, float, \ - double) + double, DALIDataType, DALIImageType, DALIInterpType) template class Cast : public StatelessOperator { @@ -54,6 +55,11 @@ class Cast : public StatelessOperator { bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { const auto &input = ws.Input(0); DALIDataType out_type = is_cast_like_ ? ws.GetInputDataType(1) : dtype_arg_; + DALI_ENFORCE(!(IsEnum(input.type()) && IsFloatingPoint(out_type) || + IsEnum(out_type) && IsFloatingPoint(input.type())), + make_string("Cannot cast from ", input.type(), " to ", out_type, + ". Enums can only participate in casts with integral types, " + "but not floating point types.")); output_desc.resize(1); output_desc[0].shape = input.shape(); output_desc[0].type = out_type; diff --git a/dali/operators/generic/constant.h b/dali/operators/generic/constant.h index 6200f30281b..3b6f86d77e9 100644 --- a/dali/operators/generic/constant.h +++ b/dali/operators/generic/constant.h @@ -22,8 +22,9 @@ #include "dali/core/tensor_view.h" #include "dali/core/static_switch.h" -#define CONSTANT_OP_SUPPORTED_TYPES \ - (bool, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, float, float16) +#define CONSTANT_OP_SUPPORTED_TYPES \ + (bool, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t, float, float16, \ + DALIDataType, DALIImageType, DALIInterpType) namespace dali { diff --git a/dali/operators/random/choice.h b/dali/operators/random/choice.h index 50e76a04b21..ed5ffd279ef 100644 --- a/dali/operators/random/choice.h +++ b/dali/operators/random/choice.h @@ -29,7 +29,7 @@ uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t #define DALI_CHOICE_1D_TYPES \ bool, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t, float16, float, \ - double + double, DALIDataType, DALIImageType, DALIInterpType namespace dali { diff --git a/dali/operators/random/choice_cpu.cc b/dali/operators/random/choice_cpu.cc index 1ec6ee94321..c1530cb3325 100644 --- a/dali/operators/random/choice_cpu.cc +++ b/dali/operators/random/choice_cpu.cc @@ -31,6 +31,9 @@ a single value per sample is generated. The type of the output matches the type of the input. For scalar inputs, only integral types are supported, otherwise any type can be used. +The operator supports selection from an input containing elements of one of DALI enum types, +that is: :meth:`nvidia.dali.types.DALIDataType`, :meth:`nvidia.dali.types.DALIImageType`, or +:meth:`nvidia.dali.types.DALIInterpType`. )code") .NumInput(1, 2) .InputDox(0, "a", "scalar or TensorList", diff --git a/dali/pipeline/data/types.h b/dali/pipeline/data/types.h index 4c77204d59f..79c5c4bdce6 100644 --- a/dali/pipeline/data/types.h +++ b/dali/pipeline/data/types.h @@ -277,6 +277,18 @@ constexpr bool IsUnsigned(DALIDataType type) { } } + +constexpr bool IsEnum(DALIDataType type) { + switch (type) { + case DALI_DATA_TYPE: + case DALI_IMAGE_TYPE: + case DALI_INTERP_TYPE: + return true; + default: + return false; + } +} + template struct id2type_helper; diff --git a/dali/python/nvidia/dali/_backend_enums.pyi b/dali/python/nvidia/dali/_backend_enums.pyi index 351a40c0a9d..a11829490c8 100644 --- a/dali/python/nvidia/dali/_backend_enums.pyi +++ b/dali/python/nvidia/dali/_backend_enums.pyi @@ -29,6 +29,9 @@ class DALIDataType(Enum): FLOAT64 = ... BOOL = ... STRING = ... + DATA_TYPE = ... + IMAGE_TYPE = ... + INTERP_TYPE = ... class DALIImageType(Enum): RGB = ... diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 67a5d3a1523..7c7ba2d49d2 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -519,6 +519,8 @@ def _type_from_value_or_list(v): has_floats = False has_ints = False has_bools = False + has_enums = False + enum_type = None for x in v: if isinstance(x, float): has_floats = True @@ -526,9 +528,36 @@ def _type_from_value_or_list(v): has_bools = True elif isinstance(x, int): has_ints = True + elif isinstance(x, (DALIDataType, DALIImageType, DALIInterpType)): + has_enums = True + enum_type = type(x) + break else: raise TypeError("Unexpected type: " + str(type(x))) + if has_enums: + for x in v: + if not isinstance(x, enum_type): + raise TypeError( + f"Expected all elements of the input to be the " + f"same enum type: `{enum_type.__name__}` but got `{type(x).__name__}` " + f"for one of the elements." + ) + + if has_enums: + if issubclass(enum_type, DALIDataType): + return DALIDataType.DATA_TYPE + elif issubclass(enum_type, DALIImageType): + return DALIDataType.IMAGE_TYPE + elif issubclass(enum_type, DALIInterpType): + return DALIDataType.INTERP_TYPE + else: + raise TypeError( + f"Unexpected enum type: `{enum_type.__name__}`, expected one of: " + "`nvidia.dali.types.DALIDataType`, `nvidia.dali.types.DALIImageType`, " + "or `nvidia.dali.types.DALIInterpType`." + ) + if has_floats: return DALIDataType.FLOAT if has_ints: @@ -582,7 +611,8 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs): Args ---- - value: `bool`, `int`, `float`, a `list` or `tuple` thereof or a `numpy.ndarray` + value: `bool`, `int`, `float`, `DALIDataType` `DALIImageType`, `DALIInterpType`, + a `list` or `tuple` thereof or a `numpy.ndarray` The constant value to wrap. If it is a scalar, it can be used as scalar value in mathematical expressions. Otherwise, it will produce a constant tensor node (optionally reshaped according to `shape` argument). diff --git a/dali/util/pybind.h b/dali/util/pybind.h index b33333eec66..7e30a044e92 100644 --- a/dali/util/pybind.h +++ b/dali/util/pybind.h @@ -22,6 +22,7 @@ #include #include "dali/pipeline/data/types.h" #include "dali/pipeline/data/dltensor.h" +#include "dali/pipeline/operator/error_reporting.h" namespace dali { @@ -54,6 +55,19 @@ static std::string FormatStrFromType(DALIDataType type) { return "=d"; case DALI_BOOL: return "=?"; + case DALI_DATA_TYPE: + case DALI_IMAGE_TYPE: + case DALI_INTERP_TYPE: + throw DaliTypeError( + "DALI enum types cannot be used with buffer protocol " + "when they are returned as Tensors or TensorLists from DALI pipeline." + "You can use `nvidia.dali.fn.cast` to convert those values to an integral type."); + // As an alternative, to allow the usage of tensors containing DALI enums (printing, use with + // buffer protocol, numpy conversion etc), we can return format specifier for the underlying + // type here. This would allow access to the actual numeric values, for example: + // case DALI_DATA_TYPE: + // return + // FormatStrFromType(TypeTable::GetTypeInfo>().id()); default: break; } diff --git a/docs/data_types.rst b/docs/data_types.rst index 41017402228..1badaa674e5 100644 --- a/docs/data_types.rst +++ b/docs/data_types.rst @@ -97,9 +97,10 @@ DALIDataType :member-order: bysource :exclude-members: name +.. autofunction:: to_numpy_type + DALIIterpType ^^^^^^^^^^^^^ -.. autofunction:: to_numpy_type .. autoenum:: DALIInterpType :members: :undoc-members: diff --git a/include/dali/core/convert.h b/include/dali/core/convert.h index 6ebc2fe8089..210b47384cd 100644 --- a/include/dali/core/convert.h +++ b/include/dali/core/convert.h @@ -233,19 +233,23 @@ inline __device__ unsigned long cuda_round_helper(double f, unsigned long) { // template ::value, - bool InIsFP = is_fp_or_half::value> + bool InIsFP = is_fp_or_half::value, + bool OutIsEnum = std::is_enum::value, + bool InIsEnum = std::is_enum::value> struct ConverterBase; template struct Converter : ConverterBase { - static_assert(is_arithmetic_or_half::value && is_arithmetic_or_half::value, - "Default ConverterBase can only be used with arithmetic types. For custom types, " - "specialize or overload dali::Convert"); + static_assert( + (is_arithmetic_or_half::value || std::is_enum::value) && + (is_arithmetic_or_half::value || std::is_enum::value), + "Default ConverterBase can only be used with arithmetic or enum types. For custom types, " + "specialize or overload dali::Convert"); }; /// Converts between two FP types template -struct ConverterBase { +struct ConverterBase { DALI_HOST_DEV static constexpr Out Convert(In value) { return value; } DALI_HOST_DEV @@ -258,7 +262,7 @@ struct ConverterBase { /// Converts integral to FP type template -struct ConverterBase { +struct ConverterBase { DALI_HOST_DEV static constexpr Out Convert(In value) { return value; } DALI_HOST_DEV @@ -272,7 +276,7 @@ struct ConverterBase { /// Converts integral to float16 special case template -struct ConverterBase { +struct ConverterBase { DALI_HOST_DEV static constexpr float16 Convert(In value) { auto out = ConverterBase::Convert(value); @@ -300,7 +304,7 @@ struct ConverterBase { /// Converts FP to integral type template -struct ConverterBase { +struct ConverterBase { DALI_HOST_DEV static constexpr Out Convert(In value) { #ifdef __CUDA_ARCH__ @@ -383,9 +387,87 @@ struct ConvertIntInt { /// Converts between integral types template -struct ConverterBase : ConvertIntInt { +struct ConverterBase : ConvertIntInt { static_assert(std::is_arithmetic::value && std::is_arithmetic::value, - "Default ConverterBase can only be used with arithmetic types. For custom types, " + "Default ConverterBase can only be used with arithmetic or enum types. For custom types, " + "specialize or overload dali::Convert"); +}; + +template +struct UnderlyingOrType; + +template +struct UnderlyingOrType::value>> { + using type = typename std::underlying_type::type; +}; + +template +struct UnderlyingOrType::value>> { + using type = T; +}; + + +template +using underlying_or_type_t = typename UnderlyingOrType::type; + + +/// Convert between arithmetic and enum types +template +struct ConverterEnum { + using UnderlyingOut = underlying_or_type_t; + using UnderlyingIn = underlying_or_type_t; + using Conv = Converter; + + DALI_HOST_DEV + static constexpr UnderlyingIn to_underlying(In value) noexcept { + return static_cast(value); + } + + DALI_HOST_DEV + static constexpr Out from_underlying(UnderlyingOut value) noexcept { + return static_cast(value); + } + + DALI_HOST_DEV + static constexpr Out Convert(In value) { + return from_underlying(Conv::Convert(to_underlying(value))); + } + + DALI_HOST_DEV + static constexpr Out ConvertNorm(In value) { + return from_underlying(Conv::ConvertNorm(to_underlying(value))); + } + + DALI_HOST_DEV + static constexpr Out ConvertSat(In value) { + return from_underlying(Conv::ConvertSat(to_underlying(value))); + } + + DALI_HOST_DEV + static constexpr Out ConvertSatNorm(In value) { + return from_underlying(Conv::ConvertSatNorm(to_underlying(value))); + } +}; + +/// Convert from integral to enum type +template +struct ConverterBase : ConverterEnum { + static_assert(std::is_enum::value && is_arithmetic_or_half::value, + "Default ConverterBase can only be used with arithmetic or enum types. For custom types, " + "specialize or overload dali::Convert"); +}; + +template +struct ConverterBase : ConverterEnum { + static_assert(is_arithmetic_or_half::value && std::is_enum::value, + "Default ConverterBase can only be used with arithmetic or enum types. For custom types, " + "specialize or overload dali::Convert"); +}; + +template +struct ConverterBase : ConverterEnum { + static_assert(std::is_enum::value && std::is_enum::value, + "Default ConverterBase can only be used with arithmetic or enum types. For custom types, " "specialize or overload dali::Convert"); }; From e6e0e2c41a031dd246b09eb30a1756d6d1c22c97 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Thu, 11 Apr 2024 17:14:34 +0200 Subject: [PATCH 2/7] Add initial test Signed-off-by: Krzysztof Lecki --- .../test/python/operator_2/test_enum_types.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 dali/test/python/operator_2/test_enum_types.py diff --git a/dali/test/python/operator_2/test_enum_types.py b/dali/test/python/operator_2/test_enum_types.py new file mode 100644 index 00000000000..e9a7867a28b --- /dev/null +++ b/dali/test/python/operator_2/test_enum_types.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvidia.dali import fn, pipeline_def, types + +import numpy as np + + +def test_enum_constant_capture(): + batch_size = 2 + + scalar_v = types.DALIDataType.INT16 + list_v = [ + types.DALIInterpType.INTERP_CUBIC, + types.DALIInterpType.INTERP_GAUSSIAN, + types.DALIInterpType.INTERP_LANCZOS3, + ] + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def enum_constant_pipe(): + scalar = fn.copy(scalar_v) + tensor = fn.copy(list_v) + + scalar_as_int = fn.cast(scalar, dtype=types.DALIDataType.INT32) + tensor_as_int = fn.cast(tensor, dtype=types.DALIDataType.INT32) + return scalar, tensor, scalar_as_int, tensor_as_int + + pipe = enum_constant_pipe() + pipe.build() + # Compare the cast values with Python values + scalar, tensor, scalar_as_int, tensor_as_int = pipe.run() + assert scalar.dtype == types.DALIDataType.DATA_TYPE + assert scalar.shape() == [()] * batch_size, f"{scalar.shape}" + assert tensor.dtype == types.DALIDataType.INTERP_TYPE + assert tensor.shape() == [(3,)] * batch_size + for i in range(batch_size): + assert np.array_equal(np.array(scalar_as_int[i]), np.array(scalar_v.value)) + assert np.array_equal( + np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v], dtype=np.int32) + ) From 3f40adc613a8802c26d02b1272b099a64f431681 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Fri, 12 Apr 2024 16:36:05 +0200 Subject: [PATCH 3/7] Adjust repr of Tensor and TensorList so it doesn't touch the data Signed-off-by: Krzysztof Lecki --- dali/python/backend_impl.cc | 12 +++-- dali/python/nvidia/dali/_debug_mode.py | 2 +- dali/python/nvidia/dali/tensors.py | 46 ++++++++++++++----- .../test/python/operator_2/test_enum_types.py | 1 + dali/util/pybind.h | 2 +- 5 files changed, 45 insertions(+), 18 deletions(-) diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 40db33d8870..4e64f72c8c9 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -627,7 +627,7 @@ void ExposeTensor(py::module &m) { return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t); }) .def("__repr__", [](Tensor &t) { - return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t); + return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t, false); }) .def_property("__array_interface__", &ArrayInterfaceRepr, nullptr, R"code( @@ -777,7 +777,7 @@ void ExposeTensor(py::module &m) { return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t); }) .def("__repr__", [](Tensor &t) { - return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t); + return FromPythonTrampoline("nvidia.dali.tensors", "_tensor_to_string")(t, false); }) .def_property("__cuda_array_interface__", &ArrayInterfaceRepr, nullptr, R"code( @@ -1193,7 +1193,11 @@ void ExposeTensorList(py::module &m) { return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t); }) .def("__repr__", [](TensorList &t) { - return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t); + // Repr might be used in exceptions and the data might not be possible to be represented + // (DALI enums do not support buffer protocol due to difference between C++ numeric + // representation and Python "O" - object/pointer-based representation). + // That why we skip the data part. + return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t, false); }) .def_property_readonly("dtype", [](TensorList &tl) { return tl.type(); @@ -1393,7 +1397,7 @@ void ExposeTensorList(py::module &m) { return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t); }) .def("__repr__", [](TensorList &t) { - return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t); + return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t, false); }) .def_property_readonly("dtype", [](TensorList &tl) { return tl.type(); diff --git a/dali/python/nvidia/dali/_debug_mode.py b/dali/python/nvidia/dali/_debug_mode.py index 7f93c9c6eef..482784333c0 100644 --- a/dali/python/nvidia/dali/_debug_mode.py +++ b/dali/python/nvidia/dali/_debug_mode.py @@ -45,7 +45,7 @@ def __str__(self): indent = " " * 4 return ( f'DataNodeDebug(\n{indent}name="{self.name}",\n{indent}data=' - + f'{_tensors._tensorlist_to_string(self._data, indent + " " * 5)})' + + f'{_tensors._tensorlist_to_string(self._data, indent=indent + " " * 5)})' ) __repr__ = __str__ diff --git a/dali/python/nvidia/dali/tensors.py b/dali/python/nvidia/dali/tensors.py index b0bcda067fc..a985c28725e 100644 --- a/dali/python/nvidia/dali/tensors.py +++ b/dali/python/nvidia/dali/tensors.py @@ -50,35 +50,56 @@ def import_numpy(): ) -def _tensor_to_string(self): - """Returns string representation of Tensor.""" +def _tensor_to_string(self, show_data=True): + """Returns string representation of Tensor. + + Parameters + ---------- + show_data : bool, optional + Access and format the underlying data, by default True + """ import_numpy() type_name = type(self).__name__ indent = " " * 4 layout = self.layout() - data = np.array(_transfer_to_cpu(self, type_name[-3:])) - data_str = np.array2string(data, prefix=indent, edgeitems=2) + if show_data: + data = np.array(_transfer_to_cpu(self, type_name[-3:])) + data_str = np.array2string(data, prefix=indent, edgeitems=2) params = ( - [f"{type_name}(\n{indent}{data_str}", f"dtype={self.dtype}"] + ([f"{data_str}"] if show_data else []) + + [f"dtype={self.dtype}"] + ([f"layout={layout}"] if layout else []) + [f"shape={self.shape()})"] ) - return _join_string(params, False, 0, ",\n" + indent) + return f"{type_name}(\n{indent}" + _join_string(params, False, 0, ",\n" + indent) -def _tensorlist_to_string(self, indent=""): - """Returns string representation of TensorList.""" +def _tensorlist_to_string(self, show_data=True, indent=""): + """Returns string representation of TensorList. + + Parameters + ---------- + show_data : bool, optional + Access and format the underlying data, by default True + indent : str, optional + optional indentation used in formatting, by default "" + """ import_numpy() edgeitems = 2 spaces_indent = indent + " " * 4 type_name = type(self).__name__ layout = self.layout() - data = _transfer_to_cpu(self, type_name[-3:]) - data_str = "[]" + if show_data: + data = _transfer_to_cpu(self, type_name[-3:]) + data_str = "[]" + else: + data = None + data_str = "" + crop = False if data: @@ -120,9 +141,10 @@ def _tensorlist_to_string(self, indent=""): ) params = ( - [f"{type_name}(\n{spaces_indent}{data_str}", f"dtype={self.dtype}"] + ([f"{data_str}"] if show_data else []) + + [f"dtype={self.dtype}"] + ([f'layout="{layout}"'] if layout else []) + [f"num_samples={len(self)}", f"{shape_prefix}{shape_str}])"] ) - return _join_string(params, False, 0, ",\n" + spaces_indent) + return f"{type_name}(\n{spaces_indent}" + _join_string(params, False, 0, ",\n" + spaces_indent) diff --git a/dali/test/python/operator_2/test_enum_types.py b/dali/test/python/operator_2/test_enum_types.py index e9a7867a28b..b1dcba76404 100644 --- a/dali/test/python/operator_2/test_enum_types.py +++ b/dali/test/python/operator_2/test_enum_types.py @@ -43,6 +43,7 @@ def enum_constant_pipe(): assert scalar.dtype == types.DALIDataType.DATA_TYPE assert scalar.shape() == [()] * batch_size, f"{scalar.shape}" assert tensor.dtype == types.DALIDataType.INTERP_TYPE + print(tensor.shape) assert tensor.shape() == [(3,)] * batch_size for i in range(batch_size): assert np.array_equal(np.array(scalar_as_int[i]), np.array(scalar_v.value)) diff --git a/dali/util/pybind.h b/dali/util/pybind.h index 7e30a044e92..c4ec541edd1 100644 --- a/dali/util/pybind.h +++ b/dali/util/pybind.h @@ -60,7 +60,7 @@ static std::string FormatStrFromType(DALIDataType type) { case DALI_INTERP_TYPE: throw DaliTypeError( "DALI enum types cannot be used with buffer protocol " - "when they are returned as Tensors or TensorLists from DALI pipeline." + "when they are returned as Tensors or TensorLists from DALI pipeline. " "You can use `nvidia.dali.fn.cast` to convert those values to an integral type."); // As an alternative, to allow the usage of tensors containing DALI enums (printing, use with // buffer protocol, numpy conversion etc), we can return format specifier for the underlying From e3e9be0abb9fc5655779ebb5b8c86921bcc96e7d Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Fri, 12 Apr 2024 18:19:28 +0200 Subject: [PATCH 4/7] Use only Constant nodes for enums, add more tests Signed-off-by: Krzysztof Lecki --- dali/python/nvidia/dali/types.py | 3 ++ .../test/python/operator_2/test_enum_types.py | 41 +++++++++++++++---- .../python/operator_2/test_random_choice.py | 36 ++++++++++++++++ 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 7c7ba2d49d2..84ab2547907 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -640,6 +640,9 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs): device is not None or (_is_compatible_array_type(value) and not _is_true_scalar(value)) or isinstance(value, (list, tuple)) + # we force true scalar enums through a Constant node rather than using ScalarConstant + # as they do not support any arithmetic operations + or isinstance(value, (DALIDataType, DALIImageType, DALIInterpType)) or not _is_scalar_shape(shape) or kwargs or layout is not None diff --git a/dali/test/python/operator_2/test_enum_types.py b/dali/test/python/operator_2/test_enum_types.py index b1dcba76404..1e4f2a55505 100644 --- a/dali/test/python/operator_2/test_enum_types.py +++ b/dali/test/python/operator_2/test_enum_types.py @@ -15,9 +15,27 @@ from nvidia.dali import fn, pipeline_def, types import numpy as np +import tree +from nose_utils import assert_raises +from nose2.tools import params -def test_enum_constant_capture(): + +@params( + *[ + # Automatic promotion + lambda value, dtype: fn.copy(value), + # Explicit conversion to constant op + lambda value, dtype: types.Constant(value=value, dtype=dtype), + # Detection of type from value + lambda value, dtype: types.Constant(value=value), + # Explicit type when passed the underlying numeric value of the enum + lambda value, dtype: types.Constant( + value=tree.map_structure(lambda v: v.value, value), dtype=dtype + ), + ] +) +def test_enum_constant_capture(converter): batch_size = 2 scalar_v = types.DALIDataType.INT16 @@ -29,8 +47,8 @@ def test_enum_constant_capture(): @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) def enum_constant_pipe(): - scalar = fn.copy(scalar_v) - tensor = fn.copy(list_v) + scalar = converter(scalar_v, types.DALIDataType.DATA_TYPE) + tensor = converter(list_v, types.DALIDataType.INTERP_TYPE) scalar_as_int = fn.cast(scalar, dtype=types.DALIDataType.INT32) tensor_as_int = fn.cast(tensor, dtype=types.DALIDataType.INT32) @@ -38,15 +56,22 @@ def enum_constant_pipe(): pipe = enum_constant_pipe() pipe.build() - # Compare the cast values with Python values scalar, tensor, scalar_as_int, tensor_as_int = pipe.run() assert scalar.dtype == types.DALIDataType.DATA_TYPE assert scalar.shape() == [()] * batch_size, f"{scalar.shape}" assert tensor.dtype == types.DALIDataType.INTERP_TYPE - print(tensor.shape) assert tensor.shape() == [(3,)] * batch_size + # Compare the cast values with Python values for i in range(batch_size): assert np.array_equal(np.array(scalar_as_int[i]), np.array(scalar_v.value)) - assert np.array_equal( - np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v], dtype=np.int32) - ) + assert np.array_equal(np.array(tensor_as_int[i]), np.array([elem.value for elem in list_v])) + with assert_raises( + TypeError, + glob="DALI enum types cannot be used with buffer protocol*" + "use `nvidia.dali.fn.cast` to convert", + ): + print(scalar) + + +def test_scalar_constant(): + print(types.ScalarConstant(types.DALIDataType.INT16)) diff --git a/dali/test/python/operator_2/test_random_choice.py b/dali/test/python/operator_2/test_random_choice.py index ebdce522eda..6dd6ccb2ce5 100644 --- a/dali/test/python/operator_2/test_random_choice.py +++ b/dali/test/python/operator_2/test_random_choice.py @@ -252,6 +252,12 @@ def choice_pipe(): "Data type float is not supported for 0D inputs. Supported types are: " "uint8, uint16, uint32, uint64, int8, int16, int32, int64", ), + ( + (types.DALIInterpType.INTERP_CUBIC,), + {}, + "Data type DALIInterpType is not supported for 0D inputs. Supported types are: " + "uint8, uint16, uint32, uint64, int8, int16, int32, int64", + ), ( (5,), {"p": np.array([0.25, 0.5, 0.25])}, @@ -270,3 +276,33 @@ def choice_pipe(): pipe = choice_pipe() pipe.build() pipe.run() + + +def test_enum_choice(): + batch_size = 8 + + interps_to_sample = [types.DALIInterpType.INTERP_LINEAR, types.DALIInterpType.INTERP_CUBIC] + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def choice_pipeline(): + interp = fn.random.choice(interps_to_sample, shape=[100]) + interp_as_int = fn.cast(interp, dtype=types.INT32) + imgs = fn.resize( + fn.random.uniform(range=[0, 255], dtype=types.UINT8, shape=(100, 100, 3)), + size=(25, 25), + interp_type=interp[0], + ) + return interp, interp_as_int, imgs + + pipe = choice_pipeline() + pipe.build() + (interp, interp_as_int, imgs) = pipe.run() + assert interp.dtype == types.DALIDataType.INTERP_TYPE + for i in range(batch_size): + check_sample( + np.array(interp_as_int[i]), + size=(100,), + a=np.array([v.value for v in interps_to_sample]), + p=None, + idx=i, + ) From 3fdcede072526a8bf4c7cf989b567883e684937b Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Fri, 12 Apr 2024 18:45:44 +0200 Subject: [PATCH 5/7] More tests Signed-off-by: Krzysztof Lecki --- dali/python/nvidia/dali/types.py | 18 +++++++++++++--- .../test/python/operator_2/test_enum_types.py | 21 ++++++++++++++++++- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 84ab2547907..5201b3a2171 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -636,13 +636,25 @@ def Constant(value, dtype=None, shape=None, layout=None, device=None, **kwargs): and the arguments are passed to the `dali.ops.Constant` operator """ + def is_enum(value, dtype): + # we force true scalar enums through a Constant node rather than using ScalarConstant + # as they do not support any arithmetic operations + if isinstance(value, (DALIDataType, DALIImageType, DALIInterpType)): + return True + elif dtype is not None and dtype in { + DALIDataType.DATA_TYPE, + DALIDataType.IMAGE_TYPE, + DALIDataType.INTERP_TYPE, + }: + return True + else: + False + if ( device is not None or (_is_compatible_array_type(value) and not _is_true_scalar(value)) or isinstance(value, (list, tuple)) - # we force true scalar enums through a Constant node rather than using ScalarConstant - # as they do not support any arithmetic operations - or isinstance(value, (DALIDataType, DALIImageType, DALIInterpType)) + or is_enum(value, dtype) or not _is_scalar_shape(shape) or kwargs or layout is not None diff --git a/dali/test/python/operator_2/test_enum_types.py b/dali/test/python/operator_2/test_enum_types.py index 1e4f2a55505..49449fa964e 100644 --- a/dali/test/python/operator_2/test_enum_types.py +++ b/dali/test/python/operator_2/test_enum_types.py @@ -74,4 +74,23 @@ def enum_constant_pipe(): def test_scalar_constant(): - print(types.ScalarConstant(types.DALIDataType.INT16)) + with assert_raises( + TypeError, glob="Expected scalar value of type 'bool', 'int' or 'float', got *.DALIDataType" + ): + types.ScalarConstant(types.DALIDataType.INT16) + + +@params(*[(1.0, types.DALIDataType.DATA_TYPE), (types.DALIImageType.RGB, types.DALIDataType.FLOAT)]) +def test_prohibited_cast(param, dtype): + @pipeline_def(batch_size=2, device_id=0, num_threads=4) + def pipeline(): + return fn.cast(param, dtype=dtype) + + with assert_raises( + RuntimeError, + glob="Cannot cast from *float*. Enums can only participate " + "in casts with integral types, but not floating point types.", + ): + p = pipeline() + p.build() + p.run() From 50e3edd896ac7e28286df7b1ea85d64d9cd2b15f Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Fri, 12 Apr 2024 19:13:26 +0200 Subject: [PATCH 6/7] Fixup Signed-off-by: Krzysztof Lecki --- dali/python/nvidia/dali/types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 5201b3a2171..de9d1985fbc 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -647,8 +647,7 @@ def is_enum(value, dtype): DALIDataType.INTERP_TYPE, }: return True - else: - False + return False if ( device is not None From d3daacb629a3c5cc757e28e24850dd2cdf3e6a39 Mon Sep 17 00:00:00 2001 From: Krzysztof Lecki Date: Mon, 15 Apr 2024 10:27:06 +0200 Subject: [PATCH 7/7] More parentheses to make gcc happy Signed-off-by: Krzysztof Lecki --- dali/operators/generic/cast.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dali/operators/generic/cast.h b/dali/operators/generic/cast.h index 4057dbbe405..5c2063dbd01 100644 --- a/dali/operators/generic/cast.h +++ b/dali/operators/generic/cast.h @@ -55,8 +55,8 @@ class Cast : public StatelessOperator { bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { const auto &input = ws.Input(0); DALIDataType out_type = is_cast_like_ ? ws.GetInputDataType(1) : dtype_arg_; - DALI_ENFORCE(!(IsEnum(input.type()) && IsFloatingPoint(out_type) || - IsEnum(out_type) && IsFloatingPoint(input.type())), + DALI_ENFORCE(!((IsEnum(input.type()) && IsFloatingPoint(out_type)) || + (IsEnum(out_type) && IsFloatingPoint(input.type()))), make_string("Cannot cast from ", input.type(), " to ", out_type, ". Enums can only participate in casts with integral types, " "but not floating point types."));