Skip to content

Commit

Permalink
Process TFRecord reader binding classes only when it is enabled (#5360)
Browse files Browse the repository at this point in the history
Only parse the code and create the TFRecord reader API bindings/classes when
the feature is enabled.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki authored Mar 8, 2024
1 parent 6f89fb3 commit 717d704
Showing 1 changed file with 59 additions and 58 deletions.
117 changes: 59 additions & 58 deletions dali/python/nvidia/dali/ops/_operators/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,61 +29,62 @@ def tfrecord_enabled():
return False


def _get_impl(name, schema_name, internal_schema_name):

class _TFRecordReaderImpl(
ops.python_op_factory(name, schema_name, internal_schema_name, generated=False)
):
"""custom wrappers around ops"""

def __init__(self, path, index_path, features, **kwargs):
if isinstance(path, list):
self._path = path
else:
self._path = [path]
if isinstance(index_path, list):
self._index_path = index_path
else:
self._index_path = [index_path]

kwargs.update({"path": self._path, "index_path": self._index_path})
self._features = features

super().__init__(**kwargs)

def __call__(self, *inputs, **kwargs):
feature_names = []
features = []
for feature_name, feature in self._features.items():
feature_names.append(feature_name)
features.append(feature)
if not isinstance(feature, _b.tfrecord.Feature):
raise TypeError(
"Expected `nvidia.dali.tfrecord.Feature` for the "
f'"{feature_name}", but got {type(feature)}. '
"Use `nvidia.dali.tfrecord.FixedLenFeature` or "
"`nvidia.dali.tfrecord.VarLenFeature` to define the features to extract."
)

kwargs.update({"feature_names": feature_names, "features": features})

# We won't have MIS as this op doesn't have any inputs (Reader)
linear_outputs = super().__call__(*inputs, **kwargs)
# We may have single, flattened output
if not isinstance(linear_outputs, list):
linear_outputs = [linear_outputs]
outputs = {}
for feature_name, output in zip(feature_names, linear_outputs):
outputs[feature_name] = output

return outputs

return _TFRecordReaderImpl


class TFRecordReader(_get_impl("_TFRecordReader", "TFRecordReader", "_TFRecordReader")):
pass


class TFRecord(_get_impl("_TFRecord", "readers__TFRecord", "readers___TFRecord")):
pass
if tfrecord_enabled():

def _get_impl(name, schema_name, internal_schema_name):

class _TFRecordReaderImpl(
ops.python_op_factory(name, schema_name, internal_schema_name, generated=False)
):
"""custom wrappers around ops"""

def __init__(self, path, index_path, features, **kwargs):
if isinstance(path, list):
self._path = path
else:
self._path = [path]
if isinstance(index_path, list):
self._index_path = index_path
else:
self._index_path = [index_path]

kwargs.update({"path": self._path, "index_path": self._index_path})
self._features = features

super().__init__(**kwargs)

def __call__(self, *inputs, **kwargs):
feature_names = []
features = []
for feature_name, feature in self._features.items():
feature_names.append(feature_name)
features.append(feature)
if not isinstance(feature, _b.tfrecord.Feature):
raise TypeError(
"Expected `nvidia.dali.tfrecord.Feature` for the "
f'"{feature_name}", but got {type(feature)}. '
"Use `nvidia.dali.tfrecord.FixedLenFeature` or "
"`nvidia.dali.tfrecord.VarLenFeature` "
"to define the features to extract."
)

kwargs.update({"feature_names": feature_names, "features": features})

# We won't have MIS as this op doesn't have any inputs (Reader)
linear_outputs = super().__call__(*inputs, **kwargs)
# We may have single, flattened output
if not isinstance(linear_outputs, list):
linear_outputs = [linear_outputs]
outputs = {}
for feature_name, output in zip(feature_names, linear_outputs):
outputs[feature_name] = output

return outputs

return _TFRecordReaderImpl

class TFRecordReader(_get_impl("_TFRecordReader", "TFRecordReader", "_TFRecordReader")):
pass

class TFRecord(_get_impl("_TFRecord", "readers__TFRecord", "readers___TFRecord")):
pass

0 comments on commit 717d704

Please sign in to comment.