diff --git a/sdk/python/kfp/components/_component_store.py b/sdk/python/kfp/components/_component_store.py index 6bac769a575..ac036213c20 100644 --- a/sdk/python/kfp/components/_component_store.py +++ b/sdk/python/kfp/components/_component_store.py @@ -3,6 +3,7 @@ ] from pathlib import Path +import copy import requests from typing import Callable from . import _components as comp @@ -52,11 +53,38 @@ def load_component(self, name, digest=None, tag=None): #This function should be called load_task_factory since it returns a factory function. #The real load_component function should produce an object with component properties (e.g. name, description, inputs/outputs). #TODO: Change this function to return component spec object but it should be callable to construct tasks. + component_ref = ComponentReference(name=name, digest=digest, tag=tag) + component_ref = self._load_component_spec_in_component_ref(component_ref) + return comp._create_task_factory_from_component_spec( + component_spec=component_ref.spec, + component_ref=component_ref, + ) + + def _load_component_spec_in_component_ref( + self, + component_ref: ComponentReference, + ) -> ComponentReference: + '''Takes component_ref, finds the component spec and returns component_ref with .spec set to the component spec. + + See ComponentStore.load_component for the details of the search logic. + ''' + if component_ref.spec: + return component_ref + + component_ref = copy.copy(component_ref) + if component_ref.url: + component_ref.spec = comp._load_component_spec_from_url(component_ref.url) + return component_ref + + name = component_ref.name if not name: raise TypeError("name is required") if name.startswith('/') or name.endswith('/'): raise ValueError('Component name should not start or end with slash: "{}"'.format(name)) + digest = component_ref.digest + tag = component_ref.tag + tried_locations = [] if digest is not None and tag is not None: @@ -75,8 +103,10 @@ def load_component(self, name, digest=None, tag=None): component_path = Path(local_search_path, path_suffix) tried_locations.append(str(component_path)) if component_path.is_file(): - component_ref = ComponentReference(name=name, digest=digest, tag=tag) - return comp.load_component_from_file(str(component_path)) + # TODO: Verify that the content matches the digest (if specified). + component_ref._local_path = str(component_path) + component_ref.spec = comp._load_component_spec_from_file(str(component_path)) + return component_ref #Trying URL prefixes for url_search_prefix in self.url_search_prefixes: @@ -88,21 +118,16 @@ def load_component(self, name, digest=None, tag=None): except: continue if response.content: - component_ref = ComponentReference(name=name, digest=digest, tag=tag, url=url) - return comp._load_component_from_yaml_or_zip_bytes(response.content, url, component_ref) + # TODO: Verify that the content matches the digest (if specified). + component_ref.url = url + component_ref.spec = comp._load_component_spec_from_yaml_or_zip_bytes(response.content) + return component_ref raise RuntimeError('Component {} was not found. Tried the following locations:\n{}'.format(name, '\n'.join(tried_locations))) def _load_component_from_ref(self, component_ref: ComponentReference) -> Callable: - if component_ref.spec: - return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref) - if component_ref.url: - return self.load_component_from_url(component_ref.url) - return self.load_component( - name=component_ref.name, - digest=component_ref.digest, - tag=component_ref.tag, - ) + component_ref = self._load_component_spec_in_component_ref(component_ref) + return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref) ComponentStore.default_store = ComponentStore( diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index 8fa24e67e2d..fa91b9c9372 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -74,19 +74,14 @@ def load_component_from_url(url): A factory function with a strongly-typed signature. Once called with the required arguments, the factory constructs a pipeline task instance (ContainerOp). ''' - if url is None: - raise TypeError - - #Handling Google Cloud Storage URIs - if url.startswith('gs://'): - #Replacing the gs:// URI with https:// URI (works for public objects) - url = 'https://storage.googleapis.com/' + url[len('gs://'):] - - import requests - resp = requests.get(url) - resp.raise_for_status() + component_spec = _load_component_spec_from_url(url) + url = _fix_component_uri(url) component_ref = ComponentReference(url=url) - return _load_component_from_yaml_or_zip_bytes(resp.content, url, component_ref) + return _create_task_factory_from_component_spec( + component_spec=component_spec, + component_filename=url, + component_ref=component_ref, + ) def load_component_from_file(filename): @@ -100,10 +95,11 @@ def load_component_from_file(filename): A factory function with a strongly-typed signature. Once called with the required arguments, the factory constructs a pipeline task instance (ContainerOp). ''' - if filename is None: - raise TypeError - with open(filename, 'rb') as component_stream: - return _load_component_from_yaml_or_zip_stream(component_stream, filename) + component_spec = _load_component_spec_from_file(path=filename) + return _create_task_factory_from_component_spec( + component_spec=component_spec, + component_filename=filename, + ) def load_component_from_text(text): @@ -119,20 +115,47 @@ def load_component_from_text(text): ''' if text is None: raise TypeError - return _create_task_factory_from_component_text(text, None) + component_spec = _load_component_spec_from_component_text(text) + return _create_task_factory_from_component_spec(component_spec=component_spec) + + +def _fix_component_uri(uri: str) -> str: + #Handling Google Cloud Storage URIs + if uri.startswith('gs://'): + #Replacing the gs:// URI with https:// URI (works for public objects) + uri = 'https://storage.googleapis.com/' + uri[len('gs://'):] + return uri + + +def _load_component_spec_from_file(path) -> ComponentSpec: + with open(path, 'rb') as component_stream: + return _load_component_spec_from_yaml_or_zip_stream(component_stream) + + +def _load_component_spec_from_url(url: str): + if url is None: + raise TypeError + + url = _fix_component_uri(url) + + import requests + resp = requests.get(url) + resp.raise_for_status() + return _load_component_spec_from_yaml_or_zip_bytes(resp.content) _COMPONENT_FILE_NAME_IN_ARCHIVE = 'component.yaml' -def _load_component_from_yaml_or_zip_bytes(bytes, component_filename=None, component_ref: ComponentReference = None): +def _load_component_spec_from_yaml_or_zip_bytes(data: bytes): import io - component_stream = io.BytesIO(bytes) - return _load_component_from_yaml_or_zip_stream(component_stream, component_filename, component_ref) + component_stream = io.BytesIO(data) + return _load_component_spec_from_yaml_or_zip_stream(component_stream) + +def _load_component_spec_from_yaml_or_zip_stream(stream) -> ComponentSpec: + '''Loads component spec from a stream. -def _load_component_from_yaml_or_zip_stream(stream, component_filename=None, component_ref: ComponentReference = None): - '''Loads component from a stream and creates a task factory function. The stream can be YAML or a zip file with a component.yaml file inside. ''' import zipfile @@ -141,20 +164,18 @@ def _load_component_from_yaml_or_zip_stream(stream, component_filename=None, com stream.seek(0) with zipfile.ZipFile(stream) as zip_obj: with zip_obj.open(_COMPONENT_FILE_NAME_IN_ARCHIVE) as component_stream: - return _create_task_factory_from_component_text(component_stream, component_filename, component_ref) + return _load_component_spec_from_component_text( + text_or_file=component_stream, + ) else: stream.seek(0) - return _create_task_factory_from_component_text(stream, component_filename, component_ref) + return _load_component_spec_from_component_text(stream) -def _create_task_factory_from_component_text(text_or_file, component_filename=None, component_ref: ComponentReference = None): +def _load_component_spec_from_component_text(text_or_file) -> ComponentSpec: component_dict = load_yaml(text_or_file) - return _create_task_factory_from_component_dict(component_dict, component_filename, component_ref) - - -def _create_task_factory_from_component_dict(component_dict, component_filename=None, component_ref: ComponentReference = None): component_spec = ComponentSpec.from_dict(component_dict) - return _create_task_factory_from_component_spec(component_spec, component_filename, component_ref) + return component_spec _inputs_dir = '/tmp/inputs'