You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently Merlin Models does not support data parallel training models with tf.distribute.MirroredStrategy() that allows the usage of multiple GPU devices in a single machine. This issue reports the errors we found in the investigation of the issue #744 . The error occur with both retrieval and ranking models.
We get the following errors in model.compile():
Issue [BUG] Merlin TF two-tower - error using tf.distribute.MirroredStrategy() #744 reports the exception [1] that occurs in model.compile(). This is caused because the _should_compute_train_metrics_for_batch Variable in BaseModel.compile() has synchronization=tf.VariableSynchronization.NONE (exception [1]). Changing it to tf.VariableSynchronization.AUTO avoids the exception
Then we get the exception [2] on model.compile(optimizer="adam", run_eagerly=False), because our default metrics are used when no metrics are provided in model.compile() and Keras expect them to be created in the context of the with strategy.scope(): block. If we instantiate the metrics inside the distributed strategy context and provide them as argument to compile it works (see example in the code snippet below)
Then we get the following errors in model.fit():
An exception [3] occurs because our overridden model train_step() / test_step() calls a method annotated to @tf.function (compute_metrics()). If we replace the line metrics = self.compute_metrics(outputs, training=False) by metrics = {}, we ignore this error
Then we get the exception [4] related to our DataLoader, as it seems the loaded tensors devices and model devices do not match. This exception needs further investigation.
You need to have available for the process more than 1 GPU for the errors on model.fit() to occur (e.g. CUDA_VISIBLE_DEVICES=0,1)
Steps/Code to reproduce bug
Wrap the definition any model with as in the following example with TwoTowerModel.
P.s. you can edit and existing unit test to reproduce it (e.g. test_retrieval.py -> test_two_tower_model)
It should be able to train Merlin Models using multiple GPUs when using tf.distribute.MirroredStrategy(),
Environment details
Merlin version: 22.09
Platform: Ubuntu 20.04
Python version: 3.8.10
Tensorflow version : 2.9
Hardware environment: 2x V100 GPUs 32 GB
Exceptions appendix
[1]
ValueError: `NONE` variable synchronization mode is not supported with tf.distribute strategy. Please change the `synchronization` for variable: should_compute_train_metrics_for_batch
[2]
ValueError: Metric (TopKMetricsAggregator(name=top_k_metrics_aggregator,dtype=float32,topk_metrics=[{'class_name': 'merlin.models>RecallAt', 'config': {'name': 'recall_at', 'dtype': 'float32', 'k': 10, 'pre_sorted': True}}, {'class_name': 'merlin.models>MRRAt', 'config': {'name': 'mrr_at', 'dtype': 'float32', 'k': 10, 'pre_sorted': True}}, {'class_name': 'merlin.models>NDCGAt', 'config': {'name': 'ndcg_at', 'dtype': 'float32', 'k': 10, 'pre_sorted': True}}, {'class_name': 'merlin.models>AvgPrecisionAt', 'config': {'name': 'map_at', 'dtype': 'float32', 'k': 10, 'pre_sorted': True}}, {'class_name': 'merlin.models>PrecisionAt', 'config': {'name': 'precision_at', 'dtype': 'float32', 'k': 10, 'pre_sorted': True}}])) passed to `model.compile` was created inside a different distribution strategy scope than the model. All metrics must be created in the same distribution strategy scope as the model (in this case <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f9cfc469c40>). If you pass in a string identifier for a metric to compile, the metric will automatically be created in the correct distribution strategy scope.
[3]
RuntimeError: `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.
[4]
(0) INVALID_ARGUMENT: ValueError: Array device must be same as the current device: array device = 1 while current = 0
Traceback (most recent call last):
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
ret = func(*args)
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
return func(*args, **kwargs)
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 831, in wrapped_generator
for data in generator_fn():
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 957, in generator_fn
yield x[i]
File "/home/gmoreira/projects/nvidia/nvidia_merlin/models/merlin/models/tf/loader.py", line 337, in __getitem__
return DataLoader.__next__(self)
File "/home/gmoreira/projects/nvidia/nvidia_merlin/models/merlin/models/loader/backend.py", line 356, in __next__
return self._get_next_batch()
File "/home/gmoreira/projects/nvidia/nvidia_merlin/models/merlin/models/loader/backend.py", line 385, in _get_next_batch
DataLoader.__iter__(self)
File "/home/gmoreira/projects/nvidia/nvidia_merlin/models/merlin/models/loader/backend.py", line 344, in __iter__
self._shuffle_indices()
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/gmoreira/projects/nvidia/nvidia_merlin/models/merlin/models/loader/backend.py", line 332, in _shuffle_indices
cp.random.shuffle(self.indices)
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/cupy/random/_permutations.py", line 15, in shuffle
return rs.shuffle(a)
File "/home/gmoreira/miniconda3/envs/merlin_22.07_dev/lib/python3.8/site-packages/cupy/random/_generator.py", line 1090, in shuffle
a[:] = a[self._permutation(len(a))]
File "cupy/core/core.pyx", line 1285, in cupy.core.core.ndarray.__getitem__
File "cupy/core/_routines_indexing.pyx", line 44, in cupy.core._routines_indexing._ndarray_getitem
File "cupy/core/core.pyx", line 683, in cupy.core.core.ndarray.take
File "cupy/core/_routines_indexing.pyx", line 119, in cupy.core._routines_indexing._ndarray_take
File "cupy/core/_routines_indexing.pyx", line 721, in cupy.core._routines_indexing._take
File "cupy/core/_kernel.pyx", line 777, in cupy.core._kernel.ElementwiseKernel.__call__
File "cupy/core/_kernel.pyx", line 107, in cupy.core._kernel._preprocess_args
File "cupy/core/_kernel.pyx", line 77, in cupy.core._kernel._check_array_device_id
ValueError: Array device must be same as the current device: array device = 1 while current = 0
The text was updated successfully, but these errors were encountered:
gabrielspmoreira
changed the title
[BUG] Models does not support to tf.distribute.MirroredStrategy()
[BUG] Models does not support to tf.distribute.MirroredStrategy() for data parallel training
Sep 22, 2022
@jperez999@karlhigley Do we have a known limitation that prevents Merlin Data Loader to work with tf.distribute.MirroredStrategy()?
Note: Please check exception [4] in the bug description.
(I'm not involved with the dataloaders anymore, but @benfred might know)
rnyak
changed the title
[BUG] Models does not support to tf.distribute.MirroredStrategy() for data parallel training
[Task] Models does not support to tf.distribute.MirroredStrategy() for data parallel training
Oct 3, 2022
@gabrielspmoreira per our conversation with DLFW the recommended method for parallelism in TF is still horovod. Closing this out as we're not likely to change this in the near future..
Bug description
Currently Merlin Models does not support data parallel training models with
tf.distribute.MirroredStrategy()
that allows the usage of multiple GPU devices in a single machine. This issue reports the errors we found in the investigation of the issue #744 . The error occur with both retrieval and ranking models.We get the following errors in
model.compile()
:Issue [BUG] Merlin TF two-tower - error using
tf.distribute.MirroredStrategy()
#744 reports the exception [1] that occurs inmodel.compile()
. This is caused because the_should_compute_train_metrics_for_batch
Variable inBaseModel.compile()
hassynchronization=tf.VariableSynchronization.NONE
(exception [1]). Changing it totf.VariableSynchronization.AUTO
avoids the exceptionThen we get the exception [2] on
model.compile(optimizer="adam", run_eagerly=False)
, because our default metrics are used when no metrics are provided inmodel.compile()
and Keras expect them to be created in the context of thewith strategy.scope():
block. If we instantiate the metrics inside the distributed strategy context and provide them as argument to compile it works (see example in the code snippet below)Then we get the following errors in
model.fit()
:train_step()
/test_step()
calls a method annotated to@tf.function
(compute_metrics()
). If we replace the linemetrics = self.compute_metrics(outputs, training=False)
bymetrics = {}
, we ignore this errorDataLoader
, as it seems the loaded tensors devices and model devices do not match. This exception needs further investigation.You need to have available for the process more than 1 GPU for the errors on
model.fit()
to occur (e.g.CUDA_VISIBLE_DEVICES=0,1
)Steps/Code to reproduce bug
Wrap the definition any model with as in the following example with TwoTowerModel.
P.s. you can edit and existing unit test to reproduce it (e.g.
test_retrieval.py
->test_two_tower_model
)Expected behavior
It should be able to train Merlin Models using multiple GPUs when using
tf.distribute.MirroredStrategy()
,Environment details
Exceptions appendix
[1]
[2]
[3]
[4]
The text was updated successfully, but these errors were encountered: