Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creating new gatesets with added gates #2458

Merged
merged 11 commits into from
Nov 6, 2019
2 changes: 1 addition & 1 deletion cirq/google/api/v2/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Type, TYPE_CHECKING
from typing import TYPE_CHECKING

from cirq import devices, ops

Expand Down
35 changes: 28 additions & 7 deletions cirq/google/serializable_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.
"""Support for serializing and deserializing cirq.api.google.v2 protos."""

from collections import defaultdict

from typing import cast, Dict, Iterable, List, Optional, Tuple, Type, Union, \
TYPE_CHECKING
from typing import (cast, Dict, Iterable, List, Optional, Tuple, Type, Union,
TYPE_CHECKING)

from google.protobuf import json_format

Expand Down Expand Up @@ -51,12 +49,35 @@ def __init__(self, gate_set_name: str,
forms of gates to GateOperations.
"""
self.gate_set_name = gate_set_name
self.serializers = defaultdict(
list) # type: Dict[Type, List[op_serializer.GateOpSerializer]]
self.serializers: Dict[Type, List[op_serializer.GateOpSerializer]] = {}
for s in serializers:
self.serializers[s.gate_type].append(s)
self.serializers.setdefault(s.gate_type, []).append(s)
self.deserializers = {d.serialized_gate_id: d for d in deserializers}

def with_added_gates(
self,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this asterisk? Is there a reason to have an empty args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces using keyword args, which I think helps clarity. You're correct that caling this with no arguments is not very useful because it just returns an equivalent gateset. I could also make serializers and deserializers be positional args and put the optional name last. Do you have a preference?

gate_set_name: Optional[str] = None,
serializers: Iterable[op_serializer.GateOpSerializer] = (),
deserializers: Iterable[op_deserializer.GateOpDeserializer] = (),
) -> 'SerializableGateSet':
"""Creates a new gateset with additional (de)serializers.

Args:
gate_set_name: Optional new name of the gateset. If not given, use
the same name as this gateset.
serializers: Serializers to add to those in this gateset.
deserializers: Deserializers to add to those in this gateset.
"""
# Iterate over all serializers in this gateset.
curr_serializers = (serializer
for serializers in self.serializers.values()
for serializer in serializers)
return SerializableGateSet(
gate_set_name or self.gate_set_name,
serializers=[*curr_serializers, *serializers],
deserializers=[*self.deserializers.values(), *deserializers])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need _all_serializers() but not deserializers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Serializers are stored as a map from gate_id to list of serializers, so we have to flatten those into one list. Deserializers are stored as a map from gate id to deserializer, so we can just use .values(). I will refactor to just inline the expression, since it's probably not worth having a whole method for the serializer flattening.


def supported_gate_types(self) -> Tuple:
return tuple(self.serializers.keys())

Expand Down
91 changes: 71 additions & 20 deletions cirq/google/serializable_gate_set_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,57 @@
import cirq
import cirq.google as cg

X_SERIALIZER = cg.GateOpSerializer(gate_type=cirq.XPowGate,
serialized_gate_id='x_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent')
])

X_DESERIALIZER = cg.GateOpDeserializer(serialized_gate_id='x_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent')
])

MY_GATE_SET = cg.SerializableGateSet(gate_set_name='my_gate_set',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER])
X_SERIALIZER = cg.GateOpSerializer(
gate_type=cirq.XPowGate,
serialized_gate_id='x_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent',
)
],
)

X_DESERIALIZER = cg.GateOpDeserializer(
serialized_gate_id='x_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent',
)
],
)

Y_SERIALIZER = cg.GateOpSerializer(
gate_type=cirq.YPowGate,
serialized_gate_id='y_pow',
args=[
cg.SerializingArg(
serialized_name='half_turns',
serialized_type=float,
gate_getter='exponent',
)
],
)

Y_DESERIALIZER = cg.GateOpDeserializer(
serialized_gate_id='y_pow',
gate_constructor=cirq.XPowGate,
args=[
cg.DeserializingArg(
serialized_name='half_turns',
constructor_arg_name='exponent',
)
],
)

MY_GATE_SET = cg.SerializableGateSet(
gate_set_name='my_gate_set',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER],
)


def test_supported_gate_types():
Expand Down Expand Up @@ -338,6 +369,26 @@ def test_multiple_serializers():
assert gate_set.serialize_op(cirq.X(q0)**0.5).gate.id == 'x_pow'


def test_gateset_with_added_gates():
x_gateset = cg.SerializableGateSet(
gate_set_name='x',
serializers=[X_SERIALIZER],
deserializers=[X_DESERIALIZER],
)
xy_gateset = x_gateset.with_added_gates(
gate_set_name='xy',
serializers=[Y_SERIALIZER],
deserializers=[Y_DESERIALIZER],
)
assert x_gateset.gate_set_name == 'x'
assert x_gateset.is_supported_gate(cirq.X)
assert not x_gateset.is_supported_gate(cirq.Y)

assert xy_gateset.gate_set_name == 'xy'
assert xy_gateset.is_supported_gate(cirq.X)
assert xy_gateset.is_supported_gate(cirq.Y)


def test_deserialize_op_invalid_gate():
proto = {
'gate': {},
Expand Down