From 3e891ccca92739d4d180cad2d685fd40ed19413d Mon Sep 17 00:00:00 2001 From: Victory Omole Date: Thu, 31 Aug 2023 17:34:48 -0500 Subject: [PATCH] Ability to pass in `options` to `SuperstaqProvider` and `Service` (#708) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is the beginning of having `options` be added at the top level. The inspiration for this change is https://github.com/Infleqtion/client-superstaq/issues/592. New way of calling should look like ```python import json import os import qiskit import qiskit_superstaq as qss with open(os.path.expanduser('~/.coldquanta/cq_auth.json'), 'r') as file: cq_token = json.load(file) superstaq = qss.superstaq_provider.SuperstaqProvider(token, cq_token=cq_token) backend = superstaq.get_backend("cq_hilbert_simulator") qc = qiskit.QuantumCircuit(1, 1) qc.h(0) qc.measure(0, 0) job = backend.run(qc, shots=1) print(job.result().get_counts()) ``` This PR makes is so that `get_job` also has acess to the `cq_token` that's passed in . Corresponding PR https://github.com/Infleqtion/server-superstaq/pull/2663 --------- Co-authored-by: Emilio Co-authored-by: richrines1 <85512171+richrines1@users.noreply.github.com> Co-authored-by: Emilio Peláez <63567458+epelaaez@users.noreply.github.com> Co-authored-by: Emilio Pelaez --- cirq-superstaq/cirq_superstaq/service.py | 8 ++- .../general_superstaq/superstaq_client.py | 40 ++++++++++-- .../superstaq_client_test.py | 61 +++++++++++-------- .../qiskit_superstaq/superstaq_provider.py | 6 ++ 4 files changed, 83 insertions(+), 32 deletions(-) diff --git a/cirq-superstaq/cirq_superstaq/service.py b/cirq-superstaq/cirq_superstaq/service.py index 7fe1edfbd..015112195 100644 --- a/cirq-superstaq/cirq_superstaq/service.py +++ b/cirq-superstaq/cirq_superstaq/service.py @@ -117,6 +117,7 @@ def __init__( api_version: str = gss.API_VERSION, max_retry_seconds: int = 3600, verbose: bool = False, + **kwargs: object, ) -> None: """Creates the Service to access Superstaq's API. @@ -143,12 +144,15 @@ def __init__( api_version: Version of the api. max_retry_seconds: The number of seconds to retry calls for. Defaults to one hour. verbose: Whether to print to stdio and stderr on retriable errors. + kwargs: Other optimization and execution parameters. + - qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ + devices. + - cq_token: Token from CQ cloud. Raises: EnvironmentError: If an API key was not provided and could not be found. """ self.default_target = default_target - self._client = superstaq_client._SuperstaqClient( client_name="cirq-superstaq", remote_host=remote_host, @@ -156,6 +160,7 @@ def __init__( api_version=api_version, max_retry_seconds=max_retry_seconds, verbose=verbose, + **kwargs, ) def _resolve_target(self, target: Union[str, None]) -> str: @@ -271,7 +276,6 @@ def create_job( serialized_circuits = css.serialization.serialize_circuits(circuit) target = self._resolve_target(target) - result = self._client.create_job( serialized_circuits={"cirq_circuits": serialized_circuits}, repetitions=repetitions, diff --git a/general-superstaq/general_superstaq/superstaq_client.py b/general-superstaq/general_superstaq/superstaq_client.py index ec72c1d61..59c5da09b 100644 --- a/general-superstaq/general_superstaq/superstaq_client.py +++ b/general-superstaq/general_superstaq/superstaq_client.py @@ -48,6 +48,7 @@ def __init__( api_version: str = gss.API_VERSION, max_retry_seconds: float = 60, # 1 minute verbose: bool = False, + **kwargs: Any, ): """Creates the SuperstaqClient. @@ -66,6 +67,10 @@ def __init__( which is the most recent version when this client was downloaded. max_retry_seconds: The time to continue retriable responses. Defaults to 3600. verbose: Whether to print to stderr and stdio any retriable errors that are encountered. + kwargs: Other optimization and execution parameters. + - qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ + devices. + - cq_token: Token from CQ cloud. """ self.api_key = api_key or gss.superstaq_client.find_api_key() @@ -93,6 +98,7 @@ def __init__( "X-Client-Name": self.client_name, "X-Client-Version": self.api_version, } + self._client_kwargs = kwargs def get_request(self, endpoint: str) -> Any: """Performs a GET request on a given endpoint. @@ -193,12 +199,11 @@ def create_job( "target": target, "shots": int(repetitions), } - if method is not None: json_dict["method"] = method + if kwargs or self._client_kwargs: + json_dict["options"] = json.dumps({**self._client_kwargs, **kwargs}) - if kwargs: - json_dict["options"] = json.dumps(kwargs) return self.post_request("/jobs", json_dict) def get_job(self, job_id: str) -> Dict[str, str]: @@ -213,7 +218,34 @@ def get_job(self, job_id: str) -> Dict[str, str]: Raises: SuperstaqServerException: For other API call failures. """ - return self.get_request(f"/job/{job_id}") + return self.fetch_jobs([job_id])[job_id] + + def fetch_jobs( + self, + job_ids: List[str], + **kwargs: Any, + ) -> Dict[str, Dict[str, str]]: + """Get the job from the Superstaq API. + + Args: + job_ids: The UUID of the job (returned when the job was created). + kwargs: Extra options needed to fetch jobs. + - cq_token: CQ Cloud credentials. + + Returns: + The json body of the response as a dict. + + Raises: + SuperstaqServerException: For other API call failures. + """ + + json_dict: Dict[str, Any] = { + "job_ids": job_ids, + } + if kwargs or self._client_kwargs: + json_dict["options"] = json.dumps({**self._client_kwargs, **kwargs}) + + return self.post_request("/fetch_jobs", json_dict) def get_balance(self) -> Dict[str, float]: """Get the querying user's account balance in USD. diff --git a/general-superstaq/general_superstaq/superstaq_client_test.py b/general-superstaq/general_superstaq/superstaq_client_test.py index e76214132..0b7c011d8 100644 --- a/general-superstaq/general_superstaq/superstaq_client_test.py +++ b/general-superstaq/general_superstaq/superstaq_client_test.py @@ -152,12 +152,14 @@ def test_supertstaq_client_create_job(mock_post: mock.MagicMock) -> None: remote_host="http://example.com", api_key="to_my_heart", ) + response = client.create_job( serialized_circuits={"Hello": "World"}, repetitions=200, target="ss_example_qpu", method="dry-run", qiskit_pulse=True, + cq_token={"@type": "RefreshFlowState", "access_token": "123"}, ) assert response == {"foo": "bar"} @@ -166,7 +168,9 @@ def test_supertstaq_client_create_job(mock_post: mock.MagicMock) -> None: "target": "ss_example_qpu", "shots": 200, "method": "dry-run", - "options": json.dumps({"qiskit_pulse": True}), + "options": json.dumps( + {"qiskit_pulse": True, "cq_token": {"@type": "RefreshFlowState", "access_token": "123"}} + ), } mock_post.assert_called_with( f"http://example.com/{API_VERSION}/jobs", @@ -287,20 +291,25 @@ def test_superstaq_client_create_job_json(mock_post: mock.MagicMock) -> None: ) -@mock.patch("requests.get") -def test_superstaq_client_get_job(mock_get: mock.MagicMock) -> None: - mock_get.return_value.ok = True - mock_get.return_value.json.return_value = {"foo": "bar"} +@mock.patch("requests.post") +def test_superstaq_client_fetch_jobs(mock_post: mock.MagicMock) -> None: + mock_post.return_value.ok = True + mock_post.return_value.json.return_value = {"my_id": {"foo": "bar"}} client = gss.superstaq_client._SuperstaqClient( client_name="general-superstaq", remote_host="http://example.com", api_key="to_my_heart", ) - response = client.get_job(job_id="job_id") - assert response == {"foo": "bar"} - - mock_get.assert_called_with( - f"http://example.com/{API_VERSION}/job/job_id", headers=EXPECTED_HEADERS, verify=False + response = client.fetch_jobs(job_ids=["job_id"], cq_token={"access_token": "token"}) + assert response == {"my_id": {"foo": "bar"}} + mock_post.assert_called_with( + f"http://example.com/{API_VERSION}/fetch_jobs", + json={ + "job_ids": ["job_id"], + "options": '{"cq_token": {"access_token": "token"}}', + }, + headers=EXPECTED_HEADERS, + verify=False, ) @@ -451,10 +460,10 @@ def test_superstaq_client_get_targets(mock_get: mock.MagicMock) -> None: ) -@mock.patch("requests.get") -def test_superstaq_client_get_job_unauthorized(mock_get: mock.MagicMock) -> None: - mock_get.return_value.ok = False - mock_get.return_value.status_code = requests.codes.unauthorized +@mock.patch("requests.post") +def test_superstaq_client_get_job_unauthorized(mock_post: mock.MagicMock) -> None: + mock_post.return_value.ok = False + mock_post.return_value.status_code = requests.codes.unauthorized client = gss.superstaq_client._SuperstaqClient( client_name="general-superstaq", @@ -465,10 +474,10 @@ def test_superstaq_client_get_job_unauthorized(mock_get: mock.MagicMock) -> None _ = client.get_job("job_id") -@mock.patch("requests.get") -def test_superstaq_client_get_job_not_found(mock_get: mock.MagicMock) -> None: - (mock_get.return_value).ok = False - (mock_get.return_value).status_code = requests.codes.not_found +@mock.patch("requests.post") +def test_superstaq_client_get_job_not_found(mock_post: mock.MagicMock) -> None: + (mock_post.return_value).ok = False + (mock_post.return_value).status_code = requests.codes.not_found client = gss.superstaq_client._SuperstaqClient( client_name="general-superstaq", @@ -479,10 +488,10 @@ def test_superstaq_client_get_job_not_found(mock_get: mock.MagicMock) -> None: _ = client.get_job("job_id") -@mock.patch("requests.get") -def test_superstaq_client_get_job_not_retriable(mock_get: mock.MagicMock) -> None: - mock_get.return_value.ok = False - mock_get.return_value.status_code = requests.codes.bad_request +@mock.patch("requests.post") +def test_superstaq_client_get_job_not_retriable(mock_post: mock.MagicMock) -> None: + mock_post.return_value.ok = False + mock_post.return_value.status_code = requests.codes.bad_request client = gss.superstaq_client._SuperstaqClient( client_name="general-superstaq", @@ -493,11 +502,11 @@ def test_superstaq_client_get_job_not_retriable(mock_get: mock.MagicMock) -> Non _ = client.get_job("job_id") -@mock.patch("requests.get") -def test_superstaq_client_get_job_retry(mock_get: mock.MagicMock) -> None: +@mock.patch("requests.post") +def test_superstaq_client_get_job_retry(mock_post: mock.MagicMock) -> None: response1 = mock.MagicMock() response2 = mock.MagicMock() - mock_get.side_effect = [response1, response2] + mock_post.side_effect = [response1, response2] response1.ok = False response1.status_code = requests.codes.service_unavailable response2.ok = True @@ -507,7 +516,7 @@ def test_superstaq_client_get_job_retry(mock_get: mock.MagicMock) -> None: api_key="to_my_heart", ) _ = client.get_job("job_id") - assert mock_get.call_count == 2 + assert mock_post.call_count == 2 @mock.patch("requests.post") diff --git a/qiskit-superstaq/qiskit_superstaq/superstaq_provider.py b/qiskit-superstaq/qiskit_superstaq/superstaq_provider.py index c3fc68dcb..55bb2bac0 100644 --- a/qiskit-superstaq/qiskit_superstaq/superstaq_provider.py +++ b/qiskit-superstaq/qiskit_superstaq/superstaq_provider.py @@ -47,6 +47,7 @@ def __init__( api_version: str = gss.API_VERSION, max_retry_seconds: int = 3600, verbose: bool = False, + **kwargs: Any, ) -> None: """Initializes a SuperstaqProvider. @@ -70,6 +71,10 @@ def __init__( api_version: The version of the API. max_retry_seconds: The number of seconds to retry calls for. Defaults to one hour. verbose: Whether to print to stdio and stderr on retriable errors. + kwargs: Other optimization and execution parameters. + - qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ + devices. + - cq_token: Token from CQ cloud. Raises: EnvironmentError: If an API key was not provided and could not be found. @@ -83,6 +88,7 @@ def __init__( api_version=api_version, max_retry_seconds=max_retry_seconds, verbose=verbose, + **kwargs, ) def __str__(self) -> str: