Skip to content

Commit

Permalink
reset self._per_class_res after use
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiyangZhou committed Jun 8, 2021
1 parent d43a78c commit 29881c7
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions dassl/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def __init__(self, cfg, lab2cname=None, **kwargs):
def reset(self):
self._correct = 0
self._total = 0
if self._per_class_res is not None:
self._per_class_res = defaultdict(list)

def process(self, mo, gt):
# mo (torch.Tensor): model output [batch, num_classes]
# gt (torch.LongTensor): ground truth [batch]
pred = mo.max(1)[1]
matched = pred.eq(gt).float()
self._correct += int(matched.sum().item())
matches = pred.eq(gt).float()
self._correct += int(matches.sum().item())
self._total += gt.shape[0]

self._y_true.extend(gt.data.cpu().numpy().tolist())
Expand All @@ -57,8 +59,8 @@ def process(self, mo, gt):
if self._per_class_res is not None:
for i, label in enumerate(gt):
label = label.item()
matched_i = int(matched[i].item())
self._per_class_res[label].append(matched_i)
matches_i = int(matches[i].item())
self._per_class_res[label].append(matches_i)

def evaluate(self):
results = OrderedDict()
Expand Down

0 comments on commit 29881c7

Please sign in to comment.