Skip to content

Commit eaf1f4e

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Implement ProductBaselines (#1212)
Summary: Pull Request resolved: #1212 Implement a Callable Baselines class that returns a sample from the Cartesian product of the inputs' available baselines. Reviewed By: vivekmig Differential Revision: D51582979 fbshipit-source-id: ec7c833a5572b6a15c5cc8acb3fb9b1bcf439065
1 parent 5398892 commit eaf1f4e

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

captum/attr/_utils/baselines.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
import random
3+
from typing import Any, Dict, List, Tuple, Union
4+
5+
6+
class ProductBaselines:
7+
"""
8+
A Callable Baselines class that returns a sample from the Cartesian product of
9+
the inputs' available baselines.
10+
11+
Args:
12+
baseline_values (List or Dict): A list or dict of lists containing
13+
the possible values for each feature. If a dict is provided, the keys
14+
can a string of the feature name and the values is a list of available
15+
baselines. The keys can also be a tuple of strings to group
16+
multiple features whose baselines are not independent to each other.
17+
If the key is a tuple, the value must be a list of tuples of
18+
the corresponding values.
19+
"""
20+
21+
def __init__(
22+
self,
23+
baseline_values: Union[
24+
List[List[Any]],
25+
Dict[Union[str, Tuple[str, ...]], List[Any]],
26+
],
27+
):
28+
if isinstance(baseline_values, dict):
29+
dict_keys = list(baseline_values.keys())
30+
baseline_values = [baseline_values[k] for k in dict_keys]
31+
else:
32+
dict_keys = []
33+
34+
self.dict_keys = dict_keys
35+
self.baseline_values = baseline_values
36+
37+
def sample(self) -> Union[List[Any], Dict[str, Any]]:
38+
baselines = [
39+
random.choice(baseline_list) for baseline_list in self.baseline_values
40+
]
41+
42+
if not self.dict_keys:
43+
return baselines
44+
45+
dict_baselines = {}
46+
for key, val in zip(self.dict_keys, baselines):
47+
if not isinstance(key, tuple):
48+
key, val = (key,), (val,)
49+
50+
for k, v in zip(key, val):
51+
dict_baselines[k] = v
52+
53+
return dict_baselines
54+
55+
def __call__(self) -> Union[List[Any], Dict[str, Any]]:
56+
"""
57+
Returns:
58+
59+
baselines (List or Dict): A sample from the Cartesian product of
60+
the inputs' available baselines
61+
"""
62+
return self.sample()

tests/attr/test_baselines.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import cast, Dict, List, Tuple, Union
2+
3+
from captum.attr._utils.baselines import ProductBaselines
4+
5+
# from parameterized import parameterized
6+
from tests.helpers.basic import BaseTest
7+
8+
9+
class TestProductBaselines(BaseTest):
10+
def test_list(self) -> None:
11+
baseline_values = [
12+
[1, 2, 3],
13+
[4, 5, 6, 7],
14+
[8, 9],
15+
]
16+
17+
baselines = ProductBaselines(baseline_values)
18+
19+
baseline_sample = baselines()
20+
21+
self.assertIsInstance(baseline_sample, list)
22+
for sample_val, vals in zip(baseline_sample, baseline_values):
23+
self.assertIn(sample_val, vals)
24+
25+
def test_dict(self) -> None:
26+
baseline_values = {
27+
"f1": [1, 2, 3],
28+
"f2": [4, 5, 6, 7],
29+
"f3": [8, 9],
30+
}
31+
32+
baselines = ProductBaselines(
33+
cast(Dict[Union[str, Tuple[str, ...]], List[int]], baseline_values)
34+
)
35+
36+
baseline_sample = baselines()
37+
38+
self.assertIsInstance(baseline_sample, dict)
39+
baseline_sample = cast(dict, baseline_sample)
40+
41+
for sample_key, sample_val in baseline_sample.items():
42+
self.assertIn(sample_val, baseline_values[sample_key])
43+
44+
def test_dict_tuple_key(self) -> None:
45+
baseline_values: Dict[Union[str, Tuple[str, ...]], List] = {
46+
("f1", "f2"): [(1, "1"), (2, "2"), (3, "3")],
47+
"f3": [4, 5],
48+
}
49+
50+
baselines = ProductBaselines(baseline_values)
51+
52+
baseline_sample = baselines()
53+
54+
self.assertIsInstance(baseline_sample, dict)
55+
baseline_sample = cast(dict, baseline_sample)
56+
57+
self.assertEqual(len(baseline_sample), 3)
58+
59+
self.assertIn(
60+
(baseline_sample["f1"], baseline_sample["f2"]),
61+
baseline_values[("f1", "f2")],
62+
)
63+
self.assertIn(baseline_sample["f3"], baseline_values["f3"])

0 commit comments

Comments
 (0)