diff --git a/README.md b/README.md index 7cffa872..610341ee 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,7 @@ def train_mnist(config): trainer = pl.Trainer( max_epochs=4, callbacks=callbacks, - strategy=[RayStrategy(num_workers=4, use_gpu=False)]) + strategy=RayStrategy(num_workers=4, use_gpu=False)) trainer.fit(model) config = {