Skip to content

Commit

Permalink
move BART to its own module (#5058)
Browse files Browse the repository at this point in the history
* move BART to its own module

* add missing file
  • Loading branch information
aloctavodia authored Oct 8, 2021
1 parent 70f1975 commit a3cc81c
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 9 deletions.
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __set_compiler_flags():
to_inference_data,
)
from pymc.backends.tracetab import *
from pymc.bart import *
from pymc.blocking import *
from pymc.data import *
from pymc.distributions import *
Expand Down
19 changes: 19 additions & 0 deletions pymc/bart/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc.bart.bart import BART
from pymc.bart.pgbart import PGBART

__all__ = ["BART", "PGBART"]
File renamed without changes.
4 changes: 2 additions & 2 deletions pymc/step_methods/pgbart.py → pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from pandas import DataFrame, Series

from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc.bart.bart import BARTRV
from pymc.bart.tree import LeafNode, SplitNode, Tree
from pymc.blocking import RaveledVars
from pymc.distributions.bart import BARTRV
from pymc.distributions.tree import LeafNode, SplitNode, Tree
from pymc.model import modelcontext
from pymc.step_methods.arraystep import ArrayStepShared, Competence

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
logpt_sum,
)

from pymc.distributions.bart import BART
from pymc.distributions.bound import Bound
from pymc.distributions.continuous import (
AsymmetricLaplace,
Expand Down Expand Up @@ -190,7 +189,6 @@
"Rice",
"Moyal",
"Simulator",
"BART",
"CAR",
"PolyaGamma",
"logpt",
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
from pymc.backends.arviz import _DefaultTrace
from pymc.backends.base import BaseTrace, MultiTrace
from pymc.backends.ndarray import NDArray
from pymc.bart.pgbart import PGBART
from pymc.blocking import DictToArrayBijection
from pymc.distributions import NoDistribution
from pymc.exceptions import IncorrectArgumentsError, SamplingError
from pymc.model import Model, Point, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.step_methods import (
NUTS,
PGBART,
BinaryGibbsMetropolis,
BinaryMetropolis,
CategoricalGibbsMetropolis,
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@
MetropolisMLDA,
RecursiveDAProposal,
)
from pymc.step_methods.pgbart import PGBART
from pymc.step_methods.slicer import Slice
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pymc.aesaraf import floatX
from pymc.backends.report import SamplerWarning, WarningType
from pymc.distributions.bart import BARTRV
from pymc.bart.bart import BARTRV
from pymc.math import logbern, logdiffexp_numpy
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_split_node():
split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
assert split_node.index == 5
assert split_node.idx_split_variable == 2
assert split_node.split_value == 3.0
Expand All @@ -18,7 +18,7 @@ def test_split_node():


def test_leaf_node():
leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
leaf_node = pm.bart.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
assert leaf_node.index == 5
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
assert leaf_node.value == 3.14
Expand Down

0 comments on commit a3cc81c

Please sign in to comment.