Skip to content

Commit acdecd7

Browse files
Jinpu Zhoumeta-codesync[bot]
authored andcommitted
Add NMSE metric for APS model
Summary: Add CNMSE and CNRMSE to APS metric for early signal of model performance Reviewed By: iamzainhuda Differential Revision: D85412386 fbshipit-source-id: b891d45587c30aa099cce1acdf70ee30212a0b1f
1 parent 047760b commit acdecd7

File tree

5 files changed

+508
-0
lines changed

5 files changed

+508
-0
lines changed

torchrec/metrics/metric_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from torchrec.metrics.ne import NEMetric
5454
from torchrec.metrics.ne_positive import NEPositiveMetric
5555
from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric
56+
from torchrec.metrics.nmse import NMSEMetric
5657
from torchrec.metrics.output import OutputMetric
5758
from torchrec.metrics.precision import PrecisionMetric
5859
from torchrec.metrics.precision_session import PrecisionSessionMetric
@@ -105,6 +106,7 @@
105106
RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric,
106107
RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric,
107108
RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric,
109+
RecMetricEnum.NMSE: NMSEMetric,
108110
}
109111

110112

torchrec/metrics/metrics_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class RecMetricEnum(RecMetricEnumBase):
5050
CALI_FREE_NE = "cali_free_ne"
5151
UNWEIGHTED_NE = "unweighted_ne"
5252
HINDSIGHT_TARGET_PR = "hindsight_target_pr"
53+
NMSE = "nmse"
5354

5455

5556
@dataclass(unsafe_hash=True, eq=True)

torchrec/metrics/metrics_namespace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class MetricName(MetricNameBase):
9090

9191
EFFECTIVE_SAMPLE_RATE = "effective_sample_rate"
9292

93+
NMSE = "nmse"
94+
NRMSE = "nrmse"
95+
9396

9497
class MetricNamespaceBase(StrValueMixin, Enum):
9598
pass
@@ -148,6 +151,8 @@ class MetricNamespace(MetricNamespaceBase):
148151
# This is particularly useful for MTML models train with composite pipelines to figure out per-batch blending ratio.
149152
EFFECTIVE_RATE = "effective_rate"
150153

154+
NMSE = "nmse"
155+
151156

152157
class MetricPrefix(StrValueMixin, Enum):
153158
DEFAULT = ""

torchrec/metrics/nmse.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, cast, Dict, List, Optional, Type
11+
12+
import torch
13+
14+
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
15+
from torchrec.metrics.mse import (
16+
compute_error_sum,
17+
compute_mse,
18+
compute_rmse,
19+
ERROR_SUM,
20+
get_mse_states,
21+
MSEMetricComputation,
22+
WEIGHTED_NUM_SAMPES,
23+
)
24+
from torchrec.metrics.rec_metric import (
25+
MetricComputationReport,
26+
RecMetric,
27+
RecMetricException,
28+
)
29+
30+
CONST_PRED_ERROR_SUM = "const_pred_error_sum"
31+
32+
33+
def compute_norm(
34+
model_error_sum: torch.Tensor, baseline_error_sum: torch.Tensor
35+
) -> torch.Tensor:
36+
return torch.where(
37+
baseline_error_sum == 0,
38+
torch.tensor(0.0),
39+
model_error_sum / baseline_error_sum,
40+
).double()
41+
42+
43+
def get_norm_mse_states(
44+
labels: torch.Tensor,
45+
predictions: torch.Tensor,
46+
weights: torch.Tensor,
47+
) -> Dict[str, torch.Tensor]:
48+
return {
49+
**get_mse_states(labels, predictions, weights),
50+
**(
51+
{
52+
CONST_PRED_ERROR_SUM: compute_error_sum(
53+
labels, torch.ones_like(labels), weights
54+
)
55+
}
56+
),
57+
}
58+
59+
60+
class NMSEMetricComputation(MSEMetricComputation):
61+
r"""
62+
This class extends the MSEMetricComputation for normalization computation for L2 regression metrics.
63+
64+
The constructor arguments are defined in RecMetricComputation.
65+
See the docstring of RecMetricComputation for more detail.
66+
"""
67+
68+
def __init__(self, *args: Any, **kwargs: Any) -> None:
69+
super().__init__(*args, **kwargs)
70+
self._add_state(
71+
CONST_PRED_ERROR_SUM,
72+
torch.zeros(self._n_tasks, dtype=torch.double),
73+
add_window_state=True,
74+
dist_reduce_fx="sum",
75+
persistent=True,
76+
)
77+
78+
def update(
79+
self,
80+
*,
81+
predictions: Optional[torch.Tensor],
82+
labels: torch.Tensor,
83+
weights: Optional[torch.Tensor],
84+
**kwargs: Dict[str, Any],
85+
) -> None:
86+
if predictions is None or weights is None:
87+
raise RecMetricException(
88+
"Inputs 'predictions' and 'weights' should not be None for NMSEMetricComputation update"
89+
)
90+
states = get_norm_mse_states(labels, predictions, weights)
91+
num_samples = predictions.shape[-1]
92+
for state_name, state_value in states.items():
93+
state = getattr(self, state_name)
94+
state += state_value
95+
self._aggregate_window_state(state_name, state_value, num_samples)
96+
97+
def _compute(self) -> List[MetricComputationReport]:
98+
mse = compute_mse(
99+
cast(torch.Tensor, self.error_sum),
100+
cast(torch.Tensor, self.weighted_num_samples),
101+
)
102+
const_pred_mse = compute_mse(
103+
cast(torch.Tensor, self.const_pred_error_sum),
104+
cast(torch.Tensor, self.weighted_num_samples),
105+
)
106+
nmse = compute_norm(mse, const_pred_mse)
107+
108+
rmse = compute_rmse(
109+
cast(torch.Tensor, self.error_sum),
110+
cast(torch.Tensor, self.weighted_num_samples),
111+
)
112+
const_pred_rmse = compute_rmse(
113+
cast(torch.Tensor, self.const_pred_error_sum),
114+
cast(torch.Tensor, self.weighted_num_samples),
115+
)
116+
nrmse = compute_norm(rmse, const_pred_rmse)
117+
118+
window_mse = compute_mse(
119+
self.get_window_state(ERROR_SUM),
120+
self.get_window_state(WEIGHTED_NUM_SAMPES),
121+
)
122+
window_const_pred_mse = compute_mse(
123+
self.get_window_state(CONST_PRED_ERROR_SUM),
124+
self.get_window_state(WEIGHTED_NUM_SAMPES),
125+
)
126+
window_nmse = compute_norm(window_mse, window_const_pred_mse)
127+
128+
window_rmse = compute_rmse(
129+
self.get_window_state(ERROR_SUM),
130+
self.get_window_state(WEIGHTED_NUM_SAMPES),
131+
)
132+
window_const_pred_rmse = compute_rmse(
133+
self.get_window_state(CONST_PRED_ERROR_SUM),
134+
self.get_window_state(WEIGHTED_NUM_SAMPES),
135+
)
136+
window_nrmse = compute_norm(window_rmse, window_const_pred_rmse)
137+
138+
return [
139+
MetricComputationReport(
140+
name=MetricName.NMSE,
141+
metric_prefix=MetricPrefix.LIFETIME,
142+
value=nmse,
143+
),
144+
MetricComputationReport(
145+
name=MetricName.NRMSE,
146+
metric_prefix=MetricPrefix.LIFETIME,
147+
value=nrmse,
148+
),
149+
MetricComputationReport(
150+
name=MetricName.NMSE,
151+
metric_prefix=MetricPrefix.WINDOW,
152+
value=window_nmse,
153+
),
154+
MetricComputationReport(
155+
name=MetricName.NRMSE,
156+
metric_prefix=MetricPrefix.WINDOW,
157+
value=window_nrmse,
158+
),
159+
]
160+
161+
162+
class NMSEMetric(RecMetric):
163+
_namespace: MetricNamespace = MetricNamespace.NMSE
164+
_computation_class: Type[NMSEMetricComputation] = NMSEMetricComputation

0 commit comments

Comments
 (0)