diff --git a/projects/blenderbot2/agents/modules.py b/projects/blenderbot2/agents/modules.py index e4f12c93ad2..10dc176a1a9 100644 --- a/projects/blenderbot2/agents/modules.py +++ b/projects/blenderbot2/agents/modules.py @@ -566,9 +566,11 @@ def access_long_term_memory( indices = memory_indices.tolist() if memory_vec is not None: + # Only look in memory_vec for batch elements with memories + memory_ids = [m for m in indices if num_memories[m] > 0] memory_dict = { batch_id: memory_vec[batch_id, : num_memories[mem_id]] - for batch_id, mem_id in enumerate(indices) + for batch_id, mem_id in enumerate(memory_ids) } if memory_decoder_vec is not None: for batch_id in indices: