From 22002eaae419ddb30c911b40c40bd27513baebb9 Mon Sep 17 00:00:00 2001 From: Cian Eastwood Date: Sat, 28 Jan 2023 11:26:53 +0000 Subject: [PATCH] Added EQRM. --- README.md | 1 + domainbed/algorithms.py | 48 ++++++- domainbed/hparams_registry.py | 5 + domainbed/lib/misc.py | 241 ++++++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 64fb6c1b..d6004763 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ The [currently available algorithms](domainbed/algorithms.py) are: * Optimal Representations for Covariate Shift (CAD & CondCAD, [Ruan et al., 2022](https://arxiv.org/abs/2201.00057)), contributed by [@ryoungj](https://github.com/ryoungj) * Quantifying and Improving Transferability in Domain Generalization (Transfer, [Zhang et al., 2021](https://arxiv.org/abs/2106.03632)), contributed by [@Gordon-Guojun-Zhang](https://github.com/Gordon-Guojun-Zhang) * Invariant Causal Mechanisms through Distribution Matching (CausIRL with CORAL or MMD, [Chevalley et al., 2022](https://arxiv.org/abs/2206.11646)), contributed by [@MathieuChevalley](https://github.com/MathieuChevalley) +* Empirical Quantile Risk Minimization (EQRM, [Eastwood et al., 2022](https://arxiv.org/abs/2207.09944)), contributed by [@cianeastwood](https://github.com/cianeastwood) Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py). diff --git a/domainbed/algorithms.py b/domainbed/algorithms.py index 46ee4535..2204279d 100644 --- a/domainbed/algorithms.py +++ b/domainbed/algorithms.py @@ -17,7 +17,7 @@ from domainbed import networks from domainbed.lib.misc import ( random_pairs_of_minibatches, split_meta_train_test, ParamDict, - MovingAverage, l2_between_dicts, proj + MovingAverage, l2_between_dicts, proj, Nonparametric ) @@ -51,6 +51,7 @@ 'Transfer', 'CausIRL_CORAL', 'CausIRL_MMD', + 'EQRM', ] def get_algorithm_class(algorithm_name): @@ -1987,3 +1988,48 @@ class CausIRL_CORAL(AbstractCausIRL): def __init__(self, input_shape, num_classes, num_domains, hparams): super(CausIRL_CORAL, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=False) + + +class EQRM(ERM): + """ + Empirical Quantile Risk Minimization (EQRM). + Algorithm 1 from [https://arxiv.org/pdf/2207.09944.pdf]. + """ + def __init__(self, input_shape, num_classes, num_domains, hparams, dist=None): + super().__init__(input_shape, num_classes, num_domains, hparams) + self.register_buffer('update_count', torch.tensor([0])) + self.register_buffer('alpha', torch.tensor(self.hparams["eqrm_quantile"], dtype=torch.float64)) + if dist is None: + self.dist = Nonparametric() + else: + self.dist = dist + + def risk(self, x, y): + return F.cross_entropy(self.network(x), y).reshape(1) + + def update(self, minibatches, unlabeled=None): + env_risks = torch.cat([self.risk(x, y) for x, y in minibatches]) + + if self.update_count < self.hparams["eqrm_burnin_iters"]: + # Burn-in/annealing period uses ERM like penalty methods (which set penalty_weight=0, e.g. IRM, VREx.) + loss = torch.mean(env_risks) + else: + # Loss is the alpha-quantile value + self.dist.estimate_parameters(env_risks) + loss = self.dist.icdf(self.alpha) + + if self.update_count == self.hparams['eqrm_burnin_iters']: + # Reset Adam (like IRM, VREx, etc.), because it doesn't like the sharp jump in + # gradient magnitudes that happens at this step. + self.optimizer = torch.optim.Adam( + self.network.parameters(), + lr=self.hparams["eqrm_lr"], + weight_decay=self.hparams['weight_decay']) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.update_count += 1 + + return {'loss': loss.item()} diff --git a/domainbed/hparams_registry.py b/domainbed/hparams_registry.py index 30ba0eb0..3fb4d8b6 100644 --- a/domainbed/hparams_registry.py +++ b/domainbed/hparams_registry.py @@ -136,6 +136,11 @@ def _hparam(name, default_val, random_val_fn): _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5])) _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) + elif algorithm == 'EQRM': + _hparam('eqrm_quantile', 0.75, lambda r: r.uniform(0.5, 0.99)) + _hparam('eqrm_burnin_iters', 2500, lambda r: 10 ** r.uniform(2.5, 3.5)) + _hparam('eqrm_lr', 1e-6, lambda r: 10 ** r.uniform(-7, -5)) + # Dataset-and-algorithm-specific hparam definitions. Each block of code # below corresponds to exactly one hparam. Avoid nested conditionals. diff --git a/domainbed/lib/misc.py b/domainbed/lib/misc.py index 36eae928..cea5d885 100644 --- a/domainbed/lib/misc.py +++ b/domainbed/lib/misc.py @@ -4,6 +4,7 @@ Things that don't belong anywhere else """ +import math import hashlib import sys from collections import OrderedDict @@ -258,3 +259,243 @@ def __rsub__(self, other): def __truediv__(self, other): return self._prototype(other, operator.truediv) + + +############################################################ +# A general PyTorch implementation of KDE. Builds on: +# https://github.com/EugenHotaj/pytorch-generative/blob/master/pytorch_generative/models/kde.py +############################################################ + +class Kernel(torch.nn.Module): + """Base class which defines the interface for all kernels.""" + + def __init__(self, bw=None): + super().__init__() + self.bw = 0.05 if bw is None else bw + + def _diffs(self, test_Xs, train_Xs): + """Computes difference between each x in test_Xs with all train_Xs.""" + test_Xs = test_Xs.view(test_Xs.shape[0], 1, *test_Xs.shape[1:]) + train_Xs = train_Xs.view(1, train_Xs.shape[0], *train_Xs.shape[1:]) + return test_Xs - train_Xs + + def forward(self, test_Xs, train_Xs): + """Computes p(x) for each x in test_Xs given train_Xs.""" + + def sample(self, train_Xs): + """Generates samples from the kernel distribution.""" + + +class GaussianKernel(Kernel): + """Implementation of the Gaussian kernel.""" + + def forward(self, test_Xs, train_Xs): + diffs = self._diffs(test_Xs, train_Xs) + dims = tuple(range(len(diffs.shape))[2:]) + if dims == (): + x_sq = diffs ** 2 + else: + x_sq = torch.norm(diffs, p=2, dim=dims) ** 2 + + var = self.bw ** 2 + exp = torch.exp(-x_sq / (2 * var)) + coef = 1. / torch.sqrt(2 * np.pi * var) + + return (coef * exp).mean(dim=1) + + def sample(self, train_Xs): + # device = train_Xs.device + noise = torch.randn(train_Xs.shape) * self.bw + return train_Xs + noise + + def cdf(self, test_Xs, train_Xs): + mus = train_Xs # kernel centred on each observation + sigmas = torch.ones(len(mus), device=test_Xs.device) * self.bw # bandwidth = stddev + x_ = test_Xs.repeat(len(mus), 1).T # repeat to allow broadcasting below + return torch.mean(torch.distributions.Normal(mus, sigmas).cdf(x_)) + + +def estimate_bandwidth(x, method="silverman"): + x_, _ = torch.sort(x) + n = len(x_) + sample_std = torch.std(x_, unbiased=True) + + if method == 'silverman': + # https://en.wikipedia.org/wiki/Kernel_density_estimation#A_rule-of-thumb_bandwidth_estimator + iqr = torch.quantile(x_, 0.75) - torch.quantile(x_, 0.25) + bandwidth = 0.9 * torch.min(sample_std, iqr / 1.34) * n ** (-0.2) + + elif method.lower() == 'gauss-optimal': + bandwidth = 1.06 * sample_std * (n ** -0.2) + + else: + raise ValueError(f"Invalid method selected: {method}.") + + return bandwidth + + +class KernelDensityEstimator(torch.nn.Module): + """The KernelDensityEstimator model.""" + + def __init__(self, train_Xs, kernel='gaussian', bw_select='Gauss-optimal'): + """Initializes a new KernelDensityEstimator. + Args: + train_Xs: The "training" data to use when estimating probabilities. + kernel: The kernel to place on each of the train_Xs. + """ + super().__init__() + self.train_Xs = train_Xs + self._n_kernels = len(self.train_Xs) + + if bw_select is not None: + self.bw = estimate_bandwidth(self.train_Xs, bw_select) + else: + self.bw = None + + if kernel.lower() == 'gaussian': + self.kernel = GaussianKernel(self.bw) + else: + raise NotImplementedError(f"'{kernel}' kernel not implemented.") + + @property + def device(self): + return self.train_Xs.device + + # TODO(eugenhotaj): This method consumes O(train_Xs * x) memory. Implement an iterative version instead. + def forward(self, x): + return self.kernel(x, self.train_Xs) + + def sample(self, n_samples): + idxs = np.random.choice(range(self._n_kernels), size=n_samples) + return self.kernel.sample(self.train_Xs[idxs]) + + def cdf(self, x): + return self.kernel.cdf(x, self.train_Xs) + + +############################################################ +# PyTorch implementation of 1D distributions. +############################################################ + +EPS = 1e-16 + + +class Distribution1D: + def __init__(self, dist_function=None): + """ + :param dist_function: function to instantiate the distribution (self.dist). + :param parameters: list of parameters in the correct order for dist_function. + """ + self.dist = None + self.dist_function = dist_function + + @property + def parameters(self): + raise NotImplementedError + + def create_dist(self): + if self.dist_function is not None: + return self.dist_function(*self.parameters) + else: + raise NotImplementedError("No distribution function was specified during intialization.") + + def estimate_parameters(self, x): + raise NotImplementedError + + def log_prob(self, x): + return self.create_dist().log_prob(x) + + def cdf(self, x): + return self.create_dist().cdf(x) + + def icdf(self, q): + return self.create_dist().icdf(q) + + def sample(self, n=1): + if self.dist is None: + self.dist = self.create_dist() + n_ = torch.Size([]) if n == 1 else (n,) + return self.dist.sample(n_) + + def sample_n(self, n=10): + return self.sample(n) + + +def continuous_bisect_fun_left(f, v, lo, hi, n_steps=32): + val_range = [lo, hi] + k = 0.5 * sum(val_range) + for _ in range(n_steps): + val_range[int(f(k) > v)] = k + next_k = 0.5 * sum(val_range) + if next_k == k: + break + k = next_k + return k + + +class Normal(Distribution1D): + def __init__(self, location=0, scale=1): + self.location = location + self.scale = scale + super().__init__(torch.distributions.Normal) + + @property + def parameters(self): + return [self.location, self.scale] + + def estimate_parameters(self, x): + mean = sum(x) / len(x) + var = sum([(x_i - mean) ** 2 for x_i in x]) / (len(x) - 1) + self.location = mean + self.scale = torch.sqrt(var + EPS) + + def icdf(self, q): + if q >= 0: + return super().icdf(q) + + else: + # To get q *very* close to 1 without numerical issues, we: + # 1) Use q < 0 to represent log(y), where q = 1 - y. + # 2) Use the inverse-normal-cdf approximation here: + # https://math.stackexchange.com/questions/2964944/asymptotics-of-inverse-of-normal-cdf + log_y = q + return self.location + self.scale * math.sqrt(-2 * log_y) + + +class Nonparametric(Distribution1D): + def __init__(self, use_kde=True, bw_select='Gauss-optimal'): + self.use_kde = use_kde + self.bw_select = bw_select + self.bw, self.data, self.kde = None, None, None + super().__init__() + + @property + def parameters(self): + return [] + + def estimate_parameters(self, x): + self.data, _ = torch.sort(x) + + if self.use_kde: + self.kde = KernelDensityEstimator(self.data, bw_select=self.bw_select) + self.bw = torch.ones(1, device=self.data.device) * self.kde.bw + + def icdf(self, q): + if not self.use_kde: + # Empirical or step CDF. Differentiable as torch.quantile uses (linear) interpolation. + return torch.quantile(self.data, float(q)) + + if q >= 0: + # Find quantile via binary search on the KDE CDF + lo = torch.distributions.Normal(self.data[0], self.bw[0]).icdf(q) + hi = torch.distributions.Normal(self.data[-1], self.bw[-1]).icdf(q) + return continuous_bisect_fun_left(self.kde.cdf, q, lo, hi) + + else: + # To get q *very* close to 1 without numerical issues, we: + # 1) Use q < 0 to represent log(y), where q = 1 - y. + # 2) Use the inverse-normal-cdf approximation here: + # https://math.stackexchange.com/questions/2964944/asymptotics-of-inverse-of-normal-cdf + log_y = q + v = torch.mean(self.data + self.bw * math.sqrt(-2 * log_y)) + return v