From c83c763040033a09ff168997da89ee2b3eae0bbe Mon Sep 17 00:00:00 2001 From: Richard Stanton Date: Wed, 23 Jun 2021 21:20:39 +0100 Subject: [PATCH] removing bias from linear model regularisation --- pl_bolts/models/regression/linear_regression.py | 4 ++-- pl_bolts/models/regression/logistic_regression.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index b4cff87f44..778a2cae0e 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -58,12 +58,12 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st # L1 regularizer if self.hparams.l1_strength > 0: - l1_reg = sum(param.abs().sum() for param in self.parameters()) + l1_reg = self.linear.weight.abs().sum() loss += self.hparams.l1_strength * l1_reg # L2 regularizer if self.hparams.l2_strength > 0: - l2_reg = sum(param.pow(2).sum() for param in self.parameters()) + l2_reg = self.linear.weight.pow(2).sum() loss += self.hparams.l2_strength * l2_reg loss /= x.size(0) diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index e61b58f720..efc805c510 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -61,12 +61,12 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st # L1 regularizer if self.hparams.l1_strength > 0: - l1_reg = sum(param.abs().sum() for param in self.parameters()) + l1_reg = self.linear.weight.abs().sum() loss += self.hparams.l1_strength * l1_reg # L2 regularizer if self.hparams.l2_strength > 0: - l2_reg = sum(param.pow(2).sum() for param in self.parameters()) + l2_reg = self.linear.weight.pow(2).sum() loss += self.hparams.l2_strength * l2_reg loss /= x.size(0)