Skip to content

Commit

Permalink
Client change google api (#2530)
Browse files Browse the repository at this point in the history
* Change client code to send cirq.google.api messages.

* Change tests to use old proto and fix coverage
  • Loading branch information
dstrain115 authored Nov 13, 2019
1 parent d7927cb commit a167b64
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 78 deletions.
10 changes: 5 additions & 5 deletions cirq/api/google/v2/proto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
"""Check protobuf modules initialize successfully."""

# pylint: disable=unused-import
from cirq.api.google.v2 import device_pb2
from cirq.api.google.v2 import metrics_pb2
from cirq.api.google.v2 import program_pb2
from cirq.api.google.v2 import result_pb2
from cirq.api.google.v2 import run_context_pb2
from cirq.google.api.v2 import device_pb2
from cirq.google.api.v2 import metrics_pb2
from cirq.google.api.v2 import program_pb2
from cirq.google.api.v2 import result_pb2
from cirq.google.api.v2 import run_context_pb2
# pylint: enable=unused-import
2 changes: 1 addition & 1 deletion cirq/google/api/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from collections import OrderedDict
import numpy as np

from cirq.api.google.v2 import result_pb2
from cirq.google.api import v2
from cirq.google.api.v2 import result_pb2
from cirq import circuits
from cirq import devices
from cirq import ops
Expand Down
9 changes: 4 additions & 5 deletions cirq/google/api/v2/results_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest

import cirq
from cirq.api.google.v2 import result_pb2
from cirq.google.api import v2


Expand Down Expand Up @@ -169,7 +168,7 @@ def test_results_to_proto():
],
]
proto = v2.results_to_proto(trial_results, measurements)
assert isinstance(proto, result_pb2.Result)
assert isinstance(proto, v2.result_pb2.Result)
assert len(proto.sweep_results) == 2
deserialized = v2.results_from_proto(proto, measurements)
assert len(deserialized) == 2
Expand Down Expand Up @@ -213,7 +212,7 @@ def test_results_from_proto_qubit_ordering():
slot=0,
invert_mask=[False, False, False])
]
proto = result_pb2.Result()
proto = v2.result_pb2.Result()
sr = proto.sweep_results.add()
sr.repetitions = 8
pr = sr.parameterized_results.add()
Expand Down Expand Up @@ -254,7 +253,7 @@ def test_results_from_proto_duplicate_qubit():
slot=0,
invert_mask=[False, False, False])
]
proto = result_pb2.Result()
proto = v2.result_pb2.Result()
sr = proto.sweep_results.add()
sr.repetitions = 8
pr = sr.parameterized_results.add()
Expand All @@ -274,7 +273,7 @@ def test_results_from_proto_duplicate_qubit():


def test_results_from_proto_default_ordering():
proto = result_pb2.Result()
proto = v2.result_pb2.Result()
sr = proto.sweep_results.add()
sr.repetitions = 8
pr = sr.parameterized_results.add()
Expand Down
2 changes: 1 addition & 1 deletion cirq/google/api/v2/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional, Dict, List

from cirq import value
from cirq.api.google.v2 import run_context_pb2
from cirq.google.api.v2 import run_context_pb2
from cirq.study import sweeps

def sweep_to_proto(
Expand Down
15 changes: 7 additions & 8 deletions cirq/google/api/v2/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import sympy

import cirq
from cirq.api.google.v2 import run_context_pb2
from cirq.google.api import v2
from cirq.study import sweeps

Expand Down Expand Up @@ -58,7 +57,7 @@ def test_sweep_to_proto_roundtrip(sweep):

def test_sweep_to_proto_linspace():
proto = v2.sweep_to_proto(cirq.Linspace('foo', 0, 1, 20))
assert isinstance(proto, run_context_pb2.Sweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
assert proto.HasField('single_sweep')
assert proto.single_sweep.parameter_key == 'foo'
assert proto.single_sweep.WhichOneof('sweep') == 'linspace'
Expand All @@ -69,7 +68,7 @@ def test_sweep_to_proto_linspace():

def test_sweep_to_proto_points():
proto = v2.sweep_to_proto(cirq.Points('foo', [-1, 0, 1, 1.5]))
assert isinstance(proto, run_context_pb2.Sweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
assert proto.HasField('single_sweep')
assert proto.single_sweep.parameter_key == 'foo'
assert proto.single_sweep.WhichOneof('sweep') == 'points'
Expand All @@ -78,7 +77,7 @@ def test_sweep_to_proto_points():

def test_sweep_to_proto_unit():
proto = v2.sweep_to_proto(cirq.UnitSweep)
assert isinstance(proto, run_context_pb2.Sweep)
assert isinstance(proto, v2.run_context_pb2.Sweep)
assert not proto.HasField('single_sweep')
assert not proto.HasField('sweep_function')

Expand All @@ -89,14 +88,14 @@ def test_sweep_from_proto_unknown_sweep_type():


def test_sweep_from_proto_sweep_function_not_set():
proto = run_context_pb2.Sweep()
proto = v2.run_context_pb2.Sweep()
proto.sweep_function.sweeps.add()
with pytest.raises(ValueError, match='invalid sweep function type'):
v2.sweep_from_proto(proto)


def test_sweep_from_proto_single_sweep_type_not_set():
proto = run_context_pb2.Sweep()
proto = v2.run_context_pb2.Sweep()
proto.single_sweep.parameter_key = 'foo'
with pytest.raises(ValueError, match='single sweep type not set'):
v2.sweep_from_proto(proto)
Expand All @@ -105,8 +104,8 @@ def test_sweep_from_proto_single_sweep_type_not_set():
def test_sweep_with_list_sweep():
ls = cirq.study.to_sweep([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}])
proto = v2.sweep_to_proto(ls)
expected = run_context_pb2.Sweep()
expected.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
expected = v2.run_context_pb2.Sweep()
expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP
p1 = expected.sweep_function.sweeps.add()
p1.single_sweep.parameter_key = 'a'
p1.single_sweep.points.points.extend([1, 3])
Expand Down
2 changes: 1 addition & 1 deletion cirq/google/engine/calibration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

_CALIBRATION_DATA = {
'@type':
'type.googleapis.com/cirq.api.google.v2.MetricsSnapshot',
'type.googleapis.com/cirq.google.api.v2.MetricsSnapshot',
'timestampMs':
'1562544000021',
'metrics': [{
Expand Down
29 changes: 13 additions & 16 deletions cirq/google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@
from google.protobuf import any_pb2

from cirq import circuits, optimizers, schedules, study, value
from cirq.api.google import v1, v2
from cirq.google import gate_sets, serializable_gate_set
from cirq.google.api import v1 as api_v1
from cirq.google.api import v2 as api_v2
from cirq.google.api import v1, v2
from cirq.google.engine import (calibration, engine_job, engine_program,
engine_sampler)

Expand Down Expand Up @@ -450,16 +448,15 @@ def _serialize_run_context(
context_dict = {} # type: Dict[str, Any]
context_dict['@type'] = TYPE_PREFIX + context_descriptor.full_name
context_dict['parameter_sweeps'] = [
api_v1.sweep_to_proto_dict(sweep, repetitions)
for sweep in sweeps
v1.sweep_to_proto_dict(sweep, repetitions) for sweep in sweeps
]
return context_dict
elif proto_version == ProtoVersion.V2:
run_context = v2.run_context_pb2.RunContext()
for sweep in sweeps:
sweep_proto = run_context.parameter_sweeps.add()
sweep_proto.repetitions = repetitions
api_v2.sweep_to_proto(sweep, out=sweep_proto.sweep)
v2.sweep_to_proto(sweep, out=sweep_proto.sweep)

return _any_dict_from_msg(run_context)
else:
Expand Down Expand Up @@ -516,7 +513,7 @@ def _serialize_program(
program_dict = {} # type: Dict[str, Any]
program_dict['@type'] = TYPE_PREFIX + program_descriptor.full_name
program_dict['operations'] = [
op for op in api_v1.schedule_to_proto_dicts(schedule)
op for op in v1.schedule_to_proto_dicts(schedule)
]
return program_dict
elif self.proto_version == ProtoVersion.V2:
Expand Down Expand Up @@ -575,15 +572,15 @@ def get_job_results(self,
parent=job_resource_name))
result = response['result']
result_type = result['@type'][len(TYPE_PREFIX):]
if result_type == 'cirq.api.google.v1.Result':
return self._get_job_results_v1(result)
if result_type == 'cirq.api.google.v2.Result':
return self._get_job_results_v2(result)
if result_type == 'cirq.google.api.v1.Result':
return self._get_job_results_v1(result)
if result_type == 'cirq.google.api.v2.Result':
# Pretend the path is the other one until we switch over
result['@type'] = 'type.googleapis.com/cirq.api.google.v2.Result'
return self._get_job_results_v2(result)
if result_type == 'cirq.api.google.v1.Result':
return self._get_job_results_v1(result)
if result_type == 'cirq.api.google.v2.Result':
# Change path to the new path
result['@type'] = 'type.googleapis.com/cirq.google.api.v2.Result'
return self._get_job_results_v2(result)
raise ValueError('invalid result proto version: {}'.format(
self.proto_version))
Expand All @@ -597,8 +594,8 @@ def _get_job_results_v1(self,
for m in sweep_result['measurementKeys']]
for result in sweep_result['parameterizedResults']:
data = base64.standard_b64decode(result['measurementResults'])
measurements = api_v1.unpack_results(data, sweep_repetitions,
key_sizes)
measurements = v1.unpack_results(data, sweep_repetitions,
key_sizes)

trial_results.append(
study.TrialResult.from_single_parameter_set(
Expand All @@ -613,7 +610,7 @@ def _get_job_results_v2(self, result_dict: Dict[str, Any]
gp.json_format.ParseDict(result_dict, result_any)
result = v2.result_pb2.Result()
result_any.Unpack(result)
sweep_results = api_v2.results_from_proto(result)
sweep_results = v2.results_from_proto(result)
# Flatten to single list to match to sampler api.
return [
trial_result for sweep_result in sweep_results
Expand Down
24 changes: 12 additions & 12 deletions cirq/google/engine/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

_A_RESULT = {
'@type':
'type.googleapis.com/cirq.api.google.v1.Result',
'type.googleapis.com/cirq.google.api.v1.Result',
'sweepResults': [{
'repetitions':
1,
Expand All @@ -55,7 +55,7 @@

_RESULTS = {
'@type':
'type.googleapis.com/cirq.api.google.v1.Result',
'type.googleapis.com/cirq.google.api.v1.Result',
'sweepResults': [{
'repetitions':
1,
Expand Down Expand Up @@ -86,7 +86,7 @@

_RESULTS_V2 = {
'@type':
'type.googleapis.com/cirq.api.google.v2.Result',
'type.googleapis.com/cirq.google.api.v2.Result',
'sweepResults': [
{
'repetitions':
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_run_circuit(build):
}
},
'run_context': {
'@type': 'type.googleapis.com/cirq.api.google.v1.RunContext',
'@type': 'type.googleapis.com/cirq.google.api.v1.RunContext',
'parameter_sweeps': [{
'repetitions': 1
}]
Expand Down Expand Up @@ -467,13 +467,13 @@ def test_run_sweep_params(build):


@mock.patch.object(discovery, 'build')
def test_run_sweep_params_new_proto(build):
def test_run_sweep_params_old_proto(build):
service = mock.Mock()
build.return_value = service
programs = service.projects().programs()
jobs = programs.jobs()
results_new_proto = copy.deepcopy(_RESULTS)
results_new_proto['@type'] = 'type.googleapis.com/cirq.google.api.v1.Result'
results_old_proto = copy.deepcopy(_RESULTS)
results_old_proto['@type'] = 'type.googleapis.com/cirq.api.google.v1.Result'
programs.create().execute.return_value = {
'name': 'projects/project-id/programs/test'
}
Expand All @@ -489,7 +489,7 @@ def test_run_sweep_params_new_proto(build):
'state': 'SUCCESS'
}
}
jobs.getResult().execute.return_value = {'result': results_new_proto}
jobs.getResult().execute.return_value = {'result': results_old_proto}

engine = cg.Engine(project_id='project-id')
job = engine.run_sweep(
Expand Down Expand Up @@ -681,7 +681,7 @@ def test_run_sweep_v2(build):


@mock.patch.object(discovery, 'build')
def test_run_sweep_v2_new_proto(build):
def test_run_sweep_v2_old_proto(build):
service = mock.Mock()
build.return_value = service
programs = service.projects().programs()
Expand All @@ -701,9 +701,9 @@ def test_run_sweep_v2_new_proto(build):
'state': 'SUCCESS'
}
}
results_new_proto = copy.deepcopy(_RESULTS_V2)
results_new_proto['@type'] = 'type.googleapis.com/cirq.google.api.v2.Result'
jobs.getResult().execute.return_value = {'result': results_new_proto}
results_old_proto = copy.deepcopy(_RESULTS_V2)
results_old_proto['@type'] = 'type.googleapis.com/cirq.api.google.v2.Result'
jobs.getResult().execute.return_value = {'result': results_old_proto}

engine = cg.Engine(
project_id='project-id',
Expand Down
11 changes: 5 additions & 6 deletions cirq/google/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@

from google.protobuf import json_format

from cirq.api.google import v2

from cirq.google.api import v2 as api_v2
from cirq.google.api import v2
from cirq.google import arg_func_langs

if TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +87,8 @@ def __init__(self,

def from_proto_dict(self, proto: Dict, *, arg_function_language: str = ''
) -> 'cirq.GateOperation':
"""Turns a cirq.api.google.v2.Operation proto into a GateOperation."""
"""Turns a cirq.google.api.v2.Operation proto into a GateOperation."""

msg = v2.program_pb2.Operation()
json_format.ParseDict(proto, msg)
return self.from_proto(msg, arg_function_language=arg_function_language)
Expand All @@ -98,8 +97,8 @@ def from_proto(self,
proto: v2.program_pb2.Operation,
*,
arg_function_language: str = '') -> 'cirq.GateOperation':
"""Turns a cirq.api.google.v2.Operation proto into a GateOperation."""
qubits = [api_v2.grid_qubit_from_proto_id(q.id) for q in proto.qubits]
"""Turns a cirq.google.api.v2.Operation proto into a GateOperation."""
qubits = [v2.grid_qubit_from_proto_id(q.id) for q in proto.qubits]
args = self._args_from_proto(
proto, arg_function_language=arg_function_language)
if self.num_qubits_param is not None:
Expand Down
9 changes: 4 additions & 5 deletions cirq/google/op_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from google.protobuf import json_format

from cirq import devices, ops
from cirq.api.google import v2

from cirq.google.api import v2 as api_v2
from cirq.google.api import v2
from cirq.google import arg_func_langs
from cirq.google.arg_func_langs import _arg_to_proto

Expand Down Expand Up @@ -117,7 +115,8 @@ def to_proto(
*,
arg_function_language: Optional[str] = '',
) -> Optional[v2.program_pb2.Operation]:
"""Returns the cirq.api.google.v2.Operation message as a proto dict."""
"""Returns the cirq.google.api.v2.Operation message as a proto dict."""

if not all(isinstance(qubit, devices.GridQubit) for qubit in op.qubits):
raise ValueError('All qubits must be GridQubits')
gate = op.gate
Expand All @@ -134,7 +133,7 @@ def to_proto(

msg.gate.id = self.serialized_gate_id
for qubit in op.qubits:
msg.qubits.add().id = api_v2.qubit_to_proto_id(
msg.qubits.add().id = v2.qubit_to_proto_id(
cast(devices.GridQubit, qubit))
for arg in self.args:
value = self._value_from_gate(gate, arg)
Expand Down
Loading

0 comments on commit a167b64

Please sign in to comment.