diff --git a/rocAL/rocAL_pybind/amd/rocal/plugin/tf.py b/rocAL/rocAL_pybind/amd/rocal/plugin/tf.py index b7e45785cb..bf725d1832 100644 --- a/rocAL/rocAL_pybind/amd/rocal/plugin/tf.py +++ b/rocAL/rocAL_pybind/amd/rocal/plugin/tf.py @@ -49,9 +49,16 @@ def __init__(self, pipeline, tensor_layout = types.NCHW, reverse_channels = Fals self.loader._name = self.loader._reader color_format = b.getOutputColorFormat(self.loader._handle) self.p = (1 if (color_format == int(types.GRAY)) else 3) - - self.out = np.zeros(( self.bs*self.n, self.p, int(self.h/self.bs), self.w,), dtype = "uint8") - + if self.tensor_dtype == types.FLOAT: + data_type="float32" + elif self.tensor_dtype == types.FLOAT16: + data_type="float16" + + if(types.NHWC == self.tensor_format): + self.out = np.zeros(( self.bs*self.n, int(self.h/self.bs), self.w, self.p), dtype = data_type) + else: + self.out = np.zeros(( self.bs*self.n, self.p, int(self.h/self.bs), self.w), dtype = data_type) + def next(self): return self.__next__() @@ -68,8 +75,11 @@ def __next__(self): if self.loader.run() != 0: self.reset() raise StopIteration - - self.loader.copyImage(self.out) + + if(types.NCHW == self.tensor_format): + self.loader.copyToTensorNCHW(self.out, self.multiplier, self.offset, self.reverse_channels, int(self.tensor_dtype)) + else: + self.loader.copyToTensorNHWC(self.out, self.multiplier, self.offset, self.reverse_channels, int(self.tensor_dtype)) if(self.loader._name == "TFRecordReaderDetection"): self.bbox_list =[] diff --git a/rocAL/rocAL_pybind/example/new_api/tf_petsTrainingExample/train_withROCAL_withTFRecordReader.py b/rocAL/rocAL_pybind/example/new_api/tf_petsTrainingExample/train_withROCAL_withTFRecordReader.py index 50a672d744..d715299efe 100755 --- a/rocAL/rocAL_pybind/example/new_api/tf_petsTrainingExample/train_withROCAL_withTFRecordReader.py +++ b/rocAL/rocAL_pybind/example/new_api/tf_petsTrainingExample/train_withROCAL_withTFRecordReader.py @@ -4,9 +4,10 @@ import amd.rocal.fn as fn import amd.rocal.types as types -import tensorflow as tf +import tensorflow.compat.v1 as tf tf.compat.v1.disable_v2_behavior() + import numpy as np import tensorflow_hub as hub @@ -104,7 +105,7 @@ def main(): } - trainPipe = Pipeline(batch_size=TRAIN_BATCH_SIZE, num_threads=1, rocal_cpu=RUN_ON_HOST) + trainPipe = Pipeline(batch_size=TRAIN_BATCH_SIZE, num_threads=1, rocal_cpu=RUN_ON_HOST, tensor_layout = types.NHWC) with trainPipe: inputs = fn.readers.tfrecord(path=TRAIN_RECORDS_DIR, index_path = "", reader_type=TFRecordReaderType, user_feature_key_map=featureKeyMap, features={ @@ -117,11 +118,17 @@ def main(): images = fn.decoders.image(jpegs, user_feature_key_map=featureKeyMap, output_type=types.RGB, path=TRAIN_RECORDS_DIR) resized = fn.resize(images, resize_x=crop_size[0], resize_y=crop_size[1]) flip_coin = fn.random.coin_flip(probability=0.5) - cmn_images = fn.crop_mirror_normalize(resized, crop=(crop_size[1], crop_size[0]), mean=[0,0,0], std=[255,255,255], mirror=flip_coin, output_dtype=types.FLOAT, output_layout=types.NCHW, pad_output=False) + cmn_images = fn.crop_mirror_normalize(resized, crop=(crop_size[1], crop_size[0]), + mean=[0,0,0], + std=[255,255,255], + mirror=flip_coin, + output_dtype=types.FLOAT, + output_layout=types.NHWC, + pad_output=False) trainPipe.set_outputs(cmn_images) trainPipe.build() - valPipe = Pipeline(batch_size=TRAIN_BATCH_SIZE, num_threads=1, rocal_cpu=RUN_ON_HOST) + valPipe = Pipeline(batch_size=TRAIN_BATCH_SIZE, num_threads=1, rocal_cpu=RUN_ON_HOST, tensor_layout = types.NHWC) with valPipe: inputs = fn.readers.tfrecord(path=VAL_RECORDS_DIR, index_path = "", reader_type=TFRecordReaderType, user_feature_key_map=featureKeyMap, features={ @@ -134,7 +141,13 @@ def main(): images = fn.decoders.image(jpegs, user_feature_key_map=featureKeyMap, output_type=types.RGB, path=VAL_RECORDS_DIR) resized = fn.resize(images, resize_x=crop_size[0], resize_y=crop_size[1]) flip_coin = fn.random.coin_flip(probability=0.5) - cmn_images = fn.crop_mirror_normalize(resized, crop=(crop_size[1], crop_size[0]), mean=[0,0,0], std=[255,255,255], mirror=flip_coin, output_dtype=types.FLOAT, output_layout=types.NCHW, pad_output=False) + cmn_images = fn.crop_mirror_normalize(resized, crop=(crop_size[1], crop_size[0]), + mean=[0,0,0], + std=[255,255,255], + mirror=flip_coin, + output_dtype=types.FLOAT, + output_layout=types.NHWC, + pad_output=False) valPipe.set_outputs(cmn_images) valPipe.build() @@ -148,11 +161,10 @@ def main(): while i < NUM_TRAIN_STEPS: for t, (train_image_ndArray, train_label_ndArray) in enumerate(trainIterator, 0): - train_image_ndArray_transposed = np.transpose(train_image_ndArray, [0, 2, 3, 1]) train_label_one_hot_list = get_label_one_hot(train_label_ndArray) train_loss, _, train_accuracy = sess.run( [cross_entropy_mean, train_op, accuracy], - feed_dict={decoded_images: train_image_ndArray_transposed, labels: train_label_one_hot_list}) + feed_dict={decoded_images: train_image_ndArray, labels: train_label_one_hot_list}) print ("Step :: %s\tTrain Loss :: %.2f\tTrain Accuracy :: %.2f%%\t" % (i, train_loss, (train_accuracy * 100))) is_final_step = (i == (NUM_TRAIN_STEPS - 1)) if i % EVAL_EVERY == 0 or is_final_step: @@ -160,11 +172,10 @@ def main(): mean_loss = 0 print("\n\n-------------------------------------------------------------------------------- BEGIN VALIDATION --------------------------------------------------------------------------------") for j, (val_image_ndArray, val_label_ndArray) in enumerate(valIterator, 0): - val_image_ndArray_transposed = np.transpose(val_image_ndArray, [0, 2, 3, 1]) val_label_one_hot_list = get_label_one_hot(val_label_ndArray) val_loss, val_accuracy, val_prediction, val_target, correct_predicate = sess.run( [cross_entropy_mean, accuracy, prediction, correct_label, correct_prediction], - feed_dict={decoded_images: val_image_ndArray_transposed, labels: val_label_one_hot_list}) + feed_dict={decoded_images: val_image_ndArray, labels: val_label_one_hot_list}) mean_acc += val_accuracy mean_loss += val_loss num_correct_predicate = 0