Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: plot decision tree #876

Merged
merged 37 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
dd1300d
added plot for DecisionTreeClassifier
timwirt Jun 28, 2024
7c730d0
added plot for DecisionTreeRegressor
timwirt Jun 28, 2024
ec712e2
new snapshot for regressor
timwirt Jun 28, 2024
5ab2a05
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
c2373f3
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
4a8fb50
added snapshot for linux
SamanHushi Jul 12, 2024
829b95e
fixed force plot from sklearn
SamanHushi Jul 12, 2024
dffdf42
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
a441379
fixed force plot from sklearn in regressor
SamanHushi Jul 12, 2024
0fc9880
styling
SamanHushi Jul 12, 2024
d753493
Merge branch '856-plot-decision-tree' of github.com:Safe-DS/Library i…
SamanHushi Jul 12, 2024
bb848f5
mac snapshot
LIEeOoNn Jul 12, 2024
b8b22a5
skip linux and mac
timwirt Jul 12, 2024
dd4ba10
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
410cf3e
update windows snapshot
timwirt Jul 12, 2024
7926e2b
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
ef25c03
added documentation
timwirt Jul 12, 2024
1fcbb02
Merge branch '856-plot-decision-tree' of https://github.com/Safe-DS/L…
timwirt Jul 12, 2024
b8b47f3
fixed ruff errors
timwirt Jul 12, 2024
d20352f
test: update snapshots
lars-reimann Jul 12, 2024
c84d691
removed linux from skip test
timwirt Jul 12, 2024
9bdfadc
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
5c5ae1e
Update src/safeds/ml/classical/classification/_decision_tree_classifi…
timwirt Jul 12, 2024
64f955a
Update src/safeds/ml/classical/regression/_decision_tree_regressor.py
timwirt Jul 12, 2024
8c704e0
Update tests/safeds/ml/classical/regression/test_decision_tree.py
timwirt Jul 12, 2024
c9ed5d6
Update tests/safeds/ml/classical/regression/test_decision_tree.py
timwirt Jul 12, 2024
1ba5db8
Update tests/safeds/ml/classical/classification/test_decision_tree.py
timwirt Jul 12, 2024
3a3564f
Update tests/safeds/ml/classical/classification/test_decision_tree.py
timwirt Jul 12, 2024
d5d87b6
added raised errors to documentation
timwirt Jul 12, 2024
603cd74
push
timwirt Jul 12, 2024
c8bf7f9
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
d6ea890
ignore mac os on regressor
timwirt Jul 12, 2024
63b0e06
Merge branch '856-plot-decision-tree' of https://github.com/Safe-DS/L…
timwirt Jul 12, 2024
8ea6bf1
updated snapshots
timwirt Jul 12, 2024
ccc3023
style: apply automated linter fixes
megalinter-bot Jul 12, 2024
478c8f9
docs: result names instead of types
lars-reimann Jul 12, 2024
fde7cbb
Merge branch 'refs/heads/main' into 856-plot-decision-tree
lars-reimann Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds.data.image.containers import Image
from safeds.exceptions._ml import ModelNotFittedError
from safeds.ml.classical._bases import _DecisionTreeBase

from ._classifier import Classifier
Expand Down Expand Up @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> ClassifierMixin:
max_depth=self._max_depth,
min_samples_leaf=self._min_sample_count_in_leaves,
)

# ------------------------------------------------------------------------------------------------------------------
# Plot
# ------------------------------------------------------------------------------------------------------------------

def plot(self) -> Image:
"""
Get the image of the decision tree.

Returns
-------
plot:
The decision tree figure as an image.

Raises
------
ModelNotFittedError:
If model is not fitted.
"""
timwirt marked this conversation as resolved.
Show resolved Hide resolved
if not self.is_fitted:
raise ModelNotFittedError

from io import BytesIO

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

plot_tree(self._wrapped_model)

# save plot fig bytes in buffer
with BytesIO() as buffer:
plt.savefig(buffer)
image = buffer.getvalue()

# prevent forced plot from sklearn showing
plt.close()

return Image.from_bytes(image)
40 changes: 40 additions & 0 deletions src/safeds/ml/classical/regression/_decision_tree_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds.data.image.containers import Image
from safeds.exceptions._ml import ModelNotFittedError
from safeds.ml.classical._bases import _DecisionTreeBase

from ._regressor import Regressor
Expand Down Expand Up @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> RegressorMixin:
max_depth=self._max_depth,
min_samples_leaf=self._min_sample_count_in_leaves,
)

# ------------------------------------------------------------------------------------------------------------------
# Plot
# ------------------------------------------------------------------------------------------------------------------

def plot(self) -> Image:
"""
Get the image of the decision tree.

Returns
-------
plot:
The decision tree figure as an image.

Raises
------
ModelNotFittedError:
If model is not fitted.
"""
timwirt marked this conversation as resolved.
Show resolved Hide resolved
if not self.is_fitted:
raise ModelNotFittedError

from io import BytesIO

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

plot_tree(self._wrapped_model)

# save plot fig bytes in buffer
with BytesIO() as buffer:
plt.savefig(buffer)
image = buffer.getvalue()

# prevent forced plot from sklearn showing
plt.close()

return Image.from_bytes(image)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 22 additions & 1 deletion tests/safeds/ml/classical/classification/test_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Table
from safeds.exceptions import OutOfBoundsError
from safeds.exceptions import ModelNotFittedError, OutOfBoundsError
from safeds.ml.classical.classification import DecisionTreeClassifier
from syrupy import SnapshotAssertion

from tests.helpers import os_mac, skip_if_os


@pytest.fixture()
Expand Down Expand Up @@ -41,3 +44,21 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None
def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None:
with pytest.raises(OutOfBoundsError):
DecisionTreeClassifier(min_sample_count_in_leaves=min_sample_count_in_leaves)


class TestPlot:
def test_should_raise_if_model_is_not_fitted(self) -> None:
model = DecisionTreeClassifier()
with pytest.raises(ModelNotFittedError):
model.plot()

def test_should_check_that_plot_image_is_same_as_snapshot(
self,
training_set: TabularDataset,
snapshot_png_image: SnapshotAssertion,
) -> None:
skip_if_os([os_mac])

fitted_model = DecisionTreeClassifier().fit(training_set)
image = fitted_model.plot()
assert image == snapshot_png_image
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 18 additions & 1 deletion tests/safeds/ml/classical/regression/test_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Table
from safeds.exceptions import OutOfBoundsError
from safeds.exceptions import ModelNotFittedError, OutOfBoundsError
from safeds.ml.classical.regression import DecisionTreeRegressor
from syrupy import SnapshotAssertion


@pytest.fixture()
Expand Down Expand Up @@ -41,3 +42,19 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None
def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None:
with pytest.raises(OutOfBoundsError):
DecisionTreeRegressor(min_sample_count_in_leaves=min_sample_count_in_leaves)


class TestPlot:
def test_should_raise_if_model_is_not_fitted(self) -> None:
model = DecisionTreeRegressor()
with pytest.raises(ModelNotFittedError):
model.plot()

def test_should_check_that_plot_image_is_same_as_snapshot(
self,
training_set: TabularDataset,
snapshot_png_image: SnapshotAssertion,
) -> None:
fitted_model = DecisionTreeRegressor().fit(training_set)
image = fitted_model.plot()
assert image == snapshot_png_image