diff --git a/memoria/memoria.py b/memoria/memoria.py index 07dcfcd..8491ad1 100644 --- a/memoria/memoria.py +++ b/memoria/memoria.py @@ -262,7 +262,7 @@ def _search_longterm_memories_with_initials( found_ltm_indices[:, depth + 1] = current_ltm_indices unreachable[index_0, current_ltm_indices] = True - return found_ltm_indices.view(batch_size, -1) + return found_ltm_indices.view(batch_size, -1).unique(dim=1) @torch.no_grad() def _select_final_ltms(self, working_memory: Engrams, found_longterm_memory_indices: torch.Tensor) -> torch.Tensor: