diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 55f83d905..ea6041a6c 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -37,6 +37,7 @@ def compute(self): labels = values['labels'] logits = values['logits'] mask = values['mask'] + sigmoid = jax.nn.sigmoid if USE_PYTORCH_DDP: # Sync labels, logits, and masks across devices. @@ -49,9 +50,14 @@ def compute(self): all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values + def sigmoid_np(x): + return 1 / (1 + np.exp(-x)) + + sigmoid = sigmoid_np + mask = mask.astype(bool) - probs = jax.nn.sigmoid(logits) + probs = sigmoid(logits) num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan)