From b5b20a2e551e1ecc5a7ac018e9dcddd6e57a235f Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Wed, 9 Jun 2021 17:39:10 -0700 Subject: [PATCH] Fix erroneous loss for ExactGP multitask models The code in ExactMLL divides by `targets.size(-1)`. For multitask models, this is the number of outputs. Instead, this code divides by `function_dist.event_shape.numel()`, which will result in losses that are similar in size to univariate ExactGP models. [Fixes #1155] [Fixes #1033] [Fixes #1159] [Addresses #1129] [Addresses #1633] --- gpytorch/mlls/exact_marginal_log_likelihood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 21196eb2c..f2306d2f2 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -63,7 +63,7 @@ def forward(self, function_dist, target, *params): res = self._add_other_terms(res, params) # Scale by the amount of data we have - num_data = target.size(-1) + num_data = function_dist.event_shape.numel() return res.div_(num_data) def pyro_factor(self, output, target, *params):