From b2723f2d0a33bad7ddbf3167ec9e93aa60af2acd Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Thu, 16 Dec 2021 14:42:02 -0800 Subject: [PATCH 1/4] make bb2 ddp compatible --- projects/blenderbot2/agents/blenderbot2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index e27a35c8b4e..530e48710fb 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -800,7 +800,10 @@ def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batc def eval_step(self, batch): output = super().eval_step(batch) - if output is None or not hasattr(self.model, 'retriever'): + model = self.model + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.model.module + if output is None or not hasattr(model, 'retriever'): return output if hasattr(self.model_api.retriever, 'top_docs'): output.top_docs = self.model_api.retriever.top_docs @@ -855,7 +858,10 @@ def compute_loss( Override Rag.compute_loss to add some additional metrics. """ loss, output = super().compute_loss(batch, return_output=True) - assert isinstance(self.model, BlenderBot2RagModel) + model = self.model + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.model.module + assert isinstance(model, BlenderBot2RagModel) if ( KnowledgeAccessMethod(self.opt['knowledge_access_method']) is KnowledgeAccessMethod.CLASSIFY From 3ba480e0d44f095c6414e6f0bc2e56078c4d43a9 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Thu, 16 Dec 2021 14:43:43 -0800 Subject: [PATCH 2/4] longfid ddp compatible --- projects/msc/agents/memory_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/msc/agents/memory_agent.py b/projects/msc/agents/memory_agent.py index 169e40287c6..ba88dd7b231 100644 --- a/projects/msc/agents/memory_agent.py +++ b/projects/msc/agents/memory_agent.py @@ -205,6 +205,10 @@ def __init__(self, opt, dictionary, retriever_shared=None): if opt.get('fid_ddp_compatible', True): for param in self.long_term_memory.query_encoder.parameters(): param.requires_grad = False + for param in self.long_term_memory.memory_encoder.parameters(): + param.requires_grad = False + for param in self.retriever.parameters(): + param.requires_grad = False class MemoryLongFidAgent(LongFidAgent, MemoryRagAgent): From e3c32276606f7fcb55383a2ad9a1b9ff8fbbf055 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Mon, 20 Dec 2021 14:59:16 -0800 Subject: [PATCH 3/4] comments --- projects/blenderbot2/agents/blenderbot2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index 530e48710fb..996222c503a 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -800,9 +800,9 @@ def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batc def eval_step(self, batch): output = super().eval_step(batch) - model = self.model - if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): - model = self.model.module + model = self.model_api + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module if output is None or not hasattr(model, 'retriever'): return output if hasattr(self.model_api.retriever, 'top_docs'): @@ -858,9 +858,9 @@ def compute_loss( Override Rag.compute_loss to add some additional metrics. """ loss, output = super().compute_loss(batch, return_output=True) - model = self.model - if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): - model = self.model.module + model = self.model_api + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module assert isinstance(model, BlenderBot2RagModel) if ( KnowledgeAccessMethod(self.opt['knowledge_access_method']) From 7c98a5ca33d4e866e4d03ae1a429e42cca467115 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Mon, 20 Dec 2021 15:54:30 -0800 Subject: [PATCH 4/4] more coments --- projects/blenderbot2/agents/blenderbot2.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index 996222c503a..f14c79ccc05 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -800,10 +800,7 @@ def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batc def eval_step(self, batch): output = super().eval_step(batch) - model = self.model_api - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model = model.module - if output is None or not hasattr(model, 'retriever'): + if output is None or not hasattr(self.model_api, 'retriever'): return output if hasattr(self.model_api.retriever, 'top_docs'): output.top_docs = self.model_api.retriever.top_docs @@ -858,10 +855,7 @@ def compute_loss( Override Rag.compute_loss to add some additional metrics. """ loss, output = super().compute_loss(batch, return_output=True) - model = self.model_api - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - model = model.module - assert isinstance(model, BlenderBot2RagModel) + assert isinstance(self.model_api, BlenderBot2RagModel) if ( KnowledgeAccessMethod(self.opt['knowledge_access_method']) is KnowledgeAccessMethod.CLASSIFY