diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 00b5b18d..39e6c8fc 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -95,9 +95,6 @@ def fit( callbacks.on_train_batch_begin(batch=batch_idx) batch = next(batch_train_ds) - - # print(jax.tree_map(lambda x: x.devices(), batch)) - ( (state, train_batch_metrics), batch_loss,