Skip to content

Commit

Permalink
Final commit
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Oct 5, 2023
1 parent 26b5b5f commit 80a4304
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 23 deletions.
264 changes: 258 additions & 6 deletions benchmarks_nonasv/notebooks/forest_ht_independent_data.ipynb

Large diffs are not rendered by default.

20 changes: 4 additions & 16 deletions sktree/conftest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
# import pytest
import pytest


# def pytest_addoption(parser):
# parser.addoption(
# "--runslow", action="store_true", default=False, help="run slow tests"
# )
# With the following global module marker,
# monitoring is disabled by default:
pytestmark = [pytest.mark.monitor_skip_test]


def pytest_configure(config):
"""Set up pytest markers."""
config.addinivalue_line("markers", "slowtest: mark test as slow")


# def pytest_collection_modifyitems(config, items):
# if config.getoption("--runslow"):
# # --runslow given in cli: do not skip slow tests
# return
# skip_slow = pytest.mark.skip(reason="need --runslow option to run")
# for item in items:
# if "slow" in item.keywords:
# item.add_marker(skip_slow)
3 changes: 2 additions & 1 deletion sktree/stats/tests/meson.build
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
python_sources = [
'__init__.py',
'test_forestht.py'
'test_forestht.py',
'test_coleman.py'
]

py3.install_sources(
Expand Down
19 changes: 19 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,22 @@ def test_small_dataset():
stat, pvalue = clf.test(X, y, metric="mi")
assert stat == 0.0
assert pvalue > 0.05


# @pytest.mark.monitor_test
# def test_memory_usage():
# n_samples = 1000
# n_features = 5000
# X = rng.uniform(size=(n_samples, n_features))
# y = rng.integers(0, 2, size=n_samples) # Binary classification

# clf = FeatureImportanceForestClassifier(
# estimator=HonestForestClassifier(
# n_estimators=10, random_state=seed, n_jobs=-1, honest_fraction=0.5
# ),
# test_size=0.2,
# permute_per_tree=False,
# sample_dataset_per_tree=False,
# )

# stat, pvalue = clf.test(X, y, covariate_index=[1, 2], metric="mi")

0 comments on commit 80a4304

Please sign in to comment.