Skip to content

Commit 4dedf6e

Browse files
committed
fix hasattr check for task-awareness
1 parent 91841ac commit 4dedf6e

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

avalanche/training/templates/base_sgd.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,7 @@ def make_train_dataloader(
456456
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]
457457

458458
# use task-balanced dataloader for task-aware benchmarks
459-
if hasattr(self.experience, "task_label") or hasattr(
460-
self.experience, "task_labels"
461-
):
459+
if hasattr(self.experience, "task_labels"):
462460
self.dataloader = TaskBalancedDataLoader(
463461
self.adapted_dataset,
464462
oversample_small_groups=True,

avalanche/training/templates/problem_type/supervised_problem.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def criterion(self):
4545
def forward(self):
4646
"""Compute the model's output given the current mini-batch."""
4747
# use task-aware forward only for task-aware benchmarks
48-
if hasattr(self.experience, "task_labels") or hasattr(
49-
self.experience, "task_label"
50-
):
48+
if hasattr(self.experience, "task_labels"):
5149
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
5250
else:
5351
return self.model(self.mb_x)

0 commit comments

Comments
 (0)