Skip to content

Commit

Permalink
API: accept additional partitioner parameters in simulate_feature() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
j-ittner authored Oct 20, 2022
1 parent ab3d158 commit ad013dd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
8 changes: 6 additions & 2 deletions src/facet/simulation/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,16 @@ def baseline(self) -> float:
return 0.0

def simulate_feature(
self, feature_name: str, *, partitioner: Partitioner[T_Values]
self,
feature_name: str,
*,
partitioner: Partitioner[T_Values],
**partitioner_params: Any,
) -> UnivariateSimulationResult[T_Values]:
"""[see superclass]"""

result = super().simulate_feature(
feature_name=feature_name, partitioner=partitioner
feature_name=feature_name, partitioner=partitioner, **partitioner_params
)

# offset the mean values to get uplift instead of absolute outputs
Expand Down
12 changes: 10 additions & 2 deletions src/facet/simulation/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,22 @@ def __init__(
)

def simulate_feature(
self, feature_name: str, *, partitioner: Partitioner[T_Value]
self,
feature_name: str,
*,
partitioner: Partitioner[T_Value],
**partitioner_params: Any,
) -> UnivariateSimulationResult[T_Value]:
"""
Simulate the average target uplift when fixing the value of the given feature
across all observations.
Simulations are run for a set of values determined by the given partitioner,
which is fitted to the observed values for the feature being simulated.
:param feature_name: the feature to run the simulation for
:param partitioner: the partitioner of feature values to run simulations for
:param partitioner_params: additional parameters to pass to the partitioner
:return: a mapping of output names to simulation results
"""

Expand All @@ -149,7 +157,7 @@ def simulate_feature(
mean, sem = self._simulate_feature_with_values(
feature_name=feature_name,
simulation_values=partitioner.fit(
sample.features.loc[:, feature_name]
sample.features.loc[:, feature_name], **partitioner_params
).partitions_,
)
return UnivariateSimulationResult(
Expand Down
11 changes: 6 additions & 5 deletions test/test/facet/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,20 @@ def test_univariate_target_subsample_simulation_80(
] = target_simulator.simulate_feature(
feature_name=parameterized_feature,
partitioner=partitioner,
lower_bound=3.8,
)

# test simulation results

index = pd.Index(
data=[2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0],
data=[4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0],
name=UnivariateSimulationResult.IDX_PARTITION,
)

assert_series_equal(
simulation_result.data.loc[:, UnivariateSimulationResult.COL_LOWER_BOUND],
pd.Series(
[25.05676, 25.05676, 25.05676, 22.96243, 21.43395]
[25.05676, 25.05676, 22.96243, 21.43395]
+ [21.21544, 20.76824, 20.49282, 20.49282],
name=UnivariateSimulationResult.COL_LOWER_BOUND,
index=index,
Expand All @@ -166,7 +167,7 @@ def test_univariate_target_subsample_simulation_80(
assert_series_equal(
simulation_result.data.loc[:, UnivariateSimulationResult.COL_MEAN],
pd.Series(
[25.642227, 25.642227, 25.642227, 23.598706, 22.067057]
[25.642227, 25.642227, 23.598706, 22.067057]
+ [21.864828, 21.451056, 21.195954, 21.195954],
name=UnivariateSimulationResult.COL_MEAN,
index=index,
Expand All @@ -176,15 +177,15 @@ def test_univariate_target_subsample_simulation_80(
assert_series_equal(
simulation_result.data.loc[:, UnivariateSimulationResult.COL_UPPER_BOUND],
pd.Series(
[26.22769, 26.22769, 26.22769, 24.23498, 22.70016]
[26.22769, 26.22769, 24.23498, 22.70016]
+ [22.51422, 22.13387, 21.89909, 21.89909],
name=UnivariateSimulationResult.COL_UPPER_BOUND,
index=index,
),
)

assert_array_equal(
simulation_result.partitioner.frequencies_, [1, 4, 9, 10, 10, 6, 2, 1, 4]
simulation_result.partitioner.frequencies_, [4, 9, 10, 10, 6, 2, 1, 4]
)

SimulationDrawer(style="text").draw(
Expand Down

0 comments on commit ad013dd

Please sign in to comment.