From 04c6800ff124b154eda108651a0e082cd61843a7 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 17 Aug 2021 13:02:22 +0300 Subject: [PATCH] add size argument and check for NoDistribution --- pymc3/distributions/bart.py | 40 ++++++++++++++++++++++--------------- pymc3/sampling.py | 7 +++++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index d7694fe6616..03e6d587bd8 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -36,27 +36,35 @@ class BARTRV(RandomVariable): def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) - def __new__(cls, *args, **kwargs): - return super().__new__(cls) - @classmethod - def rng_fn(cls, rng=np.random.default_rng(), X_new=None, *args, **kwargs): + def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): + size = kwargs.pop("size", None) + X_new = kwargs.pop("X_new", None) all_trees = cls.all_trees if all_trees: - # this should be rng.integers() but when sampling from the prior/posterior predictive - # I get 'numpy.random.mtrand.RandomState' object has no attribute 'integers' - # So I guess those functions need to be updated - idx = np.random.randint(len(all_trees)) - trees = all_trees[idx] + + if size is None: + size = () + elif isinstance(size, int): + size = [size] + + flatten_size = 1 + for s in size: + flatten_size *= s + + idx = rng.randint(len(all_trees), size=flatten_size) + if X_new is None: - pred = np.zeros(trees[0].num_observations) - for tree in trees: - pred += tree.predict_output() + pred = np.zeros((flatten_size, all_trees[0][0].num_observations)) + for ind, p in enumerate(pred): + for tree in all_trees[idx[ind]]: + p += tree.predict_output() else: - pred = np.zeros(X_new.shape[0]) - for tree in trees: - pred += np.array([tree.predict_out_of_sample(x) for x in X_new]) - return pred + pred = np.zeros((flatten_size, X_new.shape[0])) + for ind, p in enumerate(pred): + for tree in all_trees[idx[ind]]: + p += np.array([tree.predict_out_of_sample(x) for x in X_new]) + return pred.reshape((*size, -1)) else: return np.full_like(cls.Y, cls.Y.mean()) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 205fc5ab98c..dd4a7a091b4 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -42,7 +42,7 @@ from pymc3.backends.base import BaseTrace, MultiTrace from pymc3.backends.ndarray import NDArray from pymc3.blocking import DictToArrayBijection -from pymc3.distributions.bart import BARTRV +from pymc3.distributions import NoDistribution from pymc3.exceptions import IncorrectArgumentsError, SamplingError from pymc3.model import Model, Point, modelcontext from pymc3.parallel_sampling import Draw, _cpu_count @@ -240,7 +240,10 @@ def all_continuous(vars, model): if any( [ - (var.dtype in discrete_types or isinstance(model.values_to_rvs[var].owner.op, BARTRV)) + ( + var.dtype in discrete_types + or isinstance(model.values_to_rvs[var].owner.op, NoDistribution) + ) for var in vars_ ] ):