Skip to content

Commit

Permalink
fix(server): fix reshaping of bloom past_key_values in concatenate() (#…
Browse files Browse the repository at this point in the history
…252)

Introduced in #214 

Fixes #249
  • Loading branch information
njhill authored Apr 27, 2023
1 parent db2b4e0 commit b4cf832
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
]
elif batch.past_key_values[0][0].shape == 3:
elif len(batch.past_key_values[0][0].shape) == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
Expand Down

0 comments on commit b4cf832

Please sign in to comment.