Skip to content

Commit

Permalink
Add output_dim parameter to LinearRegression (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
gokceneraslan authored Sep 10, 2020
1 parent 9ad00bf commit ca8e7b2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class LinearRegression(pl.LightningModule):

def __init__(self,
input_dim: int,
output_dim: int = 1,
bias: bool = True,
learning_rate: float = 0.0001,
optimizer: Optimizer = Adam,
Expand All @@ -26,6 +27,7 @@ def __init__(self,
Args:
input_dim: number of dimensions of the input (1+)
output_dim: number of dimensions of the output (default=1)
bias: If false, will not use $+b$
learning_rate: learning_rate for the optimizer
optimizer: the optimizer to use (default='Adam')
Expand All @@ -37,7 +39,7 @@ def __init__(self,
self.save_hyperparameters()
self.optimizer = optimizer

self.linear = nn.Linear(in_features=self.hparams.input_dim, out_features=1, bias=bias)
self.linear = nn.Linear(in_features=self.hparams.input_dim, out_features=self.hparams.output_dim, bias=bias)

def forward(self, x):
y_hat = self.linear(x)
Expand Down Expand Up @@ -114,6 +116,7 @@ def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--input_dim', type=int, default=None)
parser.add_argument('--output_dim', type=int, default=1)
parser.add_argument('--bias', default='store_true')
parser.add_argument('--batch_size', type=int, default=16)
return parser
Expand Down

0 comments on commit ca8e7b2

Please sign in to comment.