diff --git a/cirq-superstaq/cirq_superstaq/job.py b/cirq-superstaq/cirq_superstaq/job.py index 536ab2608..0521b522a 100644 --- a/cirq-superstaq/cirq_superstaq/job.py +++ b/cirq-superstaq/cirq_superstaq/job.py @@ -154,6 +154,19 @@ def status(self) -> str: self._refresh_job() return self._overall_status + def cancel(self, **kwargs: object) -> None: + """Cancel the current job if it is not in a terminal state. + + Args: + kwargs: Extra options needed to fetch jobs. + + Raises: + SuperstaqServerException: If unable to get the status of the job from the API or + cancellations were unsuccessful. + """ + job_ids = self._job_id.split(",") + self._client.cancel_jobs(job_ids, **kwargs) + def target(self) -> str: """Gets the Superstaq target associated with this job. diff --git a/cirq-superstaq/cirq_superstaq/job_test.py b/cirq-superstaq/cirq_superstaq/job_test.py index 957ae5ad8..590c04556 100644 --- a/cirq-superstaq/cirq_superstaq/job_test.py +++ b/cirq-superstaq/cirq_superstaq/job_test.py @@ -95,6 +95,13 @@ def patched_requests(*contents: object) -> mock._patch[mock.Mock]: return mock.patch("requests.post", side_effect=responses) +def test_cancel(job: css.Job) -> None: + with mock.patch("requests.post", return_value=mock.MagicMock(ok=True)) as mock_post: + job.cancel() + new_job().cancel() + assert mock_post.call_count == 2 + + def test_job_fields(job: css.job.Job) -> None: compiled_circuit = cirq.Circuit(cirq.H(cirq.q(0)), cirq.measure(cirq.q(0))) job_dict = { diff --git a/general-superstaq/general_superstaq/superstaq_client.py b/general-superstaq/general_superstaq/superstaq_client.py index 631df3da8..642396c04 100644 --- a/general-superstaq/general_superstaq/superstaq_client.py +++ b/general-superstaq/general_superstaq/superstaq_client.py @@ -175,20 +175,43 @@ def create_job( json_dict["method"] = method if kwargs or self.client_kwargs: json_dict["options"] = json.dumps({**self.client_kwargs, **kwargs}) - return self.post_request("/jobs", json_dict) + def cancel_jobs( + self, + job_ids: Sequence[str], + **kwargs: object, + ) -> list[str]: + """Cancel jobs associated with given job ids. + + Args: + job_ids: The UUIDs of the jobs (returned when the jobs were created). + kwargs: Extra options needed to fetch jobs. + + Returns: + A list of the job ids of the jobs that successfully cancelled. + + Raises: + SuperstaqServerException: For other API call failures. + """ + json_dict: dict[str, str | Sequence[str]] = { + "job_ids": job_ids, + } + if kwargs or self.client_kwargs: + json_dict["options"] = json.dumps({**self.client_kwargs, **kwargs}) + + return self.post_request("/cancel_jobs", json_dict)["succeeded"] + def fetch_jobs( self, job_ids: list[str], - **kwargs: Any, + **kwargs: object, ) -> dict[str, dict[str, str]]: """Get the job from the Superstaq API. Args: - job_ids: The UUID of the job (returned when the job was created). + job_ids: The UUIDs of the jobs (returned when the jobs were created). kwargs: Extra options needed to fetch jobs. - - cq_token: CQ Cloud credentials. Returns: The json body of the response as a dict. @@ -748,6 +771,7 @@ def _handle_status_codes(self, response: requests.Response) -> None: gss.SuperstaqServerException: If an error has occurred in making a request to the Superstaq API. """ + if response.status_code == requests.codes.unauthorized: if response.json() == ( "You must accept the Terms of Use (superstaq.infleqtion.com/terms_of_use)." @@ -778,7 +802,7 @@ def _handle_status_codes(self, response: requests.Response) -> None: if response.status_code not in self.RETRIABLE_STATUS_CODES: try: - json_content = response.json() + json_content = self._handle_response(response) except requests.JSONDecodeError: json_content = None @@ -821,13 +845,14 @@ def _make_request(self, request: Callable[[], requests.Response]) -> requests.Re TimeoutError: If the requests retried for more than `max_retry_seconds`. Returns: - The `request.Response` from the final successful request call. + The `requests.Response` from the final successful request call. """ # Initial backoff of 100ms. delay_seconds = 0.1 while True: try: response = request() + if response.ok: return response diff --git a/general-superstaq/general_superstaq/superstaq_client_test.py b/general-superstaq/general_superstaq/superstaq_client_test.py index 62babb11d..92e99ba7e 100644 --- a/general-superstaq/general_superstaq/superstaq_client_test.py +++ b/general-superstaq/general_superstaq/superstaq_client_test.py @@ -294,12 +294,9 @@ def test_superstaq_client_create_job_not_retriable(mock_post: mock.MagicMock) -> @mock.patch("requests.post") def test_superstaq_client_create_job_retry(mock_post: mock.MagicMock) -> None: - response1 = mock.MagicMock() - response2 = mock.MagicMock() + response1 = mock.MagicMock(ok=False, status_code=requests.codes.service_unavailable) + response2 = mock.MagicMock(ok=True) mock_post.side_effect = [response1, response2] - response1.ok = False - response1.status_code = requests.codes.service_unavailable - response2.ok = True client = gss.superstaq_client._SuperstaqClient( client_name="general-superstaq", remote_host="http://example.com", @@ -597,6 +594,63 @@ def test_superstaq_client_fetch_jobs_retry(mock_post: mock.MagicMock) -> None: assert mock_post.call_count == 2 +@mock.patch("requests.post") +def test_superstaq_client_cancel_jobs_unauthorized(mock_post: mock.MagicMock) -> None: + mock_post.return_value.ok = False + mock_post.return_value.status_code = requests.codes.unauthorized + + client = gss.superstaq_client._SuperstaqClient( + client_name="general-superstaq", + remote_host="http://example.com", + api_key="to_my_heart", + ) + with pytest.raises(gss.SuperstaqServerException, match="Not authorized"): + _ = client.cancel_jobs(["job_id"]) + + +@mock.patch("requests.post") +def test_superstaq_client_cancel_jobs_not_found(mock_post: mock.MagicMock) -> None: + (mock_post.return_value).ok = False + (mock_post.return_value).status_code = requests.codes.not_found + + client = gss.superstaq_client._SuperstaqClient( + client_name="general-superstaq", + remote_host="http://example.com", + api_key="to_my_heart", + ) + with pytest.raises(gss.SuperstaqServerException): + _ = client.cancel_jobs(["job_id"]) + + +@mock.patch("requests.post") +def test_superstaq_client_get_cancel_jobs_retriable(mock_post: mock.MagicMock) -> None: + mock_post.return_value.ok = False + mock_post.return_value.status_code = requests.codes.bad_request + + client = gss.superstaq_client._SuperstaqClient( + client_name="general-superstaq", + remote_host="http://example.com", + api_key="to_my_heart", + ) + with pytest.raises(gss.SuperstaqServerException, match="Status code: 400"): + _ = client.cancel_jobs(["job_id"], cq_token=1) + + +@mock.patch("requests.post") +def test_superstaq_client_cancel_jobs_retry(mock_post: mock.MagicMock) -> None: + response1 = mock.MagicMock(ok=False, status_code=requests.codes.service_unavailable) + response2 = mock.MagicMock(ok=True) + mock_post.side_effect = [response1, response2] + + client = gss.superstaq_client._SuperstaqClient( + client_name="general-superstaq", + remote_host="http://example.com", + api_key="to_my_heart", + ) + _ = client.cancel_jobs(["job_id"]) + assert mock_post.call_count == 2 + + @mock.patch("requests.post") def test_superstaq_client_aqt_compile(mock_post: mock.MagicMock) -> None: client = gss.superstaq_client._SuperstaqClient( diff --git a/qiskit-superstaq/qiskit_superstaq/superstaq_job.py b/qiskit-superstaq/qiskit_superstaq/superstaq_job.py index 2c03a0c5a..450ecdeca 100644 --- a/qiskit-superstaq/qiskit_superstaq/superstaq_job.py +++ b/qiskit-superstaq/qiskit_superstaq/superstaq_job.py @@ -177,6 +177,19 @@ def _check_if_stopped(self) -> None: self._job_id, self._overall_status ) + def cancel(self, **kwargs: object) -> None: + """Cancel the current job if it is not in a terminal state. + + Args: + kwargs: Extra options needed to fetch jobs. + + Raises: + SuperstaqServerException: If unable to get the status of the job from the API or + cancellations were unsuccessful. + """ + job_ids = self._job_id.split(",") + self._backend._provider._client.cancel_jobs(job_ids, **kwargs) + def _refresh_job(self) -> None: """Queries the server for an updated job result.""" diff --git a/qiskit-superstaq/qiskit_superstaq/superstaq_job_test.py b/qiskit-superstaq/qiskit_superstaq/superstaq_job_test.py index e2ec03746..49d17b3e1 100644 --- a/qiskit-superstaq/qiskit_superstaq/superstaq_job_test.py +++ b/qiskit-superstaq/qiskit_superstaq/superstaq_job_test.py @@ -75,6 +75,13 @@ def test_wait_for_results(backend: qss.SuperstaqBackend) -> None: ] +def test_cancel(backend: qss.SuperstaqBackend) -> None: + with mock.patch("requests.post", return_value=mock.MagicMock(ok=True)) as mock_post: + qss.SuperstaqJob(backend=backend, job_id="123abc").cancel() + qss.SuperstaqJob(backend=backend, job_id="123abc,456def").cancel() + assert mock_post.call_count == 2 + + def test_timeout(backend: qss.SuperstaqBackend) -> None: job = qss.SuperstaqJob(backend=backend, job_id="123abc")