Skip to content

Commit

Permalink
Use websocket stream to retrieve runtime job results (#33)
Browse files Browse the repository at this point in the history
* retrieve job results w websockets

* update stream results method

* add mock for unit test

* update unit tests

* temp disable test case

* increase sleep time

* increase wait for final result

* temp increase wait on test final result

* temp disable final result test

* temp disable get result twice

* adjust wait times

* add test case back

* fix lint

* update test case

* update comments

* fix lint

* added fake wait for final state

* fix lint

* increase default timeout

* remove timeout

* remove import

* remove sleep

* move fake method to utils

* add mock function

* fix lint

* fix lint

* add statement back

* update test case

Co-authored-by: Rathish Cholarajan <rathishc24@gmail.com>
  • Loading branch information
kt474 and rathishcholarajan authored Feb 18, 2022
1 parent 7e29a7a commit 826834a
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 65 deletions.
39 changes: 11 additions & 28 deletions qiskit_ibm_runtime/runtime_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
"""Qiskit runtime job."""

from typing import Any, Optional, Callable, Dict, Type
import time
import logging
from concurrent import futures
import traceback
import queue
from datetime import datetime

from qiskit.providers.exceptions import JobTimeoutError
from qiskit.providers.backend import Backend
from qiskit.providers.jobstatus import JobStatus, JOB_FINAL_STATES

Expand Down Expand Up @@ -159,14 +157,12 @@ def interim_results(self, decoder: Optional[Type[ResultDecoder]] = None) -> Any:
def result(
self,
timeout: Optional[float] = None,
wait: float = 5,
decoder: Optional[Type[ResultDecoder]] = None,
) -> Any:
"""Return the results of the job.
Args:
timeout: Number of seconds to wait for job.
wait: Seconds between queries.
decoder: A :class:`ResultDecoder` subclass used to decode job results.
Returns:
Expand All @@ -177,7 +173,7 @@ def result(
"""
_decoder = decoder or self._result_decoder
if self._results is None or (_decoder != self._result_decoder):
self.wait_for_final_state(timeout=timeout, wait=wait)
self.wait_for_final_state(timeout=timeout)
if self._status == JobStatus.ERROR:
raise RuntimeJobFailureError(
f"Unable to retrieve job result. " f"{self.error_message()}"
Expand Down Expand Up @@ -222,29 +218,18 @@ def error_message(self) -> Optional[str]:
self._set_status_and_error_message()
return self._error_message

def wait_for_final_state(
self, timeout: Optional[float] = None, wait: float = 5
) -> None:
"""Poll the job status until it progresses to a final state such as ``DONE`` or ``ERROR``.
def wait_for_final_state(self, timeout: Optional[float] = None) -> None:
"""Use the websocket server to wait for the final the state of a job. The server
will remain open if the job is still running and the connection will be terminated
once the job completes. Then update and return the status of the job.
Args:
timeout: Seconds to wait for the job. If ``None``, wait indefinitely.
wait: Seconds between queries.
Raises:
JobTimeoutError: If the job does not reach a final state before the
specified timeout.
"""
start_time = time.time()
status = self.status()
while status not in JOB_FINAL_STATES:
elapsed_time = time.time() - start_time
if timeout is not None and elapsed_time >= timeout:
raise JobTimeoutError(
"Timeout while waiting for job {}.".format(self.job_id)
)
time.sleep(wait)
status = self.status()
if self._status not in JOB_FINAL_STATES:
self._ws_client_future = self._executor.submit(self._start_websocket_client)
self._ws_client_future.result(timeout)
self.status()

def stream_results(
self, callback: Callable, decoder: Optional[Type[ResultDecoder]] = None
Expand All @@ -264,14 +249,12 @@ def stream_results(
RuntimeInvalidStateError: If a callback function is already streaming results or
if the job already finished.
"""
if self._status in JOB_FINAL_STATES:
raise RuntimeInvalidStateError("Job already finished.")
if self._is_streaming():
raise RuntimeInvalidStateError(
"A callback function is already streaming results."
)

if self._status in JOB_FINAL_STATES:
raise RuntimeInvalidStateError("Job already finished.")

self._ws_client_future = self._executor.submit(self._start_websocket_client)
self._executor.submit(
self._stream_results,
Expand Down
9 changes: 5 additions & 4 deletions test/integration/test_interim_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,17 @@ def test_stream_results_done(self, service):

def result_callback(job_id, interim_result):
# pylint: disable=unused-argument
nonlocal called_back
called_back = True
nonlocal called_back_count
called_back_count += 1

called_back = False
called_back_count = 0
job = self._run_program(service, interim_results="foobar")
job.wait_for_final_state()
job._status = JobStatus.RUNNING # Allow stream_results()
job.stream_results(result_callback)
time.sleep(2)
self.assertFalse(called_back)
# Callback is expected twice because both interim and final results are returned
self.assertEqual(2, called_back_count)
self.assertIsNotNone(job._ws_client._server_close_code)

@run_integration_test
Expand Down
11 changes: 11 additions & 0 deletions test/unit/mock/fake_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def interim_results(self):
"""Return job interim results."""
return self._interim_results

def status(self):
"""Return job status."""
return self._status


class FailedRuntimeJob(BaseFakeRuntimeJob):
"""Class for faking a failed runtime job."""
Expand Down Expand Up @@ -451,6 +455,13 @@ def job_delete(self, job_id):
self._get_job(job_id)
del self._jobs[job_id]

def wait_for_final_state(self, job_id):
"""Wait for the final state of a program job."""
final_states = ["COMPLETED", "FAILED", "CANCELLED", "CANCELLED - RAN TOO LONG"]
status = self._get_job(job_id).status()
while status not in final_states:
status = self._get_job(job_id).status()

def _get_program(self, program_id):
"""Get program."""
if program_id not in self._programs:
Expand Down
9 changes: 6 additions & 3 deletions test/unit/test_job_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..ibm_test_case import IBMTestCase
from ..decorators import run_legacy_and_cloud_fake
from ..program import run_program, upload_program
from ..utils import mock_wait_for_final_state


class TestRetrieveJobs(IBMTestCase):
Expand Down Expand Up @@ -182,8 +183,9 @@ def test_jobs_filter_by_program_id(self, service):

job = run_program(service=service, program_id=program_id)
job_1 = run_program(service=service, program_id=program_id_1)
job.wait_for_final_state()
job_1.wait_for_final_state()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
job_1.wait_for_final_state()
rjobs = service.jobs(program_id=program_id)
self.assertEqual(program_id, rjobs[0].program_id)
self.assertEqual(1, len(rjobs))
Expand All @@ -195,7 +197,8 @@ def test_jobs_filter_by_instance(self):
instance = FakeRuntimeService.DEFAULT_HGPS[1]

job = run_program(service=service, program_id=program_id, instance=instance)
job.wait_for_final_state()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
rjobs = service.jobs(program_id=program_id, instance=instance)
self.assertTrue(rjobs)
self.assertEqual(program_id, rjobs[0].program_id)
Expand Down
68 changes: 38 additions & 30 deletions test/unit/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..decorators import run_legacy_and_cloud_fake
from ..program import run_program, upload_program
from ..serialization import get_complex_types
from ..utils import mock_wait_for_final_state


class TestRuntimeJob(IBMTestCase):
Expand All @@ -51,9 +52,10 @@ def test_run_program(self, service):
self.assertIsInstance(job, RuntimeJob)
self.assertIsInstance(job.status(), JobStatus)
self.assertEqual(job.inputs, params)
job.wait_for_final_state()
self.assertEqual(job.status(), JobStatus.DONE)
self.assertTrue(job.result())
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
self.assertEqual(job.status(), JobStatus.DONE)
self.assertTrue(job.result())

@run_legacy_and_cloud_fake
def test_run_phantom_program(self, service):
Expand Down Expand Up @@ -148,9 +150,10 @@ def test_run_program_with_custom_runtime_image(self, service):
self.assertIsInstance(job, RuntimeJob)
self.assertIsInstance(job.status(), JobStatus)
self.assertEqual(job.inputs, params)
job.wait_for_final_state()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
self.assertTrue(job.result())
self.assertEqual(job.status(), JobStatus.DONE)
self.assertTrue(job.result())
self.assertEqual(job.image, image)

@run_legacy_and_cloud_fake
Expand All @@ -164,31 +167,33 @@ def test_run_program_with_custom_log_level(self, service):
def test_run_program_failed(self, service):
"""Test a failed program execution."""
job = run_program(service=service, job_classes=FailedRuntimeJob)
job.wait_for_final_state()
job_result_raw = service._api_client.job_results(job.job_id)
self.assertEqual(JobStatus.ERROR, job.status())
self.assertEqual(
API_TO_JOB_ERROR_MESSAGE["FAILED"].format(job.job_id, job_result_raw),
job.error_message(),
)
with self.assertRaises(RuntimeJobFailureError):
job.result()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
job_result_raw = service._api_client.job_results(job.job_id)
self.assertEqual(JobStatus.ERROR, job.status())
self.assertEqual(
API_TO_JOB_ERROR_MESSAGE["FAILED"].format(job.job_id, job_result_raw),
job.error_message(),
)
with self.assertRaises(RuntimeJobFailureError):
job.result()

@run_legacy_and_cloud_fake
def test_run_program_failed_ran_too_long(self, service):
"""Test a program that failed since it ran longer than maximum execution time."""
job = run_program(service=service, job_classes=FailedRanTooLongRuntimeJob)
job.wait_for_final_state()
job_result_raw = service._api_client.job_results(job.job_id)
self.assertEqual(JobStatus.ERROR, job.status())
self.assertEqual(
API_TO_JOB_ERROR_MESSAGE["CANCELLED - RAN TOO LONG"].format(
job.job_id, job_result_raw
),
job.error_message(),
)
with self.assertRaises(RuntimeJobFailureError):
job.result()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
job_result_raw = service._api_client.job_results(job.job_id)
self.assertEqual(JobStatus.ERROR, job.status())
self.assertEqual(
API_TO_JOB_ERROR_MESSAGE["CANCELLED - RAN TOO LONG"].format(
job.job_id, job_result_raw
),
job.error_message(),
)
with self.assertRaises(RuntimeJobFailureError):
job.result()

@run_legacy_and_cloud_fake
def test_program_params_namespace(self, service):
Expand All @@ -212,8 +217,9 @@ def test_cancel_job(self, service):
def test_final_result(self, service):
"""Test getting final result."""
job = run_program(service)
result = job.result()
self.assertTrue(result)
with mock_wait_for_final_state(service, job):
result = job.result()
self.assertTrue(result)

@run_legacy_and_cloud_fake
def test_interim_results(self, service):
Expand Down Expand Up @@ -248,7 +254,8 @@ def test_job_program_id(self, service):
def test_wait_for_final_state(self, service):
"""Test wait for final state."""
job = run_program(service)
job.wait_for_final_state()
with mock_wait_for_final_state(service, job):
job.wait_for_final_state()
self.assertEqual(JobStatus.DONE, job.status())

@run_legacy_and_cloud_fake
Expand All @@ -259,8 +266,9 @@ def test_get_result_twice(self, service):
job_cls.custom_result = custom_result

job = run_program(service=service, job_classes=job_cls)
_ = job.result()
_ = job.result()
with mock_wait_for_final_state(service, job):
_ = job.result()
_ = job.result()

@run_legacy_and_cloud_fake
def test_delete_job(self, service):
Expand Down
10 changes: 10 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import time
import unittest
from unittest import mock

from qiskit import QuantumCircuit
from qiskit.providers.jobstatus import JOB_FINAL_STATES, JobStatus
Expand Down Expand Up @@ -137,3 +138,12 @@ def get_real_device(service):
).name()
except QiskitBackendNotFoundError:
raise unittest.SkipTest("No real device") # cloud has no real device


def mock_wait_for_final_state(service, job):
"""replace `wait_for_final_state` with a mock function"""
return mock.patch.object(
RuntimeJob,
"wait_for_final_state",
side_effect=service._api_client.wait_for_final_state(job.job_id),
)

0 comments on commit 826834a

Please sign in to comment.