Skip to content

Commit

Permalink
Unify RayClient usage (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
korgan00 authored Nov 7, 2024
1 parent 509e0c2 commit fd4ffc0
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 55 deletions.
59 changes: 19 additions & 40 deletions client/qiskit_serverless/core/clients/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Configuration,
)
from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction
from qiskit_serverless.core.local_functions_store import LocalFunctionsStore
from qiskit_serverless.exception import QiskitServerlessException
from qiskit_serverless.serializers.program_serializers import (
QiskitObjectsEncoder,
Expand All @@ -69,7 +70,7 @@ def __init__(self):
super().__init__("local-client")
self.in_test = os.getenv("IN_TEST")
self._jobs = {}
self._patterns = []
self._functions = LocalFunctionsStore(self)

@classmethod
def from_dict(cls, dictionary: dict):
Expand All @@ -92,33 +93,30 @@ def run(
config: Optional[Configuration] = None,
) -> Job:
# pylint: disable=too-many-locals
title = ""
if isinstance(program, QiskitFunction):
title = program.title
else:
title = str(program)

for pattern in self._patterns:
if pattern["title"] == title:
saved_program = pattern
if saved_program[ # pylint: disable=possibly-used-before-assignment
"dependencies"
]:
dept = json.loads(saved_program["dependencies"])
for dependency in dept:
title = program.title if isinstance(program, QiskitFunction) else str(program)

saved_program = self.function(title)

if not saved_program:
raise QiskitServerlessException(
"QiskitFunction provided is not uploaded to the client. Use upload() first."
)

if saved_program.dependencies:
for dependency in saved_program.dependencies:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", dependency]
)
arguments = arguments or {}
env_vars = {
**(saved_program["env_vars"] or {}),
**{OT_PROGRAM_NAME: saved_program["title"]},
**(saved_program.env_vars or {}),
**{OT_PROGRAM_NAME: saved_program.title},
**{"PATH": os.environ["PATH"]},
**{ENV_JOB_ARGUMENTS: json.dumps(arguments, cls=QiskitObjectsEncoder)},
}

with Popen(
["python", saved_program["working_dir"] + saved_program["entrypoint"]],
["python", saved_program.working_dir + saved_program.entrypoint],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
Expand Down Expand Up @@ -165,31 +163,12 @@ def filtered_logs(self, job_id: str, **kwargs):

def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]:
# check if entrypoint exists
if not os.path.exists(os.path.join(program.working_dir, program.entrypoint)):
raise QiskitServerlessException(
f"Entrypoint file [{program.entrypoint}] does not exist "
f"in [{program.working_dir}] working directory."
)

pattern = {
"title": program.title,
"provider": program.provider,
"entrypoint": program.entrypoint,
"working_dir": program.working_dir,
"env_vars": program.env_vars,
"arguments": json.dumps({}),
"dependencies": json.dumps(program.dependencies or []),
"client": self,
}
self._patterns.append(pattern)
return RunnableQiskitFunction.from_json(pattern)
return self._functions.upload(program)

def functions(self, **kwargs) -> List[RunnableQiskitFunction]:
"""Returns list of programs."""
return [RunnableQiskitFunction.from_json(program) for program in self._patterns]
return self._functions.functions()

def function(
self, title: str, provider: Optional[str] = None
) -> Optional[RunnableQiskitFunction]:
functions = {function.title: function for function in self.functions()}
return functions.get(title)
return self._functions.function(title)
33 changes: 19 additions & 14 deletions client/qiskit_serverless/core/clients/ray_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"""
# pylint: disable=duplicate-code
import json
import warnings
from typing import Optional, List, Dict, Any, Union
from uuid import uuid4

Expand All @@ -43,6 +42,8 @@
Job,
)
from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction
from qiskit_serverless.core.local_functions_store import LocalFunctionsStore
from qiskit_serverless.exception import QiskitServerlessException
from qiskit_serverless.serializers.program_serializers import (
QiskitObjectsEncoder,
)
Expand All @@ -64,6 +65,7 @@ def __init__(self, host: str):
"""
super().__init__("ray-client", host)
self.job_submission_client = JobSubmissionClient(host)
self._functions = LocalFunctionsStore(self)

@classmethod
def from_dict(cls, dictionary: dict):
Expand Down Expand Up @@ -104,29 +106,32 @@ def run(
arguments: Optional[Dict[str, Any]] = None,
config: Optional[Configuration] = None,
) -> Job:
if not isinstance(program, QiskitFunction):
warnings.warn(
"`run` doesn't support program str yet. "
"Send a QiskitFunction instead. "
# pylint: disable=too-many-locals
title = program.title if isinstance(program, QiskitFunction) else str(program)

saved_program = self.function(title)

if not saved_program:
raise QiskitServerlessException(
"QiskitFunction provided is not uploaded to the client. Use upload() first."
)
raise NotImplementedError

arguments = arguments or {}
entrypoint = f"python {program.entrypoint}"
entrypoint = f"python {saved_program.entrypoint}"

# set program name so OT can use it as parent span name
env_vars = {
**(program.env_vars or {}),
**{OT_PROGRAM_NAME: program.title},
**(saved_program.env_vars or {}),
**{OT_PROGRAM_NAME: saved_program.title},
**{ENV_JOB_ARGUMENTS: json.dumps(arguments, cls=QiskitObjectsEncoder)},
}

job_id = self.job_submission_client.submit_job(
entrypoint=entrypoint,
submission_id=f"qs_{uuid4()}",
runtime_env={
"working_dir": program.working_dir,
"pip": program.dependencies,
"working_dir": saved_program.working_dir,
"pip": saved_program.dependencies,
"env_vars": env_vars,
},
)
Expand Down Expand Up @@ -160,14 +165,14 @@ def filtered_logs(self, job_id: str, **kwargs) -> str:

def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]:
"""Uploads program."""
raise NotImplementedError("Upload is not available for RayClient.")
return self._functions.upload(program)

def functions(self, **kwargs) -> List[RunnableQiskitFunction]:
"""Returns list of available programs."""
raise NotImplementedError("get_programs is not available for RayClient.")
return self._functions.functions()

def function(
self, title: str, provider: Optional[str] = None
) -> Optional[RunnableQiskitFunction]:
"""Returns program based on parameters."""
raise NotImplementedError("get_program is not available for RayClient.")
return self._functions.function(title)
73 changes: 73 additions & 0 deletions client/qiskit_serverless/core/local_functions_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
================================================
Provider (:mod:`qiskit_serverless.core.client`)
================================================
.. currentmodule:: qiskit_serverless.core.client
Qiskit Serverless provider
===========================
.. autosummary::
:toctree: ../stubs/
LocalFunctionsStore
"""
# pylint: disable=duplicate-code
import os.path
import os
from typing import Optional, List
from qiskit_serverless.core.client import BaseClient
from qiskit_serverless.core.function import QiskitFunction, RunnableQiskitFunction
from qiskit_serverless.exception import QiskitServerlessException


class LocalFunctionsStore:
"""LocalClient."""

def __init__(self, client: BaseClient):
self.client = client
self._functions: List[RunnableQiskitFunction] = []

def upload(self, program: QiskitFunction) -> Optional[RunnableQiskitFunction]:
"""Save a function in the store"""
if not os.path.exists(os.path.join(program.working_dir, program.entrypoint)):
raise QiskitServerlessException(
f"Entrypoint file [{program.entrypoint}] does not exist "
f"in [{program.working_dir}] working directory."
)

pattern = {
"title": program.title,
"provider": program.provider,
"entrypoint": program.entrypoint,
"working_dir": program.working_dir,
"env_vars": program.env_vars,
"arguments": {},
"dependencies": program.dependencies or [],
"client": self.client,
}
runnable_function = RunnableQiskitFunction.from_json(pattern)
self._functions.append(runnable_function)
return runnable_function

def functions(self) -> List[RunnableQiskitFunction]:
"""Returns list of functions."""
return list(self._functions)

def function(self, title: str) -> Optional[RunnableQiskitFunction]:
"""Returns a function with the provided title."""
functions = {function.title: function for function in self.functions()}
return functions.get(title)
3 changes: 2 additions & 1 deletion client/tests/core/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def test_program():
description="description",
version="0.0.1",
)
uploaded_program = serverless.upload(program)

job = serverless.run(program)
job = serverless.run(uploaded_program)

assert isinstance(job, Job)

Expand Down

0 comments on commit fd4ffc0

Please sign in to comment.