Skip to content

Commit

Permalink
Migrate to plugin design
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Aug 13, 2024
1 parent cc5a04d commit d531b0f
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 16 deletions.
1 change: 1 addition & 0 deletions superduper/backends/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .artifacts import FileSystemArtifactStore as ArtifactStore
from .compute import LocalComputeBackend as ComputeBackend

__all__ = ["ArtifactStore"]
2 changes: 2 additions & 0 deletions superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
flavour: t.Optional[str] = None,
):
super().__init__(conn, name, flavour)
if conn.startswith('filesystem://'):
conn = conn[7:]
if not os.path.exists(self.conn):
logging.info('Creating artifact store directory')
os.makedirs(self.conn, exist_ok=True)
Expand Down
2 changes: 2 additions & 0 deletions superduper/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def __init__(
self,
uri: t.Optional[str] = None,
queue: BaseQueuePublisher = LocalQueuePublisher(),
kwargs: t.Dict = {},
):
self.__outputs: t.Dict = {}
self.uri = uri
self.queue = queue
self.kwargs = kwargs

@property
def remote(self) -> bool:
Expand Down
43 changes: 30 additions & 13 deletions superduper/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,33 @@ class _Loader:
not_supported: t.Tuple = ()

@classmethod
def create(cls, uri):
"""Helper method to create metadata backend."""
def match(cls, uri):
"""Check if the uri matches the pattern."""
plugin, flavour = None, None
for pattern in cls.patterns:
if re.match(pattern, uri) is not None:
plugin, flavour = cls.patterns[pattern]
if cls.not_supported and (plugin, flavour) in cls.not_supported:
raise ValueError(
f"{plugin} with flavour {flavour} not supported "
"to create metadata store."
)
impl = getattr(load_plugin(plugin), cls.impl)
return impl(uri, flavour=flavour)
raise ValueError(f"{cls.__name__} No support for uri: {uri}")
selection = cls.patterns[pattern]
if isinstance(selection, tuple):
plugin, flavour = selection
else:
assert isinstance(selection, str)
plugin = selection
break
if plugin is None:
raise ValueError(f"{cls.__name__} No support for uri: {uri}")
return plugin, flavour

@classmethod
def create(cls, uri):
"""Helper method to create metadata backend."""
plugin, flavour = cls.match(uri)
if cls.not_supported and (plugin, flavour) in cls.not_supported:
raise ValueError(
f"{plugin} with flavour {flavour} not supported "
"to create metadata store."
)
impl = getattr(load_plugin(plugin), cls.impl)
return impl(uri, flavour=flavour)


class _MetaDataLoader(_Loader):
Expand Down Expand Up @@ -92,9 +106,12 @@ def _build_compute(cfg):
:param cfg: SuperDuper config.
"""
from superduper.backends.local.compute import LocalComputeBackend

return LocalComputeBackend()
backend = getattr(load_plugin(cfg.cluster.compute.backend), 'ComputeBackend')
return backend(
uri=cfg.cluster.compute.uri,
**cfg.cluster.compute.kwargs
)


def build_datalayer(cfg=None, **kwargs) -> Datalayer:
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class Compute(BaseConfig):
"""

uri: t.Optional[str] = None
compute_kwargs: t.Dict = dc.field(default_factory=dict)
kwargs: t.Dict = dc.field(default_factory=dict)
backend: str = 'local'


Expand Down
2 changes: 1 addition & 1 deletion superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)
from superduper import CFG

compute_kwargs = CFG.cluster.compute.compute_kwargs
compute_kwargs = CFG.cluster.compute.kwargs
self.compute_kwargs = self.compute_kwargs or compute_kwargs
self._is_initialized = False
if not self.identifier:
Expand Down
2 changes: 1 addition & 1 deletion superduper/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
db: t.Optional['Datalayer'] = None,
component: 'Component' = None,
):
self.compute_kwargs = compute_kwargs or CFG.cluster.compute.compute_kwargs
self.compute_kwargs = compute_kwargs or CFG.cluster.compute.kwargs

super().__init__(args=args, kwargs=kwargs, db=db, identifier=identifier)

Expand Down

0 comments on commit d531b0f

Please sign in to comment.