diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d31b5002606..7255865d7da 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -15,6 +15,8 @@ - `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)). - `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)). - The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)). +- Generalize BART. A BART variable can be combined with other random variables. The `inv_link` argument has been removed (see [4914](https://github.com/pymc-devs/pymc3/pull/4914)). +- Move BART to its own module (see [5058](https://github.com/pymc-devs/pymc3/pull/5058)). - ... ### New Features @@ -32,6 +34,8 @@ - New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004) - `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)). - `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)). +- BART: add linear response, increase number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). +- BART: add partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). - ... ### Maintenance diff --git a/pymc/bart/__init__.py b/pymc/bart/__init__.py index abace693c1f..b244c69cf65 100644 --- a/pymc/bart/__init__.py +++ b/pymc/bart/__init__.py @@ -15,5 +15,6 @@ from pymc.bart.bart import BART from pymc.bart.pgbart import PGBART +from pymc.bart.utils import plot_dependence, predict __all__ = ["BART", "PGBART"] diff --git a/pymc/bart/bart.py b/pymc/bart/bart.py index f8593d80d25..783378d3001 100644 --- a/pymc/bart/bart.py +++ b/pymc/bart/bart.py @@ -41,35 +41,7 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): @classmethod def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): - size = kwargs.pop("size", None) - X_new = kwargs.pop("X_new", None) - all_trees = cls.all_trees - if all_trees: - - if size is None: - size = () - elif isinstance(size, int): - size = [size] - - flatten_size = 1 - for s in size: - flatten_size *= s - - idx = rng.randint(len(all_trees), size=flatten_size) - - if X_new is None: - pred = np.zeros((flatten_size, all_trees[0][0].num_observations)) - for ind, p in enumerate(pred): - for tree in all_trees[idx[ind]]: - p += tree.predict_output() - else: - pred = np.zeros((flatten_size, X_new.shape[0])) - for ind, p in enumerate(pred): - for tree in all_trees[idx[ind]]: - p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new]) - return pred.reshape((*size, -1)) - else: - return np.full_like(cls.Y, cls.Y.mean()) + return np.full_like(cls.Y, cls.Y.mean()) bart = BARTRV() @@ -117,7 +89,6 @@ def __new__( **kwargs, ): - cls.all_trees = [] X, Y = preprocess_XY(X, Y) bart_op = type( @@ -125,7 +96,6 @@ def __new__( (BARTRV,), dict( name="BART", - all_trees=cls.all_trees, inplace=False, initval=Y.mean(), X=X, diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py index 3f25ff05d7a..733641038d1 100644 --- a/pymc/bart/pgbart.py +++ b/pymc/bart/pgbart.py @@ -14,6 +14,7 @@ import logging +from copy import copy from typing import Any, Dict, List, Tuple import aesara @@ -121,7 +122,7 @@ class PGBART(ArrayStepShared): name = "bartsampler" default_blocked = False generates_stats = True - stats_dtypes = [{"variable_inclusion": np.ndarray}] + stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}] def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None): _log.warning("BART is experimental. Use with caution.") @@ -159,6 +160,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo tree_id=0, leaf_node_value=self.init_mean / self.m, idx_data_points=np.arange(self.num_observations, dtype="int32"), + m=self.m, ) self.mean = fast_mean() self.linear_fit = fast_linear_fit() @@ -169,8 +171,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.tune = True self.idx = 0 - self.iter = 0 - self.sum_trees = [] self.batch = batch if self.batch == "auto": @@ -193,12 +193,12 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.init_likelihood, ) self.all_particles.append(p) + self.all_trees = np.array([p.tree for p in self.all_particles]) super().__init__(vars, shared) def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: point_map_info = q.point_map_info sum_trees_output = q.data - variable_inclusion = np.zeros(self.num_variates, dtype="int") if self.idx == self.m: @@ -212,7 +212,6 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: particles = self.init_particles(tree_id) # Compute the sum of trees without the tree we are attempting to replace self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output() - self.idx += 1 # The old tree is not growing so we update the weights only once. self.update_weight(particles[0]) @@ -258,6 +257,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: # Get the new tree and update new_particle = np.random.choice(particles, p=normalized_weights) new_tree = new_particle.tree + self.all_trees[self.idx] = new_tree new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles self.all_particles[tree_id] = new_particle sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output() @@ -268,17 +268,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: self.ssv = SampleSplittingVariable(self.split_prior) else: self.batch = max(1, int(self.m * 0.2)) - self.iter += 1 - self.sum_trees.append(new_tree) - if not self.iter % self.m: - # XXX update the all_trees variable in BARTRV to be used in the rng_fn method - # this fails for chains > 1 as the variable is not shared between proccesses - self.bart.all_trees.append(self.sum_trees) - self.sum_trees = [] for index in new_particle.used_variates: variable_inclusion[index] += 1 + self.idx += 1 - stats = {"variable_inclusion": variable_inclusion} + stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} sum_trees_output = RaveledVars(sum_trees_output, point_map_info) return sum_trees_output, [stats] @@ -526,11 +520,11 @@ def linear_fit(X, Y): xbar = np.sum(X) / n ybar = np.sum(Y) / n - if np.all(X == xbar): - b = 0 + den = X @ X - n * xbar ** 2 + if den > 1e-10: + b = (X @ Y - n * xbar * ybar) / den else: - b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2) - + b = 0 a = ybar - b * xbar Y_fit = a + b * X return Y_fit, [a, b, 0] diff --git a/pymc/bart/tree.py b/pymc/bart/tree.py index 4024d4786b2..b982e80bb66 100644 --- a/pymc/bart/tree.py +++ b/pymc/bart/tree.py @@ -45,7 +45,8 @@ class Tree: Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART. num_observations : int Number of observations used to fit BART. - + m : int + Number of trees Parameters ---------- @@ -53,13 +54,14 @@ class Tree: num_observations : int, optional """ - def __init__(self, tree_id=0, num_observations=0): + def __init__(self, tree_id=0, num_observations=0, m=0): self.tree_structure = {} self.num_nodes = 0 self.idx_leaf_nodes = [] self.idx_prunable_split_nodes = [] self.tree_id = tree_id self.num_observations = num_observations + self.m = m def __getitem__(self, index): return self.get_node(index) @@ -94,7 +96,7 @@ def predict_output(self): return output.astype(aesara.config.floatX) - def predict_out_of_sample(self, X, m): + def predict_out_of_sample(self, X): """ Predict output of tree for an unobserved point x. @@ -102,8 +104,6 @@ def predict_out_of_sample(self, X, m): ---------- X : numpy array Unobserved point - m : int - Number of trees Returns ------- @@ -116,7 +116,7 @@ def predict_out_of_sample(self, X, m): return leaf_node.value else: x = X[split_variable].item() - y_x = (linear_params[0] + linear_params[1] * x) / m + y_x = (linear_params[0] + linear_params[1] * x) / self.m return y_x + linear_params[2] def _traverse_tree(self, x, node_index=0, split_variable=None): @@ -170,7 +170,7 @@ def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_no self.idx_prunable_split_nodes.remove(parent_index) @staticmethod - def init_tree(tree_id, leaf_node_value, idx_data_points): + def init_tree(tree_id, leaf_node_value, idx_data_points, m): """ Parameters @@ -178,12 +178,14 @@ def init_tree(tree_id, leaf_node_value, idx_data_points): tree_id leaf_node_value idx_data_points + m : int + number of trees in BART Returns ------- """ - new_tree = Tree(tree_id, len(idx_data_points)) + new_tree = Tree(tree_id, len(idx_data_points), m) new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) return new_tree diff --git a/pymc/bart/utils.py b/pymc/bart/utils.py new file mode 100644 index 00000000000..f16789628c7 --- /dev/null +++ b/pymc/bart/utils.py @@ -0,0 +1,270 @@ +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from numpy.random import RandomState +from scipy.interpolate import griddata +from scipy.signal import savgol_filter + + +def predict(idata, rng, X_new=None, size=None): + """ + Generate samples from the BART-posterior + + Parameters + ---------- + idata: InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + rng: NumPy random generator + X_new : array-like + A new covariate matrix. Use it to obtain out-of-sample predictions + size: int or tuple + Number of samples. + """ + bart_trees = idata.sample_stats.bart_trees + stacked_trees = bart_trees.stack(trees=["chain", "draw"]) + if size is None: + size = () + elif isinstance(size, int): + size = [size] + + flatten_size = 1 + for s in size: + flatten_size *= s + + idx = rng.randint(len(stacked_trees.trees), size=flatten_size) + + if X_new is None: + pred = np.zeros((flatten_size, stacked_trees[0, 0].item().num_observations)) + for ind, p in enumerate(pred): + for tree in stacked_trees.isel(trees=idx[ind]).values: + p += tree.predict_output() + else: + pred = np.zeros((flatten_size, X_new.shape[0])) + for ind, p in enumerate(pred): + for tree in stacked_trees.isel(trees=idx[ind]).values: + p += np.array([tree.predict_out_of_sample(x) for x in X_new]) + return pred.reshape((*size, -1)) + + +def plot_dependence( + idata, + X=None, + Y=None, + kind="pdp", + xs_interval="linear", + xs_values=None, + var_idx=None, + samples=50, + instances=10, + random_seed=None, + sharey=True, + rug=True, + smooth=True, + indices=None, + grid="long", + color="C0", + color_mean="C0", + alpha=0.1, + figsize=None, + smooth_kwargs=None, + ax=None, +): + """ + Partial dependence or individual conditional expectation plot + + Parameters + ---------- + idata: InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : array-like + The covariate matrix. + Y : array-like + The response vector. + kind : str + Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation + plot ("ice"). Defaults to pdp. + xs_interval : str + Method used to compute the values X used to evaluate the predicted function. "linear", + evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified + quantiles of X. "insample", the evaluation is done at the values of X. + xs_values : int or list + Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of + points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of + quantiles to compute, which must be between 0 and 1 inclusive. + Ignored when ``xs_interval="insample"``. + var_idx : list + List of the indices of the covariate for which to compute the pdp or ice. + samples : int + Number of posterior samples used in the predictions. Defaults to 50 + instances : int + Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. + random_seed : int + random_seed used to sample from the posterior. Defaults to None. + sharey : bool + Controls sharing of properties among y-axes. Defaults to True. + rug : bool + Whether to include a rugplot. Defaults to True. + smooth=True, + If True the result will be smoothed by first computing a linear interpolation of the data + over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. + Defaults to True. + grid : str or tuple + How to arrange the subplots. Defaults to "long", one subplot below the other. + Other options are "wide", one subplot next to eachother or a tuple indicating the number of + rows and columns. + color : matplotlib valid color + Color used to plot the pdp or ice. Defaults to "C0" + color_mean : matplotlib valid color + Color used to plot the mean pdp or ice. Defaults to "C0", + alpha : float + Transparency level, should in the interval [0, 1]. + figsize : tuple + Figure size. If None it will be defined automatically. + smooth_kwargs : dict + Additional keywords modifying the Savitzky-Golay filter. + See scipy.signal.savgol_filter() for details. + ax : axes + Matplotlib axes. + + Returns + ------- + axes: matplotlib axes + """ + if kind not in ["pdp", "ice"]: + raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") + + if xs_interval not in ["insample", "linear", "quantiles"]: + raise ValueError( + f"""{xs_interval} is not suported. + Available option are 'insample', 'linear' or 'quantiles'""" + ) + + rng = RandomState(seed=random_seed) + + if isinstance(X, pd.DataFrame): + X_names = list(X.columns) + X = X.values + else: + X_names = [] + + if isinstance(Y, pd.DataFrame): + Y_label = f"Predicted {Y.name}" + else: + Y_label = "Predicted Y" + + num_observations = X.shape[0] + num_covariates = X.shape[1] + + indices = list(range(num_covariates)) + + if var_idx is None: + var_idx = indices + + if X_names: + X_labels = [X_names[idx] for idx in var_idx] + else: + X_labels = [f"X_{idx}" for idx in var_idx] + + if xs_interval == "linear" and xs_values is None: + xs_values = 10 + + if xs_interval == "quantiles" and xs_values is None: + xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] + + if kind == "ice": + instances = np.random.choice(range(X.shape[0]), replace=False, size=instances) + + new_Y = [] + new_X_target = [] + + new_X = np.zeros_like(X) + idx_s = list(range(X.shape[0])) + for i in var_idx: + indices_mi = indices[:] + indices_mi.pop(i) + y_pred = [] + if kind == "pdp": + if xs_interval == "linear": + new_X_i = np.linspace(X[:, i].min(), X[:, i].max(), xs_values) + elif xs_interval == "quantiles": + new_X_i = np.quantile(X[:, i], q=xs_values) + elif xs_interval == "insample": + new_X_i = X[:, i] + + for x_i in new_X_i: + new_X[:, indices_mi] = X[:, indices_mi] + new_X[:, i] = x_i + y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 1)) + new_X_target.append(new_X_i) + else: + for instance in instances: + new_X = X[idx_s] + new_X[:, indices_mi] = X[:, indices_mi][instance] + y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 0)) + new_X_target.append(new_X[:, i]) + new_Y.append(np.array(y_pred).T) + + if ax is None: + if grid == "long": + fig, axes = plt.subplots(len(var_idx), sharey=sharey, figsize=figsize) + elif grid == "wide": + fig, axes = plt.subplots(1, len(var_idx), sharey=sharey, figsize=figsize) + elif isinstance(grid, tuple): + _, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) + axes = np.ravel(axes) + else: + axes = [ax] + + if rug: + lb = np.min(new_Y) + + for i, ax in enumerate(axes): + if i >= len(var_idx): + ax.set_axis_off() + else: + if smooth: + if smooth_kwargs is None: + smooth_kwargs = {} + smooth_kwargs.setdefault("window_length", 55) + smooth_kwargs.setdefault("polyorder", 2) + x_data = np.linspace(new_X_target[i].min(), new_X_target[i].max(), 200) + x_data[0] = (x_data[0] + x_data[1]) / 2 + if kind == "pdp": + interp = griddata(new_X_target[i], new_Y[i].mean(0), x_data) + else: + interp = griddata(new_X_target[i], new_Y[i], x_data) + + y_data = savgol_filter(interp, axis=0, **smooth_kwargs) + + if kind == "pdp": + az.plot_hdi( + new_X_target[i], new_Y[i], color=color, fill_kwargs={"alpha": alpha}, ax=ax + ) + ax.plot(x_data, y_data, color=color_mean) + else: + ax.plot(x_data, y_data.mean(1), color=color_mean) + ax.plot(x_data, y_data, color=color, alpha=alpha) + + else: + idx = np.argsort(new_X_target[i]) + if kind == "pdp": + az.plot_hdi( + new_X_target[i], + new_Y[i], + smooth=smooth, + fill_kwargs={"alpha": alpha}, + ax=ax, + ) + ax.plot(new_X_target[i][idx], new_Y[i][idx].mean(0), color=color) + else: + ax.plot(new_X_target[i][idx], new_Y[i][idx], color=color, alpha=alpha) + ax.plot(new_X_target[i][idx], new_Y[i][idx].mean(1), color=color_mean) + + if rug: + ax.plot(X[:, i], np.full_like(X[:, i], lb), "k|") + + ax.set_xlabel(X_labels[i]) + ax.get_figure().text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15) + return axes diff --git a/pymc/sampling.py b/pymc/sampling.py index d5a3a27b81c..ff1261e4ce4 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -628,6 +628,15 @@ def sample( else: stat["variable_inclusion"] = [np.vstack(stat["variable_inclusion"])] + if "bart_trees" in trace.stat_names: + for strace in trace._straces.values(): + for stat in strace._stats: + if "bart_trees" in stat: + if trace.nchains > 1: + stat["bart_trees"] = np.vstack(stat["bart_trees"]) + else: + stat["bart_trees"] = [np.vstack(stat["bart_trees"])] + n_chains = len(trace.chains) _log.info( f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 1e01a449005..901e4e4f912 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from numpy.random import RandomState from numpy.testing import assert_almost_equal @@ -47,33 +48,52 @@ def test_bart_vi(): assert_almost_equal(var_imp.sum(), 1) -def test_bart_random(): +def test_missing_data(): X = np.random.normal(0, 1, size=(2, 50)).T Y = np.random.normal(0, 1, size=50) + X[10:20, 0] = np.nan with pm.Model() as model: mu = pm.BART("mu", X, Y, m=10) sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415, chains=1) - - rng = RandomState(12345) - pred_all = mu.owner.op.rng_fn(rng, size=2) - rng = RandomState(12345) - pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) - - assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50) - assert pred_first.shape == (10,) + idata = pm.sample(random_seed=3415) -def test_missing_data(): +class TestUtils: X = np.random.normal(0, 1, size=(2, 50)).T Y = np.random.normal(0, 1, size=50) - X[10:20, 0] = np.nan with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10) + mu = pm.BART("mu", X, Y, m=10, response="mix") sigma = pm.HalfNormal("sigma", 1) y = pm.Normal("y", mu, sigma, observed=Y) idata = pm.sample(random_seed=3415) + + def test_predict(self): + rng = RandomState(12345) + pred_all = pm.bart.utils.predict(self.idata, rng, size=2) + rng = RandomState(12345) + pred_first = pm.bart.utils.predict(self.idata, rng, X_new=self.X[:10]) + + assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + assert pred_all.shape == (2, 50) + assert pred_first.shape == (10,) + + @pytest.mark.parametrize( + "kwargs", + [ + {}, + { + "kind": "pdp", + "samples": 2, + "xs_interval": "quantiles", + "xs_values": [0.25, 0.5, 0.75], + }, + {"kind": "ice", "instances": 2}, + {"var_idx": [0], "rug": False, "smooth": False, "color": "k"}, + {"grid": (1, 2), "sharey": "none", "alpha": 1}, + ], + ) + def test_pdp(self, kwargs): + pm.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs)