diff --git a/cpp/nvtabular/inference/categorify.cc b/cpp/nvtabular/inference/categorify.cc index 9ec8a285ba..d0e1f6496d 100644 --- a/cpp/nvtabular/inference/categorify.cc +++ b/cpp/nvtabular/inference/categorify.cc @@ -325,6 +325,11 @@ namespace nvtabular { py::object supports = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("Supports"); return supports.attr("CPU_DICT_ARRAY"); + }) + .def_property_readonly("supported_formats", [](py::object self) + { + py::object supported = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("DataFormats"); + return supported.attr("NUMPY_DICT_ARRAY"); }); } } // namespace inference diff --git a/cpp/nvtabular/inference/fill.cc b/cpp/nvtabular/inference/fill.cc index 03cee09df9..a9b3342f5d 100644 --- a/cpp/nvtabular/inference/fill.cc +++ b/cpp/nvtabular/inference/fill.cc @@ -115,6 +115,11 @@ namespace nvtabular py::object supports = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("Supports"); return supports.attr("CPU_DICT_ARRAY"); }) + .def_property_readonly("supported_formats", [](py::object self) + { + py::object supported = py::module_::import("nvtabular").attr("graph").attr("base_operator").attr("DataFormats"); + return supported.attr("NUMPY_DICT_ARRAY"); + }) .def("transform", &FillTransform::transform); } } // namespace inference diff --git a/nvtabular/ops/normalize.py b/nvtabular/ops/normalize.py index 614a93e7e1..01b021119c 100644 --- a/nvtabular/ops/normalize.py +++ b/nvtabular/ops/normalize.py @@ -23,7 +23,7 @@ flatten_list_column_values, is_list_dtype, ) -from merlin.dag import Supports +from merlin.dag import DataFormats, Supports from merlin.schema import Tags from nvtabular.ops.moments import _custom_moments from nvtabular.ops.operator import ColumnSelector, Operator @@ -98,6 +98,15 @@ def supports(self): | Supports.GPU_DATAFRAME ) + @property + def supported_formats(self): + return ( + DataFormats.PANDAS_DATAFRAME + | DataFormats.CUDF_DATAFRAME + | DataFormats.NUMPY_DICT_ARRAY + | DataFormats.CUPY_DICT_ARRAY + ) + def clear(self): self.means = {} self.stds = {} @@ -181,6 +190,15 @@ def supports(self): | Supports.GPU_DATAFRAME ) + @property + def supported_formats(self): + return ( + DataFormats.PANDAS_DATAFRAME + | DataFormats.CUDF_DATAFRAME + | DataFormats.NUMPY_DICT_ARRAY + | DataFormats.CUPY_DICT_ARRAY + ) + @property def output_tags(self): return [Tags.CONTINUOUS] diff --git a/nvtabular/ops/operator.py b/nvtabular/ops/operator.py index c63754e9c1..fec0efa63a 100644 --- a/nvtabular/ops/operator.py +++ b/nvtabular/ops/operator.py @@ -19,7 +19,7 @@ import nvtabular as nvt from merlin.core.dispatch import DataFrameType -from merlin.dag import BaseOperator, ColumnSelector +from merlin.dag import BaseOperator, ColumnSelector, DataFormats class Operator(BaseOperator): @@ -53,3 +53,7 @@ def inference_initialize( def create_node(self, selector): return nvt.workflow.node.WorkflowNode(selector) + + @property + def supported_formats(self) -> DataFormats: + return DataFormats.PANDAS_DATAFRAME | DataFormats.CUDF_DATAFRAME