From 6c2b2051ca1e13d7df5736530aa60f0b8d171872 Mon Sep 17 00:00:00 2001 From: The TensorFlow Datasets Authors Date: Tue, 28 May 2024 19:38:10 -0700 Subject: [PATCH] Add "display_image" feature to robotics dataset importer builder. PiperOrigin-RevId: 638110575 --- .../robotics/dataset_importer_builder.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tensorflow_datasets/robotics/dataset_importer_builder.py b/tensorflow_datasets/robotics/dataset_importer_builder.py index 162a8b69500..270504bb52f 100644 --- a/tensorflow_datasets/robotics/dataset_importer_builder.py +++ b/tensorflow_datasets/robotics/dataset_importer_builder.py @@ -54,6 +54,8 @@ class DatasetImporterBuilder( 'ssot_session_key', ] + images_from_observation_dict = {} + @abc.abstractmethod def get_description(self): @@ -83,9 +85,15 @@ def _info(self) -> tfds.core.DatasetInfo: tmp = dict(features) + # add all image features from observations to a new featuresdict + self.images_from_observation_dict = self.get_images_from_observation_dict() + if self.images_from_observation_dict: + tmp['display_image'] = self.images_from_observation_dict + for key in self.KEYS_TO_STRIP: if key in tmp: del tmp[key] + features = tfds.features.FeaturesDict(tmp) return tfds.core.DatasetInfo( @@ -120,15 +128,28 @@ def _generate_examples( def converter_fn(example): # Decode the RLDS Episode and transform it to numpy. example_out = dict(example) + example_out['steps'] = tf.data.Dataset.from_tensor_slices( example_out['steps'] ).map(decode_fn) + steps = list(iter(example_out['steps'].take(-1))) example_out['steps'] = steps example_out = dataset_utils.as_numpy(example_out) + first_step = example_out['steps'][0] + image_feature_dict = {} + + for feature_name in self.images_from_observation_dict: + image_feature_dict[feature_name] = first_step['observation'][ + feature_name + ] + + if image_feature_dict: + example_out['display_image'] = image_feature_dict example_id = example_out['tfds_id'].decode('utf-8') + del example_out['tfds_id'] for key in self.KEYS_TO_STRIP: if key in example_out: @@ -148,3 +169,17 @@ def get_ds_builder(self): ds_location = self.get_dataset_location() ds_builder = tfds.builder_from_directory(ds_location) return ds_builder + + def get_images_from_observation_dict(self): + features = self.get_ds_builder().info.features + tmp = dict(features) + images_from_observation = {} + if 'steps' in tmp and 'observation' in tmp['steps']: + observation = tmp['steps']['observation'] + for feature_name, feature_data in observation.items(): + if isinstance(feature_data, tfds.features.Image): + images_from_observation[feature_name] = feature_data + images_from_observation_dict = tfds.features.FeaturesDict( + images_from_observation + ) + return images_from_observation_dict