Skip to content

Commit e519621

Browse files
authored
convert : remove bug in convert.py permute function (ggml-org#3364)
1 parent ac43576 commit e519621

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def __repr__(self) -> str:
439439
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
440440
#print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) )
441441
if n_head_kv is not None and n_head != n_head_kv:
442-
n_head //= n_head_kv
442+
n_head = n_head_kv
443443
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
444444
.swapaxes(1, 2)
445445
.reshape(weights.shape))

0 commit comments

Comments
 (0)