Skip to content

Commit

Permalink
Merge branch 'staging-robust' of github.com:BasisResearch/causal_pyro…
Browse files Browse the repository at this point in the history
… into ra-eb-dr-reproduce
  • Loading branch information
agrawalraj committed Jan 2, 2024
2 parents 6453bae + 878eb0d commit 41db2e8
Show file tree
Hide file tree
Showing 10 changed files with 716 additions and 280 deletions.
8 changes: 5 additions & 3 deletions chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union
from typing import Callable, Generic, Mapping, TypeVar, Union

import pyro
import torch

from chirho.observational.internals import ObserveNameMessenger
from chirho.observational.ops import AtomicObservation, observe
from chirho.observational.ops import Observation, observe

T = TypeVar("T")
R = Union[float, torch.Tensor]
Expand Down Expand Up @@ -62,7 +62,9 @@ class Observations(Generic[T], ObserveNameMessenger):
a richer set of observational data types and enables counterfactual inference.
"""

def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]):
data: Mapping[str, Observation[T]]

def __init__(self, data: Mapping[str, Observation[T]]):
self.data = data
super().__init__()

Expand Down
21 changes: 21 additions & 0 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Callable, Optional

from typing_extensions import Concatenate

from chirho.robust.ops import Functional, P, Point, S, T, influence_fn


def one_step_correction(
model: Callable[P, Any],
guide: Callable[P, Any],
functional: Optional[Functional[P, S]] = None,
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step)

def _one_step(test_data: Point[T], *args, **kwargs) -> S:
return eif_fn(test_data, *args, **kwargs)

return _one_step
73 changes: 29 additions & 44 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import torch
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.internals.predictive import (
NMCLogPredictiveLikelihood,
PointLogPredictiveLikelihood,
)
from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood
from chirho.robust.internals.utils import (
ParamDict,
make_flatten_unflatten,
Expand Down Expand Up @@ -52,8 +49,8 @@ def _flat_conjugate_gradient_solve(
else:
cg_iters = min(cg_iters, b.shape[1])

def _batched_dot(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (a * b).sum(axis=-1)
def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return (x1 * x2).sum(axis=-1) # type: ignore

def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return a.unsqueeze(0).t() * B
Expand Down Expand Up @@ -95,29 +92,18 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor:


def make_empirical_fisher_vp(
func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor],
batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor],
log_prob_params: ParamDict,
data: Point[T],
is_batched: bool = False,
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[ParamDict], ParamDict]:
if not is_batched:
batched_func_log_prob: Callable[
[ParamDict, Point[T]], torch.Tensor
] = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
else:
batched_func_log_prob = functools.partial(func_log_prob, *args, **kwargs)

N = data[next(iter(data))].shape[0] # type: ignore
mean_vector = 1 / N * torch.ones(N)

def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor:
return batched_func_log_prob(params, data)
return batched_func_log_prob(params, data, *args, **kwargs)

def _empirical_fisher_vp(v: ParamDict) -> ParamDict:
def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:
Expand All @@ -141,7 +127,8 @@ def linearize(
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-4,
residual_tol: float = 1e-10,
pointwise_influence: bool = True,
) -> Callable[Concatenate[Point[T], P], ParamDict]:
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
Expand All @@ -154,11 +141,14 @@ def linearize(
num_samples=num_samples_outer,
parallel=True,
)
predictive_params, func_predictive = make_functional_call(predictive)

if is_point_estimate:
log_prob_type = PointLogPredictiveLikelihood
is_batched = True
batched_log_prob = BatchedNMCLogPredictiveLikelihood(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
)
log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob)
log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
if cg_iters is None:
cg_iters = log_prob_params_numel
else:
log_prob_type = NMCLogPredictiveLikelihood
is_batched = False
Expand All @@ -177,36 +167,31 @@ def linearize(

def _fn(
points: Point[T],
pointwise_influence: bool = True,
*args: P.args,
**kwargs: P.kwargs,
) -> ParamDict:
with torch.no_grad():
data: Point[T] = func_predictive(predictive_params, *args, **kwargs)
data: Point[T] = predictive(*args, **kwargs)
data = {k: data[k] for k in points.keys()}
fvp = make_efvp(func_log_prob, log_prob_params, data, *args, **kwargs)
fvp = make_empirical_fisher_vp(
batched_func_log_prob, log_prob_params, data, *args, **kwargs
)
pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp)
pinned_fvp_batched = torch.func.vmap(
lambda v: pinned_fvp(v), randomness="different"
)
if not is_point_estimate:
batched_func_log_prob = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
else:
batched_func_log_prob = functools.partial(func_log_prob, *args, **kwargs)
if log_prob_params_numel > points[next(iter(points))].shape[0]:
score_fn = torch.func.jacrev(batched_func_log_prob)
else:
score_fn = torch.func.jacfwd(batched_func_log_prob, randomness="different")
point_scores: ParamDict = score_fn(log_prob_params, points)
if not pointwise_influence:
point_scores = {
k: v.mean(dim=0).unsqueeze(0) for k, v in point_scores.items()
}

def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor:
return batched_func_log_prob(p, points, *args, **kwargs)

if pointwise_influence:
score_fn = torch.func.jacrev(bound_batched_func_log_prob)
point_scores = score_fn(log_prob_params)
else:
score_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1]
N_pts = points[next(iter(points))].shape[0] # type: ignore
point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0]
point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()}
return cg_solver(pinned_fvp_batched, point_scores)

return _fn
Loading

0 comments on commit 41db2e8

Please sign in to comment.