Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job Cancelation #998

Merged
merged 16 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cirq-superstaq/cirq_superstaq/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def status(self) -> str:
self._refresh_job()
return self._overall_status

def cancel(self, **kwargs: object) -> None:
"""Cancel the current job if it is not in a terminal state.

Args:
kwargs: Extra options needed to fetch jobs.

Raises:
SuperstaqServerException: If unable to get the status of the job from the API or
cancellations were unsuccessful.
"""
job_ids = self._job_id.split(",")
self._client.cancel_jobs(job_ids, **kwargs)

def target(self) -> str:
"""Gets the Superstaq target associated with this job.

Expand Down
7 changes: 7 additions & 0 deletions cirq-superstaq/cirq_superstaq/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def patched_requests(*contents: object) -> mock._patch[mock.Mock]:
return mock.patch("requests.post", side_effect=responses)


def test_cancel(job: css.Job) -> None:
with mock.patch("requests.post", return_value=mock.MagicMock(ok=True)) as mock_post:
job.cancel()
new_job().cancel()
assert mock_post.call_count == 2


def test_job_fields(job: css.job.Job) -> None:
compiled_circuit = cirq.Circuit(cirq.H(cirq.q(0)), cirq.measure(cirq.q(0)))
job_dict = {
Expand Down
37 changes: 31 additions & 6 deletions general-superstaq/general_superstaq/superstaq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,43 @@ def create_job(
json_dict["method"] = method
if kwargs or self.client_kwargs:
json_dict["options"] = json.dumps({**self.client_kwargs, **kwargs})

return self.post_request("/jobs", json_dict)

def cancel_jobs(
self,
job_ids: Sequence[str],
**kwargs: object,
) -> list[str]:
"""Cancel jobs associated with given job ids.

Args:
job_ids: The UUIDs of the jobs (returned when the jobs were created).
kwargs: Extra options needed to fetch jobs.

Returns:
A list of the job ids of the jobs that successfully cancelled.

Raises:
SuperstaqServerException: For other API call failures.
"""
json_dict: dict[str, str | Sequence[str]] = {
"job_ids": job_ids,
}
if kwargs or self.client_kwargs:
json_dict["options"] = json.dumps({**self.client_kwargs, **kwargs})

return self.post_request("/cancel_jobs", json_dict)["succeeded"]

def fetch_jobs(
self,
job_ids: list[str],
**kwargs: Any,
**kwargs: object,
) -> dict[str, dict[str, str]]:
"""Get the job from the Superstaq API.

Args:
job_ids: The UUID of the job (returned when the job was created).
job_ids: The UUIDs of the jobs (returned when the jobs were created).
kwargs: Extra options needed to fetch jobs.
- cq_token: CQ Cloud credentials.

Returns:
The json body of the response as a dict.
Expand Down Expand Up @@ -748,6 +771,7 @@ def _handle_status_codes(self, response: requests.Response) -> None:
gss.SuperstaqServerException: If an error has occurred in making a request
to the Superstaq API.
"""

if response.status_code == requests.codes.unauthorized:
if response.json() == (
"You must accept the Terms of Use (superstaq.infleqtion.com/terms_of_use)."
Expand Down Expand Up @@ -778,7 +802,7 @@ def _handle_status_codes(self, response: requests.Response) -> None:

if response.status_code not in self.RETRIABLE_STATUS_CODES:
try:
json_content = response.json()
json_content = self._handle_response(response)
except requests.JSONDecodeError:
json_content = None

Expand Down Expand Up @@ -821,13 +845,14 @@ def _make_request(self, request: Callable[[], requests.Response]) -> requests.Re
TimeoutError: If the requests retried for more than `max_retry_seconds`.

Returns:
The `request.Response` from the final successful request call.
The `requests.Response` from the final successful request call.
"""
# Initial backoff of 100ms.
delay_seconds = 0.1
while True:
try:
response = request()

if response.ok:
return response

Expand Down
64 changes: 59 additions & 5 deletions general-superstaq/general_superstaq/superstaq_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,9 @@ def test_superstaq_client_create_job_not_retriable(mock_post: mock.MagicMock) ->

@mock.patch("requests.post")
def test_superstaq_client_create_job_retry(mock_post: mock.MagicMock) -> None:
response1 = mock.MagicMock()
response2 = mock.MagicMock()
response1 = mock.MagicMock(ok=False, status_code=requests.codes.service_unavailable)
response2 = mock.MagicMock(ok=True)
mock_post.side_effect = [response1, response2]
response1.ok = False
response1.status_code = requests.codes.service_unavailable
response2.ok = True
client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
Expand Down Expand Up @@ -597,6 +594,63 @@ def test_superstaq_client_fetch_jobs_retry(mock_post: mock.MagicMock) -> None:
assert mock_post.call_count == 2


@mock.patch("requests.post")
def test_superstaq_client_cancel_jobs_unauthorized(mock_post: mock.MagicMock) -> None:
mock_post.return_value.ok = False
mock_post.return_value.status_code = requests.codes.unauthorized

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
with pytest.raises(gss.SuperstaqServerException, match="Not authorized"):
_ = client.cancel_jobs(["job_id"])


@mock.patch("requests.post")
def test_superstaq_client_cancel_jobs_not_found(mock_post: mock.MagicMock) -> None:
(mock_post.return_value).ok = False
(mock_post.return_value).status_code = requests.codes.not_found

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
with pytest.raises(gss.SuperstaqServerException):
_ = client.cancel_jobs(["job_id"])


@mock.patch("requests.post")
def test_superstaq_client_get_cancel_jobs_retriable(mock_post: mock.MagicMock) -> None:
mock_post.return_value.ok = False
mock_post.return_value.status_code = requests.codes.bad_request

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
with pytest.raises(gss.SuperstaqServerException, match="Status code: 400"):
_ = client.cancel_jobs(["job_id"], cq_token=1)


@mock.patch("requests.post")
def test_superstaq_client_cancel_jobs_retry(mock_post: mock.MagicMock) -> None:
response1 = mock.MagicMock(ok=False, status_code=requests.codes.service_unavailable)
response2 = mock.MagicMock(ok=True)
mock_post.side_effect = [response1, response2]

client = gss.superstaq_client._SuperstaqClient(
client_name="general-superstaq",
remote_host="http://example.com",
api_key="to_my_heart",
)
_ = client.cancel_jobs(["job_id"])
assert mock_post.call_count == 2


@mock.patch("requests.post")
def test_superstaq_client_aqt_compile(mock_post: mock.MagicMock) -> None:
client = gss.superstaq_client._SuperstaqClient(
Expand Down
13 changes: 13 additions & 0 deletions qiskit-superstaq/qiskit_superstaq/superstaq_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,19 @@ def _check_if_stopped(self) -> None:
self._job_id, self._overall_status
)

def cancel(self, **kwargs: object) -> None:
"""Cancel the current job if it is not in a terminal state.

Args:
kwargs: Extra options needed to fetch jobs.

Raises:
SuperstaqServerException: If unable to get the status of the job from the API or
cancellations were unsuccessful.
"""
job_ids = self._job_id.split(",")
self._backend._provider._client.cancel_jobs(job_ids, **kwargs)

def _refresh_job(self) -> None:
"""Queries the server for an updated job result."""

Expand Down
7 changes: 7 additions & 0 deletions qiskit-superstaq/qiskit_superstaq/superstaq_job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def test_wait_for_results(backend: qss.SuperstaqBackend) -> None:
]


def test_cancel(backend: qss.SuperstaqBackend) -> None:
with mock.patch("requests.post", return_value=mock.MagicMock(ok=True)) as mock_post:
qss.SuperstaqJob(backend=backend, job_id="123abc").cancel()
qss.SuperstaqJob(backend=backend, job_id="123abc,456def").cancel()
assert mock_post.call_count == 2


def test_timeout(backend: qss.SuperstaqBackend) -> None:
job = qss.SuperstaqJob(backend=backend, job_id="123abc")

Expand Down