Skip to content

Commit

Permalink
automated robust inference clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawalraj committed Dec 19, 2023
1 parent 7fa3a98 commit 6453bae
Show file tree
Hide file tree
Showing 3 changed files with 623 additions and 248 deletions.
21 changes: 18 additions & 3 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def influence_fn(
model: Callable[P, Any],
guide: Callable[P, Any],
functional: Optional[Functional[P, S]] = None,
**linearize_kwargs
) -> Callable[Concatenate[Point[T], P], S]:
**linearize_kwargs,
) -> Callable[Concatenate[Point[T], bool, P], S]:
from chirho.robust.internals.linearize import linearize
from chirho.robust.internals.predictive import PredictiveFunctional
from chirho.robust.internals.utils import make_functional_call
Expand All @@ -43,7 +43,7 @@ def _fn(
points: Point[T],
pointwise_influence: bool = False,
*args: P.args,
**kwargs: P.kwargs
**kwargs: P.kwargs,
) -> S:
param_eif = linearized(
points, pointwise_influence=pointwise_influence, *args, **kwargs
Expand All @@ -53,6 +53,21 @@ def _fn(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn


def one_step_correction(
model: Callable[P, Any],
guide: Callable[P, Any],
test_data: Point[T],
functional: Optional[Functional[P, S]] = None,
**influence_kwargs,
) -> Callable[P, S]:
def _one_step(*args, **kwargs) -> S:
eif_fn = influence_fn(model, guide, functional, **influence_kwargs)
return eif_fn(test_data, pointwise_influence=False, *args, **kwargs)

return _one_step
Loading

0 comments on commit 6453bae

Please sign in to comment.