diff --git a/examples/research_projects/rag/callbacks_rag.py b/examples/research_projects/rag/callbacks_rag.py index e9eda20de300fc..a2d87f82247c4a 100644 --- a/examples/research_projects/rag/callbacks_rag.py +++ b/examples/research_projects/rag/callbacks_rag.py @@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric): monitor=f"val_{metric}", mode="max", save_top_k=3, - period=1, # maybe save a checkpoint every time val is run, not just end of epoch. + every_n_epochs=1, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py index a1721623dd60cc..2fd4ef7659c543 100644 --- a/examples/research_projects/rag/finetune_rag.py +++ b/examples/research_projects/rag/finetune_rag.py @@ -254,7 +254,7 @@ def pad(self) -> int: def training_step(self, batch, batch_idx) -> Dict: loss_tensors = self._step(batch) - logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)} # tokens per batch tgt_pad_token_id = ( self.tokenizer.generator.pad_token_id @@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule: raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.") # Connect to an existing Ray cluster. try: - ray.init(address=args.ray_address) + ray.init(address=args.ray_address, namespace="rag") except (ConnectionError, ValueError): logger.warning( "Connection to Ray cluster failed. Make sure a Ray" diff --git a/examples/research_projects/rag/lightning_base.py b/examples/research_projects/rag/lightning_base.py index 0d93626677cc48..1e0f67627e7c34 100644 --- a/examples/research_projects/rag/lightning_base.py +++ b/examples/research_projects/rag/lightning_base.py @@ -266,6 +266,15 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--adafactor", action="store_true") +class InitCallback(pl.Callback): + # This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez) + def on_sanity_check_start(self, trainer, pl_module): + if ( + trainer.is_global_zero and trainer.global_rank == 0 + ): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed. + pl_module.model.rag.retriever.init_retrieval() # better to use hook functions. + + class LoggingCallback(pl.Callback): def on_batch_end(self, trainer, pl_module): lr_scheduler = trainer.lr_schedulers[0]["scheduler"] @@ -368,19 +377,21 @@ def generic_train( # TODO: remove with PyTorch 1.6 since pl uses native amp if args.fp16: train_params["precision"] = 16 - train_params["amp_level"] = args.fp16_opt_level + # train_params["amp_level"] = args.fp16_opt_level if args.gpus > 1: - train_params["accelerator"] = "ddp" + train_params["accelerator"] = "auto" # "ddp" + train_params["strategy"] = "ddp" train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs + train_params["devices"] = "auto" trainer = pl.Trainer.from_argparse_args( args, weights_summary=None, - callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback], - plugins=[custom_ddp_plugin], + callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback] + [InitCallback()], + # plugins=[custom_ddp_plugin], logger=logger, **train_params, ) diff --git a/examples/research_projects/rag/requirements.txt b/examples/research_projects/rag/requirements.txt index 652821a216cbe1..fdeb5567d24d55 100644 --- a/examples/research_projects/rag/requirements.txt +++ b/examples/research_projects/rag/requirements.txt @@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3 datasets >= 1.0.1 psutil >= 5.7.0 torch >= 1.4.0 +ray >= 1.10.0 +pytorch-lightning >= 1.5.10 transformers -pytorch-lightning -GitPython +GitPython \ No newline at end of file