diff --git a/projects/blenderbot2/agents/modules.py b/projects/blenderbot2/agents/modules.py index cc8924f43f9..accc5390a81 100644 --- a/projects/blenderbot2/agents/modules.py +++ b/projects/blenderbot2/agents/modules.py @@ -828,6 +828,8 @@ def __init__(self, opt: Opt, null_idx: int): class BlenderBot2FidModelMixin: embedding_size: int pad_idx: int + long_term_memory: LongTermMemory + retriever: RagRetriever def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None): super().__init__( @@ -837,6 +839,12 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None) opt, dictionary[dictionary.null_token] ) self.embedding_size = opt['embedding_size'] + for param in self.long_term_memory.query_encoder.parameters(): + param.requires_grad = False + for param in self.long_term_memory.memory_encoder.parameters(): + param.requires_grad = False + for param in self.retriever.parameters(): + param.requires_grad = False def encoder( self, @@ -935,4 +943,12 @@ class BB2SearchQueryFaissIndexRetriever( class BB2ObservationEchoRetriever(BB2SearchRetrieverMixin, ObservationEchoRetriever): """ A retriever that reads retrieved docs as part of the observed example message. + + Provides backwards compatibility with BB2 models by instantiating a query encoder. """ + + def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared=None): + super().__init__(opt, dictionary, shared) + self.query_encoder = DprQueryEncoder( + opt, dpr_model=opt['query_model'], pretrained_path=opt['dpr_model_file'] + )