diff --git a/supervised_finetuning.py b/supervised_finetuning.py index af4f77c..b3e9a97 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -191,11 +191,25 @@ class PeftArguments(TrainingArguments): qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"}) -class CastOutputToFloat(torch.nn.Sequential): +class CastOutputToFloat(torch.nn.Module): """Cast the output of the model to float""" + def __init__(self, ori_linear: torch.nn.Linear) -> None: + super().__init__() + self.in_features = ori_linear.in_features + self.out_features = ori_linear.out_features + self.weight = ori_linear.weight + if ori_linear.bias is not None: + self.bias = ori_linear.bias + else: + self.register_parameter('bias', None) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias).to(torch.float32) - def forward(self, x): - return super().forward(x).to(torch.float32) + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) @dataclass