Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Aug 8, 2024
1 parent 156dbc9 commit db17513
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,14 +835,6 @@ def pad_sequence(sequences: List[Tensor],
return torch.stack(padded_sequences)

def _pre_forward_hook(self, module, args, kwargs):
if args and any([isinstance(arg, torch.Tensor) for arg in args]):
logger.warning_once(f'forward args containing torch.Tensor, '
f'please make sure this model does not need to expand bs dim in streaming mode.')

for key in self.keys_expanded:
if key in kwargs:
kwargs[key] = kwargs[key][0].squeeze()

self.pre_forward(kwargs)
return args, kwargs

Expand Down Expand Up @@ -951,7 +943,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
if loss_scale is not None:
res['loss_scale'] = loss_scale

self._expand_dim(res)
# self._expand_dim(res)
return res

@staticmethod
Expand Down

0 comments on commit db17513

Please sign in to comment.