Skip to content

Commit

Permalink
Update loss.py (#304)
Browse files Browse the repository at this point in the history
* Update loss.py

* fixing default

* trying to have good default values

* fixing stuff

* removing useless logic
  • Loading branch information
jeandut authored Jun 12, 2024
1 parent bf58082 commit 8799494
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def main(output_folder, debug=False, **kwargs):
action="store_false",
help="Generate a regression dataset. (Default)",
)
parser.set_defaults(classification=False)
parser.set_defaults(classification=True)

parser.add_argument(
"--output-folder",
Expand Down
4 changes: 2 additions & 2 deletions flamby/datasets/fed_synthetic/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class BaselineLoss(_Loss):
def __init__(self):
super(BaselineLoss, self).__init__()
self.bce = torch.nn.BCELoss()
self.ce = torch.nn.CrossEntropyLoss()

def forward(self, input: torch.Tensor, target: torch.Tensor):
return self.bce(input, target)
return self.ce(input, target.squeeze().long())
13 changes: 4 additions & 9 deletions flamby/datasets/fed_synthetic/metric.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import numpy as np
from sklearn.metrics import accuracy_score


def metric(y_true, y_pred):
y_true = y_true.astype("uint8")
# The try except is needed because when the metric is batched some batches
# have one class only
try:
# return roc_auc_score(y_true, y_pred)
# proposed modification in order to get a metric that calcs on center 2
# (y=1 only on that center)
return ((y_pred > 0.5) == y_true).mean()
except ValueError:
return np.nan
y_pred = np.argmax(y_pred, axis=1)
return accuracy_score(y_true, y_pred)

2 changes: 1 addition & 1 deletion flamby/datasets/fed_synthetic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class Baseline(nn.Module):
def __init__(self, input_dim=10, output_dim=1):
def __init__(self, input_dim=10, output_dim=3):
super(Baseline, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)

Expand Down

0 comments on commit 8799494

Please sign in to comment.