From 82659b6d94ce541ad3c4845b08e705a76206c7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E6=BD=AD?= <654523039@qq.com> Date: Wed, 20 Sep 2023 00:26:12 +0800 Subject: [PATCH] fix lm_head type changed bug --- supervised_finetuning.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/supervised_finetuning.py b/supervised_finetuning.py index af4f77c..b3e9a97 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -191,11 +191,25 @@ class PeftArguments(TrainingArguments): qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"}) -class CastOutputToFloat(torch.nn.Sequential): +class CastOutputToFloat(torch.nn.Module): """Cast the output of the model to float""" + def __init__(self, ori_linear: torch.nn.Linear) -> None: + super().__init__() + self.in_features = ori_linear.in_features + self.out_features = ori_linear.out_features + self.weight = ori_linear.weight + if ori_linear.bias is not None: + self.bias = ori_linear.bias + else: + self.register_parameter('bias', None) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias).to(torch.float32) - def forward(self, x): - return super().forward(x).to(torch.float32) + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) @dataclass