Skip to content

Commit

Permalink
Migrate to plugin design in compute
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Aug 14, 2024
1 parent adf1641 commit 9d72dc1
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 23 deletions.
3 changes: 0 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Move all plugins superduperdb/ext/* to /plugins
- Optimize the logic for file saving and retrieval in the artifact_store.



#### New Features & Functionality

- Modify the field name output to _outputs.predict_id in the model results of Ibis.
Expand All @@ -41,7 +39,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add testing utils for plugins
- Add `cache` field in Component


#### Bug Fixes

- Vector-search vector-loading bug fixed
Expand Down
3 changes: 2 additions & 1 deletion superduper/backends/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .artifacts import FileSystemArtifactStore as ArtifactStore
from .compute import LocalComputeBackend as ComputeBackend

__all__ = ["ArtifactStore"]
__all__ = ["ArtifactStore", "ComputeBackend"]
2 changes: 2 additions & 0 deletions superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
name: t.Optional[str] = None,
flavour: t.Optional[str] = None,
):
if conn.startswith('filesystem://'):
conn = conn.split('filesystem://')[-1]
super().__init__(conn, name, flavour)
if not os.path.exists(self.conn):
logging.info('Creating artifact store directory')
Expand Down
3 changes: 3 additions & 0 deletions superduper/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ class LocalComputeBackend(ComputeBackend):
:param uri: Optional uri param.
:param queue: Optional pluggable queue.
:param kwargs: Optional kwargs.
"""

def __init__(
self,
uri: t.Optional[str] = None,
queue: BaseQueuePublisher = LocalQueuePublisher(),
kwargs: t.Dict = {},
):
self.__outputs: t.Dict = {}
self.uri = uri
self.queue = queue
self.kwargs = kwargs

@property
def remote(self) -> bool:
Expand Down
41 changes: 27 additions & 14 deletions superduper/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,33 @@ class _Loader:
not_supported: t.Tuple = ()

@classmethod
def create(cls, uri):
"""Helper method to create metadata backend."""
def match(cls, uri):
"""Check if the uri matches the pattern."""
plugin, flavour = None, None
for pattern in cls.patterns:
if re.match(pattern, uri) is not None:
plugin, flavour = cls.patterns[pattern]
if cls.not_supported and (plugin, flavour) in cls.not_supported:
raise ValueError(
f"{plugin} with flavour {flavour} not supported "
"to create metadata store."
)
impl = getattr(load_plugin(plugin), cls.impl)
return impl(uri, flavour=flavour)
raise ValueError(f"{cls.__name__} No support for uri: {uri}")
selection = cls.patterns[pattern]
if isinstance(selection, tuple):
plugin, flavour = selection
else:
assert isinstance(selection, str)
plugin = selection
break
if plugin is None:
raise ValueError(f"{cls.__name__} No support for uri: {uri}")
return plugin, flavour

@classmethod
def create(cls, uri):
"""Helper method to create metadata backend."""
plugin, flavour = cls.match(uri)
if cls.not_supported and (plugin, flavour) in cls.not_supported:
raise ValueError(
f"{plugin} with flavour {flavour} not supported "
"to create metadata store."
)
impl = getattr(load_plugin(plugin), cls.impl)
return impl(uri, flavour=flavour)


class _MetaDataLoader(_Loader):
Expand Down Expand Up @@ -92,9 +106,8 @@ def _build_compute(cfg):
:param cfg: SuperDuper config.
"""
from superduper.backends.local.compute import LocalComputeBackend

return LocalComputeBackend()
backend = getattr(load_plugin(cfg.cluster.compute.backend), 'ComputeBackend')
return backend(uri=cfg.cluster.compute.uri, **cfg.cluster.compute.kwargs)


def build_datalayer(cfg=None, **kwargs) -> Datalayer:
Expand Down
4 changes: 2 additions & 2 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ class Compute(BaseConfig):
"""Describes the configuration for distributed computing.
:param uri: The URI for the compute service.
:param compute_kwargs: The keyword arguments to pass to the compute service.
:param kwargs: The keyword arguments to pass to the compute service.
:param backend: Compute backend.
"""

uri: t.Optional[str] = None
compute_kwargs: t.Dict = dc.field(default_factory=dict)
kwargs: t.Dict = dc.field(default_factory=dict)
backend: str = 'local'


Expand Down
2 changes: 1 addition & 1 deletion superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
from superduper import CFG

compute_kwargs = CFG.cluster.compute.compute_kwargs
compute_kwargs = CFG.cluster.compute.kwargs
self.compute_kwargs = self.compute_kwargs or compute_kwargs
self._is_initialized = False
if not self.identifier:
Expand Down
1 change: 1 addition & 0 deletions superduper/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def form_template(self):
for k, v in self.template.items()
if k not in {'_blobs', 'identifier', '_path'}
},
'_path': self.template['_path'],
}

def execute(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion superduper/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
db: t.Optional['Datalayer'] = None,
component: 'Component' = None,
):
self.compute_kwargs = compute_kwargs or CFG.cluster.compute.compute_kwargs
self.compute_kwargs = compute_kwargs or CFG.cluster.compute.kwargs

super().__init__(args=args, kwargs=kwargs, db=db, identifier=identifier)

Expand Down
2 changes: 1 addition & 1 deletion test/unittest/backends/local/test_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.fixture
def artifact_store(tmpdir):
artifact_store = FileSystemArtifactStore(tmpdir)
artifact_store = FileSystemArtifactStore(str(tmpdir))
yield artifact_store
artifact_store.drop(True)

Expand Down

0 comments on commit 9d72dc1

Please sign in to comment.