diff --git a/superduper/backends/local/__init__.py b/superduper/backends/local/__init__.py index 6298af57a..966cb0664 100644 --- a/superduper/backends/local/__init__.py +++ b/superduper/backends/local/__init__.py @@ -1,3 +1,4 @@ from .artifacts import FileSystemArtifactStore as ArtifactStore +from .compute import LocalComputeBackend as ComputeBackend __all__ = ["ArtifactStore"] diff --git a/superduper/backends/local/artifacts.py b/superduper/backends/local/artifacts.py index c31f56a38..3d5baeed2 100644 --- a/superduper/backends/local/artifacts.py +++ b/superduper/backends/local/artifacts.py @@ -26,6 +26,8 @@ def __init__( flavour: t.Optional[str] = None, ): super().__init__(conn, name, flavour) + if conn.startswith('filesystem://'): + conn = conn[7:] if not os.path.exists(self.conn): logging.info('Creating artifact store directory') os.makedirs(self.conn, exist_ok=True) diff --git a/superduper/backends/local/compute.py b/superduper/backends/local/compute.py index 96048a858..29f40a5e1 100644 --- a/superduper/backends/local/compute.py +++ b/superduper/backends/local/compute.py @@ -18,10 +18,12 @@ def __init__( self, uri: t.Optional[str] = None, queue: BaseQueuePublisher = LocalQueuePublisher(), + kwargs: t.Dict = {}, ): self.__outputs: t.Dict = {} self.uri = uri self.queue = queue + self.kwargs = kwargs @property def remote(self) -> bool: diff --git a/superduper/base/build.py b/superduper/base/build.py index c0e50ee79..ad0478fa2 100644 --- a/superduper/base/build.py +++ b/superduper/base/build.py @@ -17,19 +17,33 @@ class _Loader: not_supported: t.Tuple = () @classmethod - def create(cls, uri): - """Helper method to create metadata backend.""" + def match(cls, uri): + """Check if the uri matches the pattern.""" + plugin, flavour = None, None for pattern in cls.patterns: if re.match(pattern, uri) is not None: - plugin, flavour = cls.patterns[pattern] - if cls.not_supported and (plugin, flavour) in cls.not_supported: - raise ValueError( - f"{plugin} with flavour {flavour} not supported " - "to create metadata store." - ) - impl = getattr(load_plugin(plugin), cls.impl) - return impl(uri, flavour=flavour) - raise ValueError(f"{cls.__name__} No support for uri: {uri}") + selection = cls.patterns[pattern] + if isinstance(selection, tuple): + plugin, flavour = selection + else: + assert isinstance(selection, str) + plugin = selection + break + if plugin is None: + raise ValueError(f"{cls.__name__} No support for uri: {uri}") + return plugin, flavour + + @classmethod + def create(cls, uri): + """Helper method to create metadata backend.""" + plugin, flavour = cls.match(uri) + if cls.not_supported and (plugin, flavour) in cls.not_supported: + raise ValueError( + f"{plugin} with flavour {flavour} not supported " + "to create metadata store." + ) + impl = getattr(load_plugin(plugin), cls.impl) + return impl(uri, flavour=flavour) class _MetaDataLoader(_Loader): @@ -92,9 +106,12 @@ def _build_compute(cfg): :param cfg: SuperDuper config. """ - from superduper.backends.local.compute import LocalComputeBackend - return LocalComputeBackend() + backend = getattr(load_plugin(cfg.cluster.compute.backend), 'ComputeBackend') + return backend( + uri=cfg.cluster.compute.uri, + **cfg.cluster.compute.kwargs + ) def build_datalayer(cfg=None, **kwargs) -> Datalayer: diff --git a/superduper/base/config.py b/superduper/base/config.py index af7c0f2c5..ab04c5df1 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -186,7 +186,7 @@ class Compute(BaseConfig): """ uri: t.Optional[str] = None - compute_kwargs: t.Dict = dc.field(default_factory=dict) + kwargs: t.Dict = dc.field(default_factory=dict) backend: str = 'local' diff --git a/superduper/components/model.py b/superduper/components/model.py index 6752e616e..2b5e89d9d 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -504,7 +504,7 @@ def __post_init__(self, db, artifacts): super().__post_init__(db, artifacts) from superduper import CFG - compute_kwargs = CFG.cluster.compute.compute_kwargs + compute_kwargs = CFG.cluster.compute.kwargs self.compute_kwargs = self.compute_kwargs or compute_kwargs self._is_initialized = False if not self.identifier: diff --git a/superduper/jobs/job.py b/superduper/jobs/job.py index 05a464521..957f32807 100644 --- a/superduper/jobs/job.py +++ b/superduper/jobs/job.py @@ -187,7 +187,7 @@ def __init__( db: t.Optional['Datalayer'] = None, component: 'Component' = None, ): - self.compute_kwargs = compute_kwargs or CFG.cluster.compute.compute_kwargs + self.compute_kwargs = compute_kwargs or CFG.cluster.compute.kwargs super().__init__(args=args, kwargs=kwargs, db=db, identifier=identifier)