Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add influence functions to interpret module #4988

Merged
merged 56 commits into from
Apr 19, 2021
Merged
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
47762ca
creating a new functionality to fields and instances to support outp…
leo-liuzy Feb 15, 2021
4895ad8
creating tests for the new functionality
leo-liuzy Feb 15, 2021
3689bde
fixing docs
leo-liuzy Feb 15, 2021
fde2c08
Delete __init__.py
Feb 15, 2021
756808c
Delete influence_interpreter.py
Feb 15, 2021
ae1900a
Delete use_if.py
Feb 15, 2021
b00153a
Delete simple_influence_test.py
Feb 15, 2021
c682a42
fixing docs
leo-liuzy Feb 15, 2021
0f7911d
Merge branch 'to_json' of github.com:allenai/allennlp into to_json
leo-liuzy Feb 15, 2021
6ed0be6
finishing up SimpleInfluence
leo-liuzy Feb 17, 2021
3fbbc10
Merge branch 'main' into simple-interpreter-new
Feb 17, 2021
2d48293
passing lint
leo-liuzy Feb 17, 2021
229872f
Merge branch 'simple-interpreter-new' of github.com:allenai/allennlp …
leo-liuzy Feb 17, 2021
e128609
passing format
leo-liuzy Feb 17, 2021
0881f8b
Merge branch 'main' into simple-interpreter-new
Mar 3, 2021
a389800
making small progress in coding
leo-liuzy Mar 10, 2021
1c1c14f
Delete fast_influence.py
Mar 11, 2021
ccb19f9
Delete faiss_utils.py
Mar 11, 2021
e2a2153
Delete gpt2_bug.py
Mar 11, 2021
b87927d
Delete text_class.py
Mar 11, 2021
884492a
Merge branch 'main' into simple-interpreter-new
Mar 17, 2021
280df3f
adding test file
leo-liuzy Mar 23, 2021
8489f74
adding testing files
leo-liuzy Mar 23, 2021
c5cbddb
deleted unwanted files
leo-liuzy Mar 24, 2021
84bb076
deleted unwanted files and rearrange test files
leo-liuzy Mar 24, 2021
0354672
small bug
leo-liuzy Mar 25, 2021
307b86b
adjust function call to save instance in json
leo-liuzy Mar 25, 2021
a5623f0
Update allennlp/interpret/influence_interpreters/influence_interprete…
Mar 25, 2021
e19dc58
Update allennlp/interpret/influence_interpreters/influence_interprete…
Mar 25, 2021
80558ab
Update allennlp/interpret/influence_interpreters/influence_interprete…
Mar 25, 2021
a3e42e7
move some documentation of parameters to base class
leo-liuzy Mar 25, 2021
6fe6149
delete one comment
leo-liuzy Mar 25, 2021
feb8b20
delete one deprecated abstract method
leo-liuzy Mar 25, 2021
4aa29c8
changing interface
leo-liuzy Mar 26, 2021
239fbfb
formatting
leo-liuzy Mar 26, 2021
90bd50c
formatting err
leo-liuzy Mar 26, 2021
f49ecec
passing mypy
leo-liuzy Mar 26, 2021
c4e2b43
passing mypy
leo-liuzy Mar 26, 2021
2eeaa34
passing mypy
leo-liuzy Mar 26, 2021
a580517
passing mypy
leo-liuzy Mar 26, 2021
2554813
passing integration test
leo-liuzy Mar 26, 2021
8c18d58
passing integration test
leo-liuzy Mar 26, 2021
e5055fc
adding a new option to the do-all function
leo-liuzy Mar 26, 2021
fd3cc62
modifying the callable function to the interface
leo-liuzy Mar 27, 2021
1cd011a
Merge branch 'main' into simple-interpreter-new
epwalsh Apr 13, 2021
a21a74e
update API, fixes
epwalsh Apr 13, 2021
81ff0c3
Merge branch 'main' into simple-interpreter-new
epwalsh Apr 13, 2021
a9e03aa
doc fixes
epwalsh Apr 13, 2021
c66d7c4
Merge branch 'main' into simple-interpreter-new
epwalsh Apr 14, 2021
1a365f2
add `from_path` and `from_archive` methods
epwalsh Apr 14, 2021
5dbdfbe
fix docs, improve logging
epwalsh Apr 14, 2021
f7bcb3d
add test
epwalsh Apr 14, 2021
c9a5329
fix merge conflicts
epwalsh Apr 15, 2021
2ce2c2e
address @matt-gardner's comments
epwalsh Apr 15, 2021
b644d3d
fixes to documentation
epwalsh Apr 19, 2021
aa4a5b5
update docs
epwalsh Apr 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`.
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
- The test for distributed metrics now takes a parameter specifying how often you want to run it.


2 changes: 1 addition & 1 deletion allennlp/data/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
TensorDict,
allennlp_collate,
)
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, DatasetReaderInput
from allennlp.data.fields.field import DataArray, Field
from allennlp.data.fields.text_field import TextFieldTensors
from allennlp.data.instance import Instance
19 changes: 16 additions & 3 deletions allennlp/data/data_loaders/multiprocess_data_loader.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from multiprocessing.process import BaseProcess
import random
import traceback
from typing import List, Iterator, Optional, Iterable, Union
from typing import List, Iterator, Optional, Iterable, Union, TypeVar

from overrides import overrides
import torch
@@ -23,6 +23,9 @@
logger = logging.getLogger(__name__)


_T = TypeVar("_T")


@DataLoader.register("multiprocess")
class MultiProcessDataLoader(DataLoader):
"""
@@ -118,6 +121,9 @@ class MultiProcessDataLoader(DataLoader):
will automatically call [`set_target_device()`](#set_target_device) before iterating
over batches.
quiet : `bool`, optional (default = `False`)
If `True`, tqdm progress bars will be disabled.
# Best practices
- **Large datasets**
@@ -200,6 +206,7 @@ def __init__(
max_instances_in_memory: int = None,
start_method: str = "fork",
cuda_device: Optional[Union[int, str, torch.device]] = None,
quiet: bool = False,
) -> None:
# Do some parameter validation.
if num_workers is not None and num_workers < 0:
@@ -240,6 +247,7 @@ def __init__(
self.collate_fn = allennlp_collate
self.max_instances_in_memory = max_instances_in_memory
self.start_method = start_method
self.quiet = quiet
self.cuda_device: Optional[torch.device] = None
if cuda_device is not None:
if not isinstance(cuda_device, torch.device):
@@ -346,7 +354,7 @@ def iter_instances(self) -> Iterator[Instance]:

if self.num_workers <= 0:
# Just read all instances in main process.
for instance in Tqdm.tqdm(
for instance in self._maybe_tqdm(
self.reader.read(self.data_path), desc="loading instances"
):
self.reader.apply_token_indexers(instance)
@@ -365,7 +373,7 @@ def iter_instances(self) -> Iterator[Instance]:
workers = self._start_instance_workers(queue, ctx)

try:
for instance in Tqdm.tqdm(
for instance in self._maybe_tqdm(
self._gather_instances(queue), desc="loading instances"
):
if self.max_instances_in_memory is None:
@@ -569,6 +577,11 @@ def _instances_to_batches(
break
yield tensorize(batch)

def _maybe_tqdm(self, iterator: Iterable[_T], **tqdm_kwargs) -> Iterable[_T]:
if self.quiet:
return iterator
return Tqdm.tqdm(iterator, **tqdm_kwargs)


class WorkerError(Exception):
"""
9 changes: 8 additions & 1 deletion allennlp/data/data_loaders/simple_data_loader.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import torch

from allennlp.common.util import lazy_groups_of
from allennlp.common.tqdm import Tqdm
from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate, TensorDict
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.instance import Instance
@@ -37,6 +38,8 @@ def __init__(
self._batch_generator: Optional[Iterator[TensorDict]] = None

def __len__(self) -> int:
if self.batches_per_epoch is not None:
return self.batches_per_epoch
return math.ceil(len(self.instances) / self.batch_size)

@overrides
@@ -87,6 +90,10 @@ def from_dataset_reader(
batch_size: int,
shuffle: bool = False,
batches_per_epoch: Optional[int] = None,
quiet: bool = False,
) -> "SimpleDataLoader":
instances = list(reader.read(data_path))
instance_iter = reader.read(data_path)
if not quiet:
instance_iter = Tqdm.tqdm(instance_iter, desc="loading instances")
instances = list(instance_iter)
return cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch)
1 change: 1 addition & 0 deletions allennlp/interpret/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from allennlp.interpret.attackers.attacker import Attacker
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
from allennlp.interpret.influence_interpreters.influence_interpreter import InfluenceInterpreter
2 changes: 2 additions & 0 deletions allennlp/interpret/influence_interpreters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from allennlp.interpret.influence_interpreters.influence_interpreter import InfluenceInterpreter
from allennlp.interpret.influence_interpreters.simple_influence import SimpleInfluence
427 changes: 427 additions & 0 deletions allennlp/interpret/influence_interpreters/influence_interpreter.py

Large diffs are not rendered by default.

244 changes: 244 additions & 0 deletions allennlp/interpret/influence_interpreters/simple_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import logging
from typing import List, Optional, Tuple, Union, Sequence

import numpy as np
from overrides import overrides
import torch
import torch.autograd as autograd

from allennlp.common import Lazy
from allennlp.common.tqdm import Tqdm
from allennlp.data import DatasetReader, DatasetReaderInput, Instance
from allennlp.data.data_loaders import DataLoader, SimpleDataLoader
from allennlp.interpret.influence_interpreters.influence_interpreter import (
InfluenceInterpreter,
)
from allennlp.models.model import Model


logger = logging.getLogger(__name__)


@InfluenceInterpreter.register("simple-influence")
class SimpleInfluence(InfluenceInterpreter):
"""
Registered as an `InfluenceInterpreter` with name "simple-influence".
This goes through every example in the train set to calculate the influence score. It uses
[LiSSA (Linear time Stochastic Second-Order Algorithm)](https://api.semanticscholar.org/CorpusID:10569090)
to approximate the inverse of the Hessian used for the influence score calculation.
# Parameters
lissa_batch_size : `int`, optional (default = `8`)
The batch size to use for LiSSA.
According to [Koh, P.W., & Liang, P. (2017)](https://api.semanticscholar.org/CorpusID:13193974),
it is better to use batched samples for approximation for better stability.
damping : `float`, optional (default = `3e-3`)
This is a hyperparameter for LiSSA.
A damping termed added in case the approximated Hessian (during LiSSA) has
negative eigenvalues.
num_samples : `int`, optional (default = `1`)
This is a hyperparameter for LiSSA that we
determine how many rounds of the recursion process we would like to run for approxmation.
recursion_depth : `Union[float, int]`, optional (default = `0.25`)
This is a hyperparameter for LiSSA that
determines the recursion depth we would like to go through.
If a `float`, it means X% of the training examples.
If an `int`, it means recurse for X times.
scale : `float`, optional, (default = `1e4`)
This is a hyperparameter for LiSSA to tune such that the Taylor expansion converges.
It is applied to scale down the loss during LiSSA to ensure that `H <= I`,
where `H` is the Hessian and `I` is the identity matrix.
See footnote 2 of [Koh, P.W., & Liang, P. (2017)](https://api.semanticscholar.org/CorpusID:13193974).
!!! Note
We choose the same default values for the LiSSA hyperparameters as
[Han, Xiaochuang et al. (2020)](https://api.semanticscholar.org/CorpusID:218628619).
"""

def __init__(
self,
model: Model,
train_data_path: DatasetReaderInput,
train_dataset_reader: DatasetReader,
*,
test_dataset_reader: Optional[DatasetReader] = None,
train_data_loader: Lazy[DataLoader] = Lazy(SimpleDataLoader.from_dataset_reader),
test_data_loader: Lazy[DataLoader] = Lazy(SimpleDataLoader.from_dataset_reader),
params_to_freeze: List[str] = None,
cuda_device: int = -1,
lissa_batch_size: int = 8,
damping: float = 3e-3,
num_samples: int = 1,
recursion_depth: Union[float, int] = 0.25,
scale: float = 1e4,
) -> None:
super().__init__(
model=model,
train_data_path=train_data_path,
train_dataset_reader=train_dataset_reader,
test_dataset_reader=test_dataset_reader,
train_data_loader=train_data_loader,
test_data_loader=test_data_loader,
params_to_freeze=params_to_freeze,
cuda_device=cuda_device,
)

self._lissa_dataloader = SimpleDataLoader(
list(self._train_loader.iter_instances()),
lissa_batch_size,
shuffle=True,
vocab=self.vocab,
)
self._lissa_dataloader.set_target_device(self.device)
if isinstance(recursion_depth, float) and recursion_depth > 0.0:
self._lissa_dataloader.batches_per_epoch = int(
len(self._lissa_dataloader) * recursion_depth
)
elif isinstance(recursion_depth, int) and recursion_depth > 0:
self._lissa_dataloader.batches_per_epoch = recursion_depth
else:
raise ValueError("'recursion_depth' should be a positive int or float")

self._damping = damping
self._num_samples = num_samples
self._recursion_depth = recursion_depth
self._scale = scale

@overrides
def _calculate_influence_scores(
self, test_instance: Instance, test_loss: float, test_grads: Sequence[torch.Tensor]
) -> List[float]:
# Approximate the inverse of Hessian-Vector Product through LiSSA
inv_hvp = get_inverse_hvp_lissa(
test_grads,
self.model,
self.used_params,
self._lissa_dataloader,
self._damping,
self._num_samples,
self._scale,
)
return [
# dL_test * d theta as in 2.2 of [https://arxiv.org/pdf/2005.06676.pdf]
# TODO (epwalsh): should we divide `x.grads` by `self._scale`?
torch.dot(inv_hvp, _flatten_tensors(x.grads)).item()
for x in Tqdm.tqdm(self.train_instances, desc="scoring train instances")
]


def get_inverse_hvp_lissa(
vs: Sequence[torch.Tensor],
model: Model,
used_params: Sequence[torch.Tensor],
lissa_data_loader: DataLoader,
damping: float,
num_samples: int,
scale: float,
) -> torch.Tensor:
"""
This function approximates the product of the inverse of the Hessian and
the vectors `vs` using LiSSA.
Adapted from [github.com/kohpangwei/influence-release]
(https://github.com/kohpangwei/influence-release/blob/0f656964867da6ddcca16c14b3e4f0eef38a7472/influence/genericNeuralNet.py#L475),
the repo for [Koh, P.W., & Liang, P. (2017)](https://api.semanticscholar.org/CorpusID:13193974),
and [github.com/xhan77/influence-function-analysis]
(https://github.com/xhan77/influence-function-analysis/blob/78d5a967aba885f690d34e88d68da8678aee41f1/bert_util.py#L336),
the repo for [Han, Xiaochuang et al. (2020)](https://api.semanticscholar.org/CorpusID:218628619).
"""
inverse_hvps = [torch.tensor(0) for _ in vs]
for _ in Tqdm.tqdm(range(num_samples), desc="LiSSA samples", total=num_samples):
# See a explanation at "Stochastic estimation" paragraph in [https://arxiv.org/pdf/1703.04730.pdf]
# initialize \tilde{H}^{−1}_0 v = v
cur_estimates = vs
recursion_iter = Tqdm.tqdm(
lissa_data_loader, desc="LiSSA depth", total=len(lissa_data_loader)
)
for j, training_batch in enumerate(recursion_iter):
# TODO (epwalsh): should we make sure `model` is in "train" or "eval" mode here?
model.zero_grad()
train_output_dict = model(**training_batch)
# Hessian of loss @ \tilde{H}^{−1}_{j - 1} v
hvps = get_hvp(train_output_dict["loss"], used_params, cur_estimates)

# This is the recursive step:
# cur_estimate = \tilde{H}^{−1}_{j - 1} v
# (i.e. Hessian-Vector Product estimate from last iteration)
# Updating for \tilde{H}^{−1}_j v, the new current estimate becomes:
# v + (I - (Hessian_at_x + damping)) * cur_estimate
# = v + (I + damping) * cur_estimate - Hessian_at_x * cur_estimate
# We divide `hvp / scale` here (or, equivalently `Hessian_at_x / scale`)
# so that we're effectively dividing the loss by `scale`.
cur_estimates = [
v + (1 - damping) * cur_estimate - hvp / scale
for v, cur_estimate, hvp in zip(vs, cur_estimates, hvps)
]

# Update the Tqdm progress bar with the current norm so the user can
# see it converge.
if (j % 50 == 0) or (j == len(lissa_data_loader) - 1):
norm = np.linalg.norm(_flatten_tensors(cur_estimates).cpu().numpy())
recursion_iter.set_description(desc=f"calculating inverse HVP, norm = {norm:.5f}")

# Accumulating X_{[i,S_2]} (notation from the LiSSA (algo. 1) [https://arxiv.org/pdf/1602.03943.pdf]
# Need to divide by `scale` again here because the `vs` represent gradients
# that haven't been scaled yet.
inverse_hvps = [
inverse_hvp + cur_estimate / scale
for inverse_hvp, cur_estimate in zip(inverse_hvps, cur_estimates)
]
return_ihvp = _flatten_tensors(inverse_hvps)
return_ihvp /= num_samples
return return_ihvp


def get_hvp(
loss: torch.Tensor, params: Sequence[torch.Tensor], vectors: Sequence[torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
"""
Get a Hessian-Vector Product (HVP) `Hv` for each Hessian `H` of the `loss`
with respect to the one of the parameter tensors in `params` and the corresponding
vector `v` in `vectors`.
# Parameters
loss : `torch.Tensor`
The loss calculated from the output of the model.
params : `Sequence[torch.Tensor]`
Tunable and used parameters in the model that we will calculate the gradient and hessian
with respect to.
vectors : `Sequence[torch.Tensor]`
The list of vectors for calculating the HVP.
"""
# Sanity check before performing element-wise multiplication
assert len(params) == len(vectors)
assert all(p.size() == v.size() for p, v in zip(params, vectors))
grads = autograd.grad(loss, params, create_graph=True, retain_graph=True)
hvp = autograd.grad(grads, params, grad_outputs=vectors)
return hvp


def _flatten_tensors(tensors: Sequence[torch.Tensor]) -> torch.Tensor:
"""
Unwraps a list of parameters gradients
# Returns
`torch.Tensor`
A tensor of shape `(x,)` where `x` is the total number of entires in the gradients.
"""
views = []
for p in tensors:
if p.data.is_sparse:
view = p.data.to_dense().view(-1)
else:
view = p.data.view(-1)
views.append(view)
return torch.cat(views, 0)
5 changes: 2 additions & 3 deletions allennlp/models/archival.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@
import tempfile
import tarfile
import shutil
from pathlib import Path
from contextlib import contextmanager
import glob

@@ -157,7 +156,7 @@ def archive_model(


def load_archive(
archive_file: Union[str, Path],
archive_file: Union[str, PathLike],
cuda_device: int = -1,
overrides: Union[str, Dict[str, Any]] = "",
weights_file: str = None,
@@ -167,7 +166,7 @@ def load_archive(
# Parameters
archive_file : `Union[str, Path]`
archive_file : `Union[str, PathLike]`
The archive file to load the model from.
cuda_device : `int`, optional (default = `-1`)
If `cuda_device` is >= 0, the model will be loaded onto the
6 changes: 3 additions & 3 deletions allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Iterator, Dict, Tuple, Any, Type, Union, Optional
import logging
from os import PathLike
import json
import re
from contextlib import contextmanager
from pathlib import Path

import numpy
import torch
@@ -314,7 +314,7 @@ def _batch_json_to_instances(self, json_dicts: List[JsonDict]) -> List[Instance]
@classmethod
def from_path(
cls,
archive_path: Union[str, Path],
archive_path: Union[str, PathLike],
predictor_name: str = None,
cuda_device: int = -1,
dataset_reader_to_load: str = "validation",
@@ -330,7 +330,7 @@ def from_path(
# Parameters
archive_path : `Union[str, Path]`
archive_path : `Union[str, PathLike]`
The path to the archive.
predictor_name : `str`, optional (default=`None`)
Name that the predictor is registered as, or None to use the
7 changes: 6 additions & 1 deletion scripts/py2md.py
Original file line number Diff line number Diff line change
@@ -274,7 +274,12 @@ class AllenNlpFilterProcessor(Struct):
Used to filter out nodes that we don't want to document.
"""

PRIVATE_METHODS_TO_KEEP = {"DatasetReader._read", "__call__", "__iter__"}
PRIVATE_METHODS_TO_KEEP = {
"DatasetReader._read",
"__call__",
"__iter__",
"InfluenceInterpreter._calculate_influence_scores",
}

def process(self, graph, _resolver):
graph.visit(self._process_node)
102 changes: 102 additions & 0 deletions tests/interpret/simple_influence_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.fields import TensorField
from allennlp.data import Instance
from allennlp.models.model import Model
from allennlp.data.data_loaders import SimpleDataLoader

from allennlp.interpret import InfluenceInterpreter
from allennlp.interpret.influence_interpreters.simple_influence import (
_flatten_tensors,
get_hvp,
get_inverse_hvp_lissa,
)


class DummyBilinearModelForTestingIF(Model):
def __init__(self, vocab, params):
super().__init__(vocab)
self.x = torch.nn.Parameter(params.float(), requires_grad=True)

def forward(self, tensors):
A = tensors # (batch_size, ..., ...)
output_dict = {"loss": 1 / 2 * (A @ self.x @ self.x)}
return output_dict


def test_get_hvp():
# This represents some train data point input.
X = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
# This represents the weights of a model.
w = torch.nn.Parameter(torch.tensor([1, 2]).float(), requires_grad=True)
# This is the vector in the HVP.
v = torch.tensor([10, 20]).float()
# And this is the forward pass / loss calculation of the model.
loss = 1 / 2 * (w @ X @ w.T)

expected_answer = 1 / 2 * (X + X.T) @ v

hessian_vector_product = get_hvp(loss, [w], [v])[0]
assert torch.equal(hessian_vector_product, expected_answer)


def test_flatten_tensors():
A = torch.nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]]), requires_grad=True)
B = torch.nn.Parameter(torch.tensor([[5.0, 6.0], [7.0, 8.0]]), requires_grad=True)
flatten_grad = _flatten_tensors([A, B])
ans = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).float()
assert torch.equal(flatten_grad, ans)


def test_get_inverse_hvp_lissa():
vs = [torch.tensor([1.0, 1.0])]
# create a fake model
vocab = Vocabulary()
params = torch.tensor([1, 2]).float()
model = DummyBilinearModelForTestingIF(vocab, params)
used_params = list(model.parameters())

# create a fake instance: just a matrix
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
fake_instance = Instance({"tensors": TensorField(A)})

# wrap fake instance into dataloader
lissa_data_loader = SimpleDataLoader([fake_instance], batch_size=1, batches_per_epoch=1)

inverse_hvp = get_inverse_hvp_lissa(
vs=vs,
model=model,
used_params=used_params,
lissa_data_loader=lissa_data_loader,
damping=0.0,
num_samples=1,
scale=1.0,
)
# I tried to increase recursion depth to actually approx the inverse Hessian vector product,
# but I suspect due to extremely small number of data point, the algorithm doesn't work well
# on this toy example
ans = torch.tensor([-1.5, -4.5])
assert torch.equal(inverse_hvp, ans)


class TestSimpleInfluence(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.archive_path = (
self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
)
self.data_path = (
self.FIXTURES_ROOT / "data" / "text_classification_json" / "imdb_corpus.jsonl"
)

def test_simple_influence(self):
# NOTE: We use the same data here for test and train, which is pointless in
# real life but convenient here.
si = InfluenceInterpreter.from_path(
self.archive_path, train_data_path=self.data_path, recursion_depth=3
)
results = si.interpret_from_file(self.data_path, k=1)
assert len(results) == 3
for result in results:
assert len(result.top_k) == 1