Skip to content

Commit

Permalink
Added EQRM.
Browse files Browse the repository at this point in the history
  • Loading branch information
Cian Eastwood committed Jan 28, 2023
1 parent 3ab2b63 commit 22002ea
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The [currently available algorithms](domainbed/algorithms.py) are:
* Optimal Representations for Covariate Shift (CAD & CondCAD, [Ruan et al., 2022](https://arxiv.org/abs/2201.00057)), contributed by [@ryoungj](https://github.com/ryoungj)
* Quantifying and Improving Transferability in Domain Generalization (Transfer, [Zhang et al., 2021](https://arxiv.org/abs/2106.03632)), contributed by [@Gordon-Guojun-Zhang](https://github.com/Gordon-Guojun-Zhang)
* Invariant Causal Mechanisms through Distribution Matching (CausIRL with CORAL or MMD, [Chevalley et al., 2022](https://arxiv.org/abs/2206.11646)), contributed by [@MathieuChevalley](https://github.com/MathieuChevalley)
* Empirical Quantile Risk Minimization (EQRM, [Eastwood et al., 2022](https://arxiv.org/abs/2207.09944)), contributed by [@cianeastwood](https://github.com/cianeastwood)

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py).

Expand Down
48 changes: 47 additions & 1 deletion domainbed/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from domainbed import networks
from domainbed.lib.misc import (
random_pairs_of_minibatches, split_meta_train_test, ParamDict,
MovingAverage, l2_between_dicts, proj
MovingAverage, l2_between_dicts, proj, Nonparametric
)


Expand Down Expand Up @@ -51,6 +51,7 @@
'Transfer',
'CausIRL_CORAL',
'CausIRL_MMD',
'EQRM',
]

def get_algorithm_class(algorithm_name):
Expand Down Expand Up @@ -1987,3 +1988,48 @@ class CausIRL_CORAL(AbstractCausIRL):
def __init__(self, input_shape, num_classes, num_domains, hparams):
super(CausIRL_CORAL, self).__init__(input_shape, num_classes, num_domains,
hparams, gaussian=False)


class EQRM(ERM):
"""
Empirical Quantile Risk Minimization (EQRM).
Algorithm 1 from [https://arxiv.org/pdf/2207.09944.pdf].
"""
def __init__(self, input_shape, num_classes, num_domains, hparams, dist=None):
super().__init__(input_shape, num_classes, num_domains, hparams)
self.register_buffer('update_count', torch.tensor([0]))
self.register_buffer('alpha', torch.tensor(self.hparams["eqrm_quantile"], dtype=torch.float64))
if dist is None:
self.dist = Nonparametric()
else:
self.dist = dist

def risk(self, x, y):
return F.cross_entropy(self.network(x), y).reshape(1)

def update(self, minibatches, unlabeled=None):
env_risks = torch.cat([self.risk(x, y) for x, y in minibatches])

if self.update_count < self.hparams["eqrm_burnin_iters"]:
# Burn-in/annealing period uses ERM like penalty methods (which set penalty_weight=0, e.g. IRM, VREx.)
loss = torch.mean(env_risks)
else:
# Loss is the alpha-quantile value
self.dist.estimate_parameters(env_risks)
loss = self.dist.icdf(self.alpha)

if self.update_count == self.hparams['eqrm_burnin_iters']:
# Reset Adam (like IRM, VREx, etc.), because it doesn't like the sharp jump in
# gradient magnitudes that happens at this step.
self.optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.hparams["eqrm_lr"],
weight_decay=self.hparams['weight_decay'])

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

self.update_count += 1

return {'loss': loss.item()}
5 changes: 5 additions & 0 deletions domainbed/hparams_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def _hparam(name, default_val, random_val_fn):
_hparam('beta1', 0.5, lambda r: r.choice([0., 0.5]))
_hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))

elif algorithm == 'EQRM':
_hparam('eqrm_quantile', 0.75, lambda r: r.uniform(0.5, 0.99))
_hparam('eqrm_burnin_iters', 2500, lambda r: 10 ** r.uniform(2.5, 3.5))
_hparam('eqrm_lr', 1e-6, lambda r: 10 ** r.uniform(-7, -5))


# Dataset-and-algorithm-specific hparam definitions. Each block of code
# below corresponds to exactly one hparam. Avoid nested conditionals.
Expand Down
241 changes: 241 additions & 0 deletions domainbed/lib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Things that don't belong anywhere else
"""

import math
import hashlib
import sys
from collections import OrderedDict
Expand Down Expand Up @@ -258,3 +259,243 @@ def __rsub__(self, other):

def __truediv__(self, other):
return self._prototype(other, operator.truediv)


############################################################
# A general PyTorch implementation of KDE. Builds on:
# https://github.com/EugenHotaj/pytorch-generative/blob/master/pytorch_generative/models/kde.py
############################################################

class Kernel(torch.nn.Module):
"""Base class which defines the interface for all kernels."""

def __init__(self, bw=None):
super().__init__()
self.bw = 0.05 if bw is None else bw

def _diffs(self, test_Xs, train_Xs):
"""Computes difference between each x in test_Xs with all train_Xs."""
test_Xs = test_Xs.view(test_Xs.shape[0], 1, *test_Xs.shape[1:])
train_Xs = train_Xs.view(1, train_Xs.shape[0], *train_Xs.shape[1:])
return test_Xs - train_Xs

def forward(self, test_Xs, train_Xs):
"""Computes p(x) for each x in test_Xs given train_Xs."""

def sample(self, train_Xs):
"""Generates samples from the kernel distribution."""


class GaussianKernel(Kernel):
"""Implementation of the Gaussian kernel."""

def forward(self, test_Xs, train_Xs):
diffs = self._diffs(test_Xs, train_Xs)
dims = tuple(range(len(diffs.shape))[2:])
if dims == ():
x_sq = diffs ** 2
else:
x_sq = torch.norm(diffs, p=2, dim=dims) ** 2

var = self.bw ** 2
exp = torch.exp(-x_sq / (2 * var))
coef = 1. / torch.sqrt(2 * np.pi * var)

return (coef * exp).mean(dim=1)

def sample(self, train_Xs):
# device = train_Xs.device
noise = torch.randn(train_Xs.shape) * self.bw
return train_Xs + noise

def cdf(self, test_Xs, train_Xs):
mus = train_Xs # kernel centred on each observation
sigmas = torch.ones(len(mus), device=test_Xs.device) * self.bw # bandwidth = stddev
x_ = test_Xs.repeat(len(mus), 1).T # repeat to allow broadcasting below
return torch.mean(torch.distributions.Normal(mus, sigmas).cdf(x_))


def estimate_bandwidth(x, method="silverman"):
x_, _ = torch.sort(x)
n = len(x_)
sample_std = torch.std(x_, unbiased=True)

if method == 'silverman':
# https://en.wikipedia.org/wiki/Kernel_density_estimation#A_rule-of-thumb_bandwidth_estimator
iqr = torch.quantile(x_, 0.75) - torch.quantile(x_, 0.25)
bandwidth = 0.9 * torch.min(sample_std, iqr / 1.34) * n ** (-0.2)

elif method.lower() == 'gauss-optimal':
bandwidth = 1.06 * sample_std * (n ** -0.2)

else:
raise ValueError(f"Invalid method selected: {method}.")

return bandwidth


class KernelDensityEstimator(torch.nn.Module):
"""The KernelDensityEstimator model."""

def __init__(self, train_Xs, kernel='gaussian', bw_select='Gauss-optimal'):
"""Initializes a new KernelDensityEstimator.
Args:
train_Xs: The "training" data to use when estimating probabilities.
kernel: The kernel to place on each of the train_Xs.
"""
super().__init__()
self.train_Xs = train_Xs
self._n_kernels = len(self.train_Xs)

if bw_select is not None:
self.bw = estimate_bandwidth(self.train_Xs, bw_select)
else:
self.bw = None

if kernel.lower() == 'gaussian':
self.kernel = GaussianKernel(self.bw)
else:
raise NotImplementedError(f"'{kernel}' kernel not implemented.")

@property
def device(self):
return self.train_Xs.device

# TODO(eugenhotaj): This method consumes O(train_Xs * x) memory. Implement an iterative version instead.
def forward(self, x):
return self.kernel(x, self.train_Xs)

def sample(self, n_samples):
idxs = np.random.choice(range(self._n_kernels), size=n_samples)
return self.kernel.sample(self.train_Xs[idxs])

def cdf(self, x):
return self.kernel.cdf(x, self.train_Xs)


############################################################
# PyTorch implementation of 1D distributions.
############################################################

EPS = 1e-16


class Distribution1D:
def __init__(self, dist_function=None):
"""
:param dist_function: function to instantiate the distribution (self.dist).
:param parameters: list of parameters in the correct order for dist_function.
"""
self.dist = None
self.dist_function = dist_function

@property
def parameters(self):
raise NotImplementedError

def create_dist(self):
if self.dist_function is not None:
return self.dist_function(*self.parameters)
else:
raise NotImplementedError("No distribution function was specified during intialization.")

def estimate_parameters(self, x):
raise NotImplementedError

def log_prob(self, x):
return self.create_dist().log_prob(x)

def cdf(self, x):
return self.create_dist().cdf(x)

def icdf(self, q):
return self.create_dist().icdf(q)

def sample(self, n=1):
if self.dist is None:
self.dist = self.create_dist()
n_ = torch.Size([]) if n == 1 else (n,)
return self.dist.sample(n_)

def sample_n(self, n=10):
return self.sample(n)


def continuous_bisect_fun_left(f, v, lo, hi, n_steps=32):
val_range = [lo, hi]
k = 0.5 * sum(val_range)
for _ in range(n_steps):
val_range[int(f(k) > v)] = k
next_k = 0.5 * sum(val_range)
if next_k == k:
break
k = next_k
return k


class Normal(Distribution1D):
def __init__(self, location=0, scale=1):
self.location = location
self.scale = scale
super().__init__(torch.distributions.Normal)

@property
def parameters(self):
return [self.location, self.scale]

def estimate_parameters(self, x):
mean = sum(x) / len(x)
var = sum([(x_i - mean) ** 2 for x_i in x]) / (len(x) - 1)
self.location = mean
self.scale = torch.sqrt(var + EPS)

def icdf(self, q):
if q >= 0:
return super().icdf(q)

else:
# To get q *very* close to 1 without numerical issues, we:
# 1) Use q < 0 to represent log(y), where q = 1 - y.
# 2) Use the inverse-normal-cdf approximation here:
# https://math.stackexchange.com/questions/2964944/asymptotics-of-inverse-of-normal-cdf
log_y = q
return self.location + self.scale * math.sqrt(-2 * log_y)


class Nonparametric(Distribution1D):
def __init__(self, use_kde=True, bw_select='Gauss-optimal'):
self.use_kde = use_kde
self.bw_select = bw_select
self.bw, self.data, self.kde = None, None, None
super().__init__()

@property
def parameters(self):
return []

def estimate_parameters(self, x):
self.data, _ = torch.sort(x)

if self.use_kde:
self.kde = KernelDensityEstimator(self.data, bw_select=self.bw_select)
self.bw = torch.ones(1, device=self.data.device) * self.kde.bw

def icdf(self, q):
if not self.use_kde:
# Empirical or step CDF. Differentiable as torch.quantile uses (linear) interpolation.
return torch.quantile(self.data, float(q))

if q >= 0:
# Find quantile via binary search on the KDE CDF
lo = torch.distributions.Normal(self.data[0], self.bw[0]).icdf(q)
hi = torch.distributions.Normal(self.data[-1], self.bw[-1]).icdf(q)
return continuous_bisect_fun_left(self.kde.cdf, q, lo, hi)

else:
# To get q *very* close to 1 without numerical issues, we:
# 1) Use q < 0 to represent log(y), where q = 1 - y.
# 2) Use the inverse-normal-cdf approximation here:
# https://math.stackexchange.com/questions/2964944/asymptotics-of-inverse-of-normal-cdf
log_y = q
v = torch.mean(self.data + self.bw * math.sqrt(-2 * log_y))
return v

0 comments on commit 22002ea

Please sign in to comment.