diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 4e15f9528a..5d6d08c6c3 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -292,6 +292,9 @@ def forward( if not self.decoders: return memory + else: + # Decoder-only + memory = src # If decoder: either use the encoder ouput, or just decode, both options are possible if len(self.decoders) > 0: @@ -300,7 +303,6 @@ def forward( for decoder in self.decoders: tgt = decoder( target=tgt, - # pyre-fixme[61]: `memory` is not always initialized here. memory=memory, input_mask=decoder_input_mask, )