Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwang04 committed Sep 14, 2024
1 parent d9278e9 commit 0efa199
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,10 @@ def __init__(

for i in range(self.split_num):
start_idx = i * split_size
if i == split_num - 1:
end_idx = self.inC
else:
end_idx = (i + 1) * split_size

end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype)

if i == 0:
res = linear_slice
else:
Expand All @@ -89,7 +84,6 @@ def run(
np.ndarray: result
"""
self.prefetchWeights(1, verify_size=False)

self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
Expand All @@ -107,10 +101,7 @@ def __init__(self, weight, bias, split_num):
for i in range(split_num):
new_linear = torch.nn.Linear(0, 0, bias=False)
start_idx = i * split_size
if i == split_num - 1:
end_idx = weight.size(1)
else:
end_idx = (i + 1) * split_size
end_idx = (i + 1) * split_size if i < split_num - 1 else weight.size(1)
new_weight = torch.nn.Parameter(weight[:, start_idx:end_idx],
requires_grad=False)
new_linear.weight = new_weight
Expand All @@ -130,18 +121,12 @@ def forward(self, hidden_states):
logits = logits.view(target_shape)
else:
split_size = hidden_states.size(-1) // self.split_num // 2 * 2

logits = None
for i in range(self.split_num):
start_idx = i * split_size
if i == self.split_num - 1:
end_idx = hidden_states.size(-1)
else:
end_idx = (i + 1) * split_size

end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
hidden_states_slice = hidden_states[:, :, start_idx:end_idx]
logits_slice = self.lm_heads[i](hidden_states_slice)

if logits is None:
logits = logits_slice
else:
Expand Down

0 comments on commit 0efa199

Please sign in to comment.