Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set kernel of support vector machine #350

Merged
merged 92 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
fa24585
Added tests for empty table add_column and add_columns
patrikguempel May 12, 2023
8fecfd4
hotfix table eq, check column names on table without rows
patrikguempel May 12, 2023
f3abb76
add check for completely empty table
patrikguempel May 12, 2023
99d2b9a
added code in parametrized test to add row in empty table
jxnior01 May 12, 2023
4843b1e
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
jxnior01 May 12, 2023
716221b
empty table now support adding any row
patrikguempel May 12, 2023
b0a93e7
added add_rows for empty tables
patrikguempel May 12, 2023
b0c9014
renamed test
jxnior01 May 19, 2023
aa4b4e2
added support to read empty file to get an empty table using from_csv…
patrikguempel May 19, 2023
1236c81
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
8be21a3
added test for empty table in test_has_column.py
patrikguempel May 19, 2023
ab5dd0f
created tests for empty table within test_inverse_transform_table.py,…
patrikguempel May 19, 2023
2d4ff5b
created tests for empty table for each plot
patrikguempel May 19, 2023
1c993d0
created tests for empty table for test_remove_columns.py, test_remove…
patrikguempel May 19, 2023
50f198d
created tests for empty table for test_rename.py, test_replace_column…
patrikguempel May 19, 2023
c55fb9e
fixes in from json file and from csv file regarding empty file
patrikguempel May 19, 2023
967aee1
added parametrized tests for empty tables
jxnior01 May 19, 2023
b341715
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
jxnior01 May 19, 2023
1692251
added parametrized tests for test_to_csv_file.py
jxnior01 May 19, 2023
22deaef
added tests for sort_rows and split regarding empty table,
patrikguempel May 19, 2023
514ca49
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
abd27b9
added empty table test in test_str.py
patrikguempel May 19, 2023
62647ac
empty summary for empty table
patrikguempel May 19, 2023
33356e9
added parametrized tests for test_transform_column.py
jxnior01 May 19, 2023
1ad7b0f
added parametrized tests for test_to_json_file.py
jxnior01 May 19, 2023
0362f04
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
jxnior01 May 19, 2023
c11299f
undone path test remove
patrikguempel May 19, 2023
9a892a9
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
27217b2
rewrote test_to_columns.py, added empy table check in the progress
patrikguempel May 19, 2023
b65531d
added parametrized tests for test_to_json_file.py
jxnior01 May 19, 2023
6f14390
added parametrized tests for test_to_excel_file.py
jxnior01 May 19, 2023
455b087
added empty table check for test_to_rows.py
patrikguempel May 19, 2023
4452347
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
9689b6d
added empty table check for test_transform_table.py
patrikguempel May 19, 2023
39f73d4
Merge branch 'main' into 123-check-that-methods-of-table-can-handle-a…
patrikguempel May 19, 2023
295b31e
added annotations for test_keep_only_columns.py and test_rename.py
jxnior01 May 19, 2023
75552e9
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
jxnior01 May 19, 2023
dff9623
added annotations for test_remove_columns.py
jxnior01 May 19, 2023
22bf39d
sugar for linter
patrikguempel May 19, 2023
a5f94b1
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
ed8c73b
json file will now read {} as empty, linter doesnt like empty json files
patrikguempel May 19, 2023
799bee0
path fixed
patrikguempel May 19, 2023
9921342
style: apply automated linter fixes
megalinter-bot May 19, 2023
aaecc76
fixed code coverage
patrikguempel May 19, 2023
6b64718
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
021f158
Merge branch 'main' into 123-check-that-methods-of-table-can-handle-a…
patrikguempel May 19, 2023
968d46c
fixed code coverage
patrikguempel May 19, 2023
6df249f
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 19, 2023
98d8b4c
fixed code coverage. part 3
patrikguempel May 19, 2023
9aa9a8b
semantical fix
patrikguempel May 19, 2023
83bf00a
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 24, 2023
0314518
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 25, 2023
294f637
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 25, 2023
671ddbc
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 25, 2023
cd0f95e
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 26, 2023
09bcab4
Update src/safeds/data/tabular/containers/_table.py
patrikguempel May 26, 2023
e3d3156
- removed docstring
patrikguempel May 26, 2023
9f756e1
Merge remote-tracking branch 'origin/123-check-that-methods-of-table-…
patrikguempel May 26, 2023
10f8b0f
added a new parameter kernel and getters for the c and kernel parameters
jxnior01 May 26, 2023
c85c972
added a new parameter kernel and getters for the c and kernel parameters
jxnior01 May 26, 2023
c43904f
added abstract base class SupportVectorMachineKernel and its subclasses
jxnior01 May 26, 2023
7b59c57
added a new parameter kernel, getters, the SupportVectorMachine Class…
jxnior01 Jun 2, 2023
0b97f4f
added all necessary codes for the kernel parameter in the SupportVect…
jxnior01 Jun 8, 2023
ddaa850
edited tests files to test the kernel parameter in the SupportVectorM…
jxnior01 Jun 8, 2023
0e471f7
tried to resolve merge conflicts
jxnior01 Jun 8, 2023
c78dcc3
removed redefined unused c
jxnior01 Jun 8, 2023
50d907a
resolving implicit Optional problems using PEP 604 syntax
jxnior01 Jun 8, 2023
f5f4f82
updated return type of the kernel property in both classicals
jxnior01 Jun 8, 2023
a1257de
style: apply automated linter fixes
megalinter-bot Jun 8, 2023
093ef81
removed inappropriate text
jxnior01 Jun 8, 2023
3b95eba
restructured code to use appropriate return types for assertions in t…
jxnior01 Jun 9, 2023
f3ef966
Merge branch 'main' into 172-set-kernel-of-support-vector-machine
jxnior01 Jun 9, 2023
67262bd
Merge remote-tracking branch 'origin/main' into 172-set-kernel-of-sup…
jxnior01 Jun 9, 2023
abe57c8
Merge branch '172-set-kernel-of-support-vector-machine' of github.com…
jxnior01 Jun 9, 2023
e765dd8
tests classification: added test_should_get_sklearn_kernel_linear for…
jxnior01 Jun 9, 2023
21219b0
pep 604 fixes
jxnior01 Jun 9, 2023
e12d7a7
pep 604
jxnior01 Jun 9, 2023
0f4d904
added assert for linter to properly access get_sklearn_kernel
jxnior01 Jun 9, 2023
424bcf6
added test_should_raise_if_degree_less_than_1 and test_should_get_skl…
jxnior01 Jun 9, 2023
6a147e5
added test_should_get_sklearn_kernel_sigmoid and test_should_get_skle…
jxnior01 Jun 9, 2023
5680be4
added test_should_get_kernel_name_invalid_kernel_type and test_should…
jxnior01 Jun 9, 2023
498a6eb
Added kernel tests for regression and removed test_should_get_kernel_…
jxnior01 Jun 9, 2023
7be415e
coverage 99.90%, readded last test: test_should_get_kernel_name_inval…
jxnior01 Jun 9, 2023
4196e1c
set expected type to be None as used in the init method
jxnior01 Jun 9, 2023
dcd0057
Last commit, coverage should hit 100%
jxnior01 Jun 9, 2023
cea8b83
performed some changes in docstrings
jxnior01 Jun 9, 2023
87a61bc
changes as per review
jxnior01 Jun 9, 2023
30caa1a
changes as per review
jxnior01 Jun 9, 2023
b60e83a
Update src/safeds/ml/classical/classification/_support_vector_machine.py
jxnior01 Jun 9, 2023
a97f466
Update src/safeds/ml/classical/regression/_support_vector_machine.py
jxnior01 Jun 9, 2023
784228a
reviewed doctrings
jxnior01 Jun 9, 2023
5afcc83
Merge branch 'main' into 172-set-kernel-of-support-vector-machine
jxnior01 Jun 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from sklearn.svm import SVC as sk_SVC # noqa: N811

from safeds.ml.classical._util_sklearn import fit, predict

from ._classifier import Classifier
from safeds.ml.classical.classification._classifier import Classifier
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
from sklearn.base import ClassifierMixin

from safeds.data.tabular.containers import Table, TaggedTable


class SupportVectorMachineKernel(ABC):
"""The abstract base class of the different subclasses supported by the `Kernel`."""

@abstractmethod
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
"""
Get the kernel of the given SupportVectorMachine.git.
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
svm: SupportVectorMachine. The SupportVectorMachine instance.

Returns
-------
object
The kernel of the SupportVectorMachine.
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved
"""


class SupportVectorMachine(Classifier):
"""
Support vector machine.
Expand All @@ -22,14 +41,15 @@ class SupportVectorMachine(Classifier):
----------
c: float
The strength of regularization. Must be strictly positive.
kernel: The type of kernel to be used. Defaults to None.

Raises
------
ValueError
If `c` is less than or equal to 0.
"""

def __init__(self, *, c: float = 1.0) -> None:
def __init__(self, *, c: float = 1.0, kernel: SupportVectorMachineKernel | None = None) -> None:
# Internal state
self._wrapped_classifier: sk_SVC | None = None
self._feature_names: list[str] | None = None
Expand All @@ -39,11 +59,50 @@ def __init__(self, *, c: float = 1.0) -> None:
if c <= 0:
raise ValueError("The parameter 'c' has to be strictly positive.")
self._c = c
self._kernel = kernel

@property
def c(self) -> float:
return self._c

@property
def kernel(self) -> SupportVectorMachineKernel | None:
return self._kernel

class Kernel:
class Linear(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class Polynomial(SupportVectorMachineKernel):
def __init__(self, degree: int):
if degree < 1:
raise ValueError("The parameter 'degree' has to be greater than or equal to 1.")
self._degree = degree

def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class Sigmoid(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class RadialBasisFunction(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

def _get_kernel_name(self) -> str:
if isinstance(self.kernel, SupportVectorMachine.Kernel.Linear):
return "linear"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.Polynomial):
return "poly"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.Sigmoid):
return "sigmoid"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.RadialBasisFunction):
return "rbf"
else:
raise TypeError("Invalid kernel type.")

def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
"""
Create a copy of this classifier and fit it with the given training data.
Expand All @@ -68,7 +127,7 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
wrapped_classifier = self._get_sklearn_classifier()
fit(wrapped_classifier, training_set)

result = SupportVectorMachine(c=self._c)
result = SupportVectorMachine(c=self._c, kernel=self._kernel)
result._wrapped_classifier = wrapped_classifier
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
67 changes: 63 additions & 4 deletions src/safeds/ml/classical/regression/_support_vector_machine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from sklearn.svm import SVR as sk_SVR # noqa: N811

from safeds.ml.classical._util_sklearn import fit, predict

from ._regressor import Regressor
from safeds.ml.classical.regression._regressor import Regressor
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
from sklearn.base import RegressorMixin

from safeds.data.tabular.containers import Table, TaggedTable


class SupportVectorMachineKernel(ABC):
"""The abstract base class of the different subclasses supported by the `Kernel`."""

@abstractmethod
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
"""
Get the kernel of the given SupportVectorMachine.

Parameters
----------
svm: SupportVectorMachine. The SupportVectorMachine instance.

Returns
-------
object
The kernel of the SupportVectorMachine.
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved
"""


class SupportVectorMachine(Regressor):
"""
Support vector machine.
Expand All @@ -22,14 +41,15 @@ class SupportVectorMachine(Regressor):
----------
c: float
The strength of regularization. Must be strictly positive.
kernel: The type of kernel to be used. Defaults to None.

Raises
------
ValueError
If `c` is less than or equal to 0.
"""

def __init__(self, *, c: float = 1.0) -> None:
def __init__(self, *, c: float = 1.0, kernel: SupportVectorMachineKernel | None = None) -> None:
# Internal state
self._wrapped_regressor: sk_SVR | None = None
self._feature_names: list[str] | None = None
Expand All @@ -39,11 +59,50 @@ def __init__(self, *, c: float = 1.0) -> None:
if c <= 0:
raise ValueError("The parameter 'c' has to be strictly positive.")
self._c = c
self._kernel = kernel

@property
def c(self) -> float:
return self._c

@property
def kernel(self) -> SupportVectorMachineKernel | None:
return self._kernel

class Kernel:
class Linear(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class Polynomial(SupportVectorMachineKernel):
def __init__(self, degree: int):
if degree < 1:
raise ValueError("The parameter 'degree' has to be greater than or equal to 1.")
self._degree = degree

def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class Sigmoid(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

class RadialBasisFunction(SupportVectorMachineKernel):
def get_sklearn_kernel(self, svm: SupportVectorMachine) -> object:
return svm.kernel

def _get_kernel_name(self) -> str:
if isinstance(self.kernel, SupportVectorMachine.Kernel.Linear):
return "linear"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.Polynomial):
return "poly"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.Sigmoid):
return "sigmoid"
elif isinstance(self.kernel, SupportVectorMachine.Kernel.RadialBasisFunction):
return "rbf"
else:
raise TypeError("Invalid kernel type.")

def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
"""
Create a copy of this regressor and fit it with the given training data.
Expand All @@ -68,7 +127,7 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine:
wrapped_regressor = self._get_sklearn_regressor()
fit(wrapped_regressor, training_set)

result = SupportVectorMachine(c=self._c)
result = SupportVectorMachine(c=self._c, kernel=self._kernel)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def training_set() -> TaggedTable:

class TestC:
def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set)
kernel = SupportVectorMachine.Kernel.Linear()
fitted_model = SupportVectorMachine(c=2, kernel=kernel).fit(training_set=training_set)
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved
assert fitted_model.c == 2
assert isinstance(fitted_model.kernel, SupportVectorMachine.Kernel.Linear)
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set)
kernel = SupportVectorMachine.Kernel.Linear()
fitted_model = SupportVectorMachine(c=2, kernel=kernel).fit(training_set)
assert fitted_model._wrapped_classifier is not None
assert fitted_model._wrapped_classifier.C == 2
assert isinstance(fitted_model.kernel, SupportVectorMachine.Kernel.Linear)
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

def test_should_raise_if_less_than_or_equal_to_0(self) -> None:
with pytest.raises(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def training_set() -> TaggedTable:

class TestC:
def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set)
kernel = SupportVectorMachine.Kernel.Linear()
fitted_model = SupportVectorMachine(c=2, kernel=kernel).fit(training_set=training_set)
assert fitted_model.c == 2
assert isinstance(fitted_model.kernel, SupportVectorMachine.Kernel.Linear)
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None:
fitted_model = SupportVectorMachine(c=2).fit(training_set)
kernel = SupportVectorMachine.Kernel.Linear()
fitted_model = SupportVectorMachine(c=2, kernel=kernel).fit(training_set)
assert fitted_model._wrapped_regressor is not None
assert fitted_model._wrapped_regressor.C == 2
assert isinstance(fitted_model.kernel, SupportVectorMachine.Kernel.Linear)
jxnior01 marked this conversation as resolved.
Show resolved Hide resolved

def test_should_raise_if_less_than_or_equal_to_0(self) -> None:
with pytest.raises(
Expand Down