Skip to content

Commit

Permalink
Ability to pass in options to SuperstaqProvider and Service (#708)
Browse files Browse the repository at this point in the history
This PR is the beginning of having `options` be added at the top level.
The inspiration for this change is
#592.

New way of calling should look like

```python
import json
import os
import qiskit
import qiskit_superstaq as qss


with open(os.path.expanduser('~/.coldquanta/cq_auth.json'), 'r') as file:
    cq_token = json.load(file)

superstaq = qss.superstaq_provider.SuperstaqProvider(token, cq_token=cq_token)

backend = superstaq.get_backend("cq_hilbert_simulator")


qc = qiskit.QuantumCircuit(1, 1)

qc.h(0)
qc.measure(0, 0)

job = backend.run(qc, shots=1)
print(job.result().get_counts())
```
This PR makes is so that `get_job` also has acess to the `cq_token`
that's passed in .

Corresponding PR
Infleqtion/server-superstaq#2663

---------

Co-authored-by: Emilio <epelaez@uchicago.edu>
Co-authored-by: richrines1 <85512171+richrines1@users.noreply.github.com>
Co-authored-by: Emilio Peláez <63567458+epelaaez@users.noreply.github.com>
Co-authored-by: Emilio Pelaez <epelaaez@gmail.com>
  • Loading branch information
5 people authored Aug 31, 2023
1 parent 7158072 commit 3e891cc
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 32 deletions.
8 changes: 6 additions & 2 deletions cirq-superstaq/cirq_superstaq/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
api_version: str = gss.API_VERSION,
max_retry_seconds: int = 3600,
verbose: bool = False,
**kwargs: object,
) -> None:
"""Creates the Service to access Superstaq's API.
Expand All @@ -143,19 +144,23 @@ def __init__(
api_version: Version of the api.
max_retry_seconds: The number of seconds to retry calls for. Defaults to one hour.
verbose: Whether to print to stdio and stderr on retriable errors.
kwargs: Other optimization and execution parameters.
- qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ
devices.
- cq_token: Token from CQ cloud.
Raises:
EnvironmentError: If an API key was not provided and could not be found.
"""
self.default_target = default_target

self._client = superstaq_client._SuperstaqClient(
client_name="cirq-superstaq",
remote_host=remote_host,
api_key=api_key,
api_version=api_version,
max_retry_seconds=max_retry_seconds,
verbose=verbose,
**kwargs,
)

def _resolve_target(self, target: Union[str, None]) -> str:
Expand Down Expand Up @@ -271,7 +276,6 @@ def create_job(
serialized_circuits = css.serialization.serialize_circuits(circuit)

target = self._resolve_target(target)

result = self._client.create_job(
serialized_circuits={"cirq_circuits": serialized_circuits},
repetitions=repetitions,
Expand Down
40 changes: 36 additions & 4 deletions general-superstaq/general_superstaq/superstaq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
api_version: str = gss.API_VERSION,
max_retry_seconds: float = 60, # 1 minute
verbose: bool = False,
**kwargs: Any,
):
"""Creates the SuperstaqClient.
Expand All @@ -66,6 +67,10 @@ def __init__(
which is the most recent version when this client was downloaded.
max_retry_seconds: The time to continue retriable responses. Defaults to 3600.
verbose: Whether to print to stderr and stdio any retriable errors that are encountered.
kwargs: Other optimization and execution parameters.
- qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ
devices.
- cq_token: Token from CQ cloud.
"""

self.api_key = api_key or gss.superstaq_client.find_api_key()
Expand Down Expand Up @@ -93,6 +98,7 @@ def __init__(
"X-Client-Name": self.client_name,
"X-Client-Version": self.api_version,
}
self._client_kwargs = kwargs

def get_request(self, endpoint: str) -> Any:
"""Performs a GET request on a given endpoint.
Expand Down Expand Up @@ -193,12 +199,11 @@ def create_job(
"target": target,
"shots": int(repetitions),
}

if method is not None:
json_dict["method"] = method
if kwargs or self._client_kwargs:
json_dict["options"] = json.dumps({**self._client_kwargs, **kwargs})

if kwargs:
json_dict["options"] = json.dumps(kwargs)
return self.post_request("/jobs", json_dict)

def get_job(self, job_id: str) -> Dict[str, str]:
Expand All @@ -213,7 +218,34 @@ def get_job(self, job_id: str) -> Dict[str, str]:
Raises:
SuperstaqServerException: For other API call failures.
"""
return self.get_request(f"/job/{job_id}")
return self.fetch_jobs([job_id])[job_id]

def fetch_jobs(
self,
job_ids: List[str],
**kwargs: Any,
) -> Dict[str, Dict[str, str]]:
"""Get the job from the Superstaq API.
Args:
job_ids: The UUID of the job (returned when the job was created).
kwargs: Extra options needed to fetch jobs.
- cq_token: CQ Cloud credentials.
Returns:
The json body of the response as a dict.
Raises:
SuperstaqServerException: For other API call failures.
"""

json_dict: Dict[str, Any] = {
"job_ids": job_ids,
}
if kwargs or self._client_kwargs:
json_dict["options"] = json.dumps({**self._client_kwargs, **kwargs})

return self.post_request("/fetch_jobs", json_dict)

def get_balance(self) -> Dict[str, float]:
"""Get the querying user's account balance in USD.
Expand Down
61 changes: 35 additions & 26 deletions general-superstaq/general_superstaq/superstaq_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,14 @@ def test_supertstaq_client_create_job(mock_post: mock.MagicMock) -> None:
remote_host="http://example.com",
api_key="to_my_heart",
)

response = client.create_job(
serialized_circuits={"Hello": "World"},
repetitions=200,
target="ss_example_qpu",
method="dry-run",
qiskit_pulse=True,
cq_token={"@type": "RefreshFlowState", "access_token": "123"},
)
assert response == {"foo": "bar"}

Expand All @@ -166,7 +168,9 @@ def test_supertstaq_client_create_job(mock_post: mock.MagicMock) -> None:
"target": "ss_example_qpu",
"shots": 200,
"method": "dry-run",
"options": json.dumps({"qiskit_pulse": True}),
"options": json.dumps(
{"qiskit_pulse": True, "cq_token": {"@type": "RefreshFlowState", "access_token": "123"}}
),
}
mock_post.assert_called_with(
f"http://example.com/{API_VERSION}/jobs",
Expand Down Expand Up @@ -287,20 +291,25 @@ def test_superstaq_client_create_job_json(mock_post: mock.MagicMock) -> None:
)


@mock.patch("requests.get")
def test_superstaq_client_get_job(mock_get: mock.MagicMock) -> None:
mock_get.return_value.ok = True
mock_get.return_value.json.return_value = {"foo": "bar"}
@mock.patch("requests.post")
def test_superstaq_client_fetch_jobs(mock_post: mock.MagicMock) -> None:
mock_post.return_value.ok = True
mock_post.return_value.json.return_value = {"my_id": {"foo": "bar"}}
client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
response = client.get_job(job_id="job_id")
assert response == {"foo": "bar"}

mock_get.assert_called_with(
f"http://example.com/{API_VERSION}/job/job_id", headers=EXPECTED_HEADERS, verify=False
response = client.fetch_jobs(job_ids=["job_id"], cq_token={"access_token": "token"})
assert response == {"my_id": {"foo": "bar"}}
mock_post.assert_called_with(
f"http://example.com/{API_VERSION}/fetch_jobs",
json={
"job_ids": ["job_id"],
"options": '{"cq_token": {"access_token": "token"}}',
},
headers=EXPECTED_HEADERS,
verify=False,
)


Expand Down Expand Up @@ -451,10 +460,10 @@ def test_superstaq_client_get_targets(mock_get: mock.MagicMock) -> None:
)


@mock.patch("requests.get")
def test_superstaq_client_get_job_unauthorized(mock_get: mock.MagicMock) -> None:
mock_get.return_value.ok = False
mock_get.return_value.status_code = requests.codes.unauthorized
@mock.patch("requests.post")
def test_superstaq_client_get_job_unauthorized(mock_post: mock.MagicMock) -> None:
mock_post.return_value.ok = False
mock_post.return_value.status_code = requests.codes.unauthorized

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
Expand All @@ -465,10 +474,10 @@ def test_superstaq_client_get_job_unauthorized(mock_get: mock.MagicMock) -> None
_ = client.get_job("job_id")


@mock.patch("requests.get")
def test_superstaq_client_get_job_not_found(mock_get: mock.MagicMock) -> None:
(mock_get.return_value).ok = False
(mock_get.return_value).status_code = requests.codes.not_found
@mock.patch("requests.post")
def test_superstaq_client_get_job_not_found(mock_post: mock.MagicMock) -> None:
(mock_post.return_value).ok = False
(mock_post.return_value).status_code = requests.codes.not_found

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
Expand All @@ -479,10 +488,10 @@ def test_superstaq_client_get_job_not_found(mock_get: mock.MagicMock) -> None:
_ = client.get_job("job_id")


@mock.patch("requests.get")
def test_superstaq_client_get_job_not_retriable(mock_get: mock.MagicMock) -> None:
mock_get.return_value.ok = False
mock_get.return_value.status_code = requests.codes.bad_request
@mock.patch("requests.post")
def test_superstaq_client_get_job_not_retriable(mock_post: mock.MagicMock) -> None:
mock_post.return_value.ok = False
mock_post.return_value.status_code = requests.codes.bad_request

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
Expand All @@ -493,11 +502,11 @@ def test_superstaq_client_get_job_not_retriable(mock_get: mock.MagicMock) -> Non
_ = client.get_job("job_id")


@mock.patch("requests.get")
def test_superstaq_client_get_job_retry(mock_get: mock.MagicMock) -> None:
@mock.patch("requests.post")
def test_superstaq_client_get_job_retry(mock_post: mock.MagicMock) -> None:
response1 = mock.MagicMock()
response2 = mock.MagicMock()
mock_get.side_effect = [response1, response2]
mock_post.side_effect = [response1, response2]
response1.ok = False
response1.status_code = requests.codes.service_unavailable
response2.ok = True
Expand All @@ -507,7 +516,7 @@ def test_superstaq_client_get_job_retry(mock_get: mock.MagicMock) -> None:
api_key="to_my_heart",
)
_ = client.get_job("job_id")
assert mock_get.call_count == 2
assert mock_post.call_count == 2


@mock.patch("requests.post")
Expand Down
6 changes: 6 additions & 0 deletions qiskit-superstaq/qiskit_superstaq/superstaq_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
api_version: str = gss.API_VERSION,
max_retry_seconds: int = 3600,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Initializes a SuperstaqProvider.
Expand All @@ -70,6 +71,10 @@ def __init__(
api_version: The version of the API.
max_retry_seconds: The number of seconds to retry calls for. Defaults to one hour.
verbose: Whether to print to stdio and stderr on retriable errors.
kwargs: Other optimization and execution parameters.
- qiskit_pulse: Whether to use Superstaq's pulse-level optimizations for IBMQ
devices.
- cq_token: Token from CQ cloud.
Raises:
EnvironmentError: If an API key was not provided and could not be found.
Expand All @@ -83,6 +88,7 @@ def __init__(
api_version=api_version,
max_retry_seconds=max_retry_seconds,
verbose=verbose,
**kwargs,
)

def __str__(self) -> str:
Expand Down

0 comments on commit 3e891cc

Please sign in to comment.