Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

225 add a plugin registry #228

Merged
merged 6 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from synthcity.utils.reproducibility import enable_reproducible_results
from synthcity.utils.serialization import load_from_file, save_to_file

PLUGIN_NAME_NOT_SET: str = "plugin_name_not_set"
PLUGIN_TYPE_NOT_SET: str = "plugin_type_not_set"


class Plugin(Serializable, metaclass=ABCMeta):
"""
Expand Down Expand Up @@ -82,6 +85,15 @@ def __init__(
compress_dataset: bool = False,
sampling_strategy: str = "marginal", # uniform, marginal
) -> None:
if self.name() == PLUGIN_NAME_NOT_SET:
raise ValueError(
f"Plugin {self.__class__.__name__} `name` was not set, use Plugins().add({self.__class__.__name__}, {self.__class__})"
)
if self.type() == PLUGIN_TYPE_NOT_SET:
raise ValueError(
f"Plugin {self.__class__.__name__} `type` was not set, use Plugins().add({self.__class__.__name__}, {self.__class__})"
)

super().__init__()

enable_reproducible_results(random_state)
Expand Down Expand Up @@ -145,13 +157,13 @@ def sample_hyperparameters_optuna(
@abstractmethod
def name() -> str:
"""The name of the plugin."""
...
return PLUGIN_NAME_NOT_SET

@staticmethod
@abstractmethod
def type() -> str:
"""The type of the plugin."""
...
return PLUGIN_TYPE_NOT_SET

@classmethod
def fqdn(cls) -> str:
Expand Down Expand Up @@ -537,14 +549,18 @@ def plot(
plot_tsne(plt, X, X_syn)


PLUGIN_CATEGORY_REGISTRY: Dict[str, List[str]] = dict()
PLUGIN_REGISTRY: Dict[str, Type[Plugin]] = dict()


class PluginLoader:
"""Plugin loading utility class.
Used to load the plugins from the current folder.
"""

@validate_arguments
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
self._plugins: Dict[str, Type] = {}
self._refresh()
self._available_plugins = {}
for plugin in plugins:
stem = Path(plugin).stem.split("plugin_")[-1]
Expand All @@ -553,7 +569,11 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None
continue
self._available_plugins[stem] = plugin
self._expected_type = expected_type
self._categories = categories

def _refresh(self) -> None:
"""Refresh the list of available plugins"""
self._plugins: Dict[str, Type[Plugin]] = PLUGIN_REGISTRY
self._categories: Dict[str, List[str]] = PLUGIN_CATEGORY_REGISTRY

@validate_arguments
def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]:
Expand Down Expand Up @@ -610,6 +630,7 @@ def _load_single_plugin(self, plugin_name: str) -> bool:

def list(self) -> List[str]:
"""Get all the available plugins."""
self._refresh()
all_plugins = list(self._plugins.keys()) + list(self._available_plugins.keys())
plugins = []
for plugin in all_plugins:
Expand All @@ -620,18 +641,42 @@ def list(self) -> List[str]:

def types(self) -> List[Type]:
"""Get the loaded plugins types"""
self._refresh()
return list(self._plugins.values())

def _add_category(self, category: str, name: str) -> "PluginLoader":
"""Add a new plugin category"""
log.debug(f"Registering plugin category {category}")
if (
category in PLUGIN_CATEGORY_REGISTRY
and name in PLUGIN_CATEGORY_REGISTRY[category]
):
raise TypeError(
f"Plugin {name} is already registered as category: {category}"
)
if PLUGIN_CATEGORY_REGISTRY.get(category, None) is not None:
PLUGIN_CATEGORY_REGISTRY[category].append(name)
else:
PLUGIN_CATEGORY_REGISTRY[category] = [name]
return self

def add(self, name: str, cls: Type) -> "PluginLoader":
"""Add a new plugin"""
self._refresh()
if name in self._plugins:
log.info(f"Plugin {name} already exists. Overwriting")

if not issubclass(cls, self._expected_type):
raise ValueError(
f"Plugin {name} must derive the {self._expected_type} interface."
)
self._plugins[name] = cls

if (
cls.type() not in PLUGIN_CATEGORY_REGISTRY.keys()
or name not in PLUGIN_CATEGORY_REGISTRY.get(cls.type(), [])
):
self._add_category(str(cls.type()), name)
PLUGIN_REGISTRY[name] = cls
return self

@validate_arguments
Expand All @@ -649,6 +694,7 @@ def get(self, name: str, *args: Any, **kwargs: Any) -> Any:
Returns:
The new object
"""
self._refresh()
if name not in self._plugins and name not in self._available_plugins:
raise ValueError(f"Plugin {name} doesn't exist.")

Expand All @@ -669,6 +715,7 @@ def get_type(self, name: str) -> Type:
Returns:
The class of the plugin
"""
self._refresh()
if name not in self._plugins and name not in self._available_plugins:
raise ValueError(f"Plugin {name} doesn't exist.")

Expand All @@ -682,6 +729,7 @@ def get_type(self, name: str) -> Type:

def __iter__(self) -> Generator:
"""Iterate the loaded plugins."""
self._refresh()
for x in self._plugins:
yield x

Expand Down
58 changes: 58 additions & 0 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import platform
from copy import copy
from pathlib import Path
from typing import Any, List

# third party
import pytest
Expand All @@ -13,10 +14,15 @@
# synthcity absolute
from synthcity.benchmark import Benchmarks
from synthcity.benchmark.utils import get_json_serializable_kwargs
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (
DataLoader,
GenericDataLoader,
SurvivalAnalysisDataLoader,
)
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema


def test_benchmark_sanity() -> None:
Expand Down Expand Up @@ -286,3 +292,55 @@ def test_benchmark_workspace_cache() -> None:

assert X_augment_cache_file.exists()
assert augment_generator_file.exists()


def test_benchmark_added_plugin() -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
X["target"] = y

class DummyCopyDataPlugin(Plugin):
"""Dummy plugin for debugging."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "copy_data"

@staticmethod
def type() -> str:
return "debug"

@staticmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
return []

def _fit(
self, X: DataLoader, *args: Any, **kwargs: Any
) -> "DummyCopyDataPlugin":
self.features_count = X.shape[1]
self.X = X
return self

def _generate(
self, count: int, syn_schema: Schema, **kwargs: Any
) -> DataLoader:
return self.X.sample(count)

generators = Plugins()
# Add the new plugin to the collection
generators.add("copy_data", DummyCopyDataPlugin)

score = Benchmarks.evaluate(
[
("copy_data", "copy_data", {}),
],
GenericDataLoader(X, target_column="target"),
metrics={
"performance": [
"linear_model",
]
},
)
assert "copy_data" in score
75 changes: 75 additions & 0 deletions tests/plugins/test_plugin_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# stdlib
import glob
from pathlib import Path
from typing import Any, List

# third party
from sklearn.datasets import load_breast_cancer

# synthcity absolute
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema


class DummyCopyDataPlugin(Plugin):
"""Dummy plugin for debugging."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "copy_data"

@staticmethod
def type() -> str:
return "debug"

@staticmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
return []

def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "DummyCopyDataPlugin":
self.features_count = X.shape[1]
self.X = X
return self

def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader:
return self.X.sample(count)


def test_add_dummy_plugin() -> None:
# get the list of plugins that are loaded
generators = Plugins()

# Get the list of plugins that come with the package
plugins_dir = Path.cwd() / "src/synthcity/plugins"
plugins_list = []
for plugin_type in plugins_dir.iterdir():
plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py"))
plugins_list.extend([Path(path).stem for path in plugin_paths])

# Test that the new plugin is not in the list plugins in the package
assert "copy_data" not in plugins_list

# Add the new plugin
generators.add("copy_data", DummyCopyDataPlugin)

# Load reference data
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
loader = GenericDataLoader(X)
loader.dataframe()

# Train the new plugin
gen = generators.get("copy_data")
gen.fit(loader)

# Generate some new data to check the new plugin works
gen.generate(count=10)

# Test that the new plugin is now in the list of available plugins
available_plugins = Plugins().list()
assert "copy_data" in available_plugins