Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated the RAG training with latest Pytorch Lightning library and the RAY #15653

Merged
merged 1 commit into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/research_projects/rag/callbacks_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
every_n_epochs=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

Expand Down
4 changes: 2 additions & 2 deletions examples/research_projects/rag/finetune_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def pad(self) -> int:
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)

logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
logs = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
Expand Down Expand Up @@ -517,7 +517,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address)
ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
Expand Down
19 changes: 15 additions & 4 deletions examples/research_projects/rag/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--adafactor", action="store_true")


class InitCallback(pl.Callback):
# This method is better that using a custom DDP plugging with the latest pytorch-lightning (@shamanez)
def on_sanity_check_start(self, trainer, pl_module):
if (
trainer.is_global_zero and trainer.global_rank == 0
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.


class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
Expand Down Expand Up @@ -368,19 +377,21 @@ def generic_train(
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
# train_params["amp_level"] = args.fp16_opt_level

if args.gpus > 1:
train_params["accelerator"] = "ddp"
train_params["accelerator"] = "auto" # "ddp"
train_params["strategy"] = "ddp"

train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params["devices"] = "auto"

trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback] + [InitCallback()],
# plugins=[custom_ddp_plugin],
logger=logger,
**train_params,
)
Expand Down
5 changes: 3 additions & 2 deletions examples/research_projects/rag/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ faiss-cpu >= 1.6.3
datasets >= 1.0.1
psutil >= 5.7.0
torch >= 1.4.0
ray >= 1.10.0
pytorch-lightning >= 1.5.10
transformers
pytorch-lightning
GitPython
GitPython