From d74e0dce73adf45f15b873bf28236ef355d960ce Mon Sep 17 00:00:00 2001 From: Bicheng Ying Date: Tue, 17 Sep 2024 12:34:19 -0700 Subject: [PATCH] Add support for const sweep with None (#6729) * Add support for const sweep with None * Fix the test * Fix the test coverage * Fix the test coverage * Address the comment * Use const for all single sweep * Add more test for constant * Fix typecheck * Fix the lint --- .../cirq_google/api/v2/run_context.proto | 15 +++ .../cirq_google/api/v2/run_context_pb2.py | 14 +- .../cirq_google/api/v2/run_context_pb2.pyi | 126 ++++++++++++------ cirq-google/cirq_google/api/v2/sweeps.py | 39 +++++- cirq-google/cirq_google/api/v2/sweeps_test.py | 38 +++++- .../engine/engine_processor_test.py | 8 +- cirq-google/cirq_google/engine/engine_test.py | 11 +- 7 files changed, 192 insertions(+), 59 deletions(-) diff --git a/cirq-google/cirq_google/api/v2/run_context.proto b/cirq-google/cirq_google/api/v2/run_context.proto index 9b6fae82405..d996c9c57a8 100644 --- a/cirq-google/cirq_google/api/v2/run_context.proto +++ b/cirq-google/cirq_google/api/v2/run_context.proto @@ -179,6 +179,8 @@ message SingleSweep { Points points = 2; // Uniformly-spaced sampling over a range. Linspace linspace = 3; + // A constant value. + Const const = 5; } // Optional arguments for if this is a device parameter. @@ -186,6 +188,7 @@ message SingleSweep { DeviceParameter parameter = 4; } + // A list of explicit values. message Points { // The values. @@ -207,3 +210,15 @@ message Linspace { // the same. int64 num_points = 3; } + +// A constant value. +message Const { + // The values. + oneof value { + // This value should always be true if set, which represent the python None object. + bool is_none = 1; + float float_value = 2; + int64 int_value = 3; + string string_value = 4; + } +} diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.py b/cirq-google/cirq_google/api/v2/run_context_pb2.py index 0d7ab92c607..1cb390c16d6 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.py +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.py @@ -14,7 +14,7 @@ from . import program_pb2 as cirq__google_dot_api_dot_v2_dot_program__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\x1a cirq_google/api/v2/program.proto\"\x98\x01\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\x12L\n\x1a\x64\x65vice_parameters_override\x18\x02 \x01(\x0b\x32(.cirq.google.api.v2.DeviceParametersDiff\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xc6\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"C\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\"W\n\x0f\x44\x65viceParameter\x12\x0c\n\x04path\x18\x01 \x03(\t\x12\x10\n\x03idx\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x12\n\x05units\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04_idxB\x08\n\x06_units\"\xcf\x03\n\x14\x44\x65viceParametersDiff\x12\x46\n\x06groups\x18\x01 \x03(\x0b\x32\x36.cirq.google.api.v2.DeviceParametersDiff.ResourceGroup\x12>\n\x06params\x18\x02 \x03(\x0b\x32..cirq.google.api.v2.DeviceParametersDiff.Param\x12\x0c\n\x04strs\x18\x04 \x03(\t\x1a-\n\rResourceGroup\x12\x0e\n\x06parent\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x1a\x36\n\x0cGenericValue\x12\x17\n\x0ftype_descriptor\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x1a\xb9\x01\n\x05Param\x12\x16\n\x0eresource_group\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x12-\n\x05value\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12N\n\rgeneric_value\x18\x04 \x01(\x0b\x32\x35.cirq.google.api.v2.DeviceParametersDiff.GenericValueH\x00\x42\x0b\n\tparam_val\"\xc5\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x12\x36\n\tparameter\x18\x04 \x01(\x0b\x32#.cirq.google.api.v2.DeviceParameterB\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\x42\x32\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$cirq_google/api/v2/run_context.proto\x12\x12\x63irq.google.api.v2\x1a cirq_google/api/v2/program.proto\"\x98\x01\n\nRunContext\x12<\n\x10parameter_sweeps\x18\x01 \x03(\x0b\x32\".cirq.google.api.v2.ParameterSweep\x12L\n\x1a\x64\x65vice_parameters_override\x18\x02 \x01(\x0b\x32(.cirq.google.api.v2.DeviceParametersDiff\"O\n\x0eParameterSweep\x12\x13\n\x0brepetitions\x18\x01 \x01(\x05\x12(\n\x05sweep\x18\x02 \x01(\x0b\x32\x19.cirq.google.api.v2.Sweep\"\x86\x01\n\x05Sweep\x12;\n\x0esweep_function\x18\x01 \x01(\x0b\x32!.cirq.google.api.v2.SweepFunctionH\x00\x12\x37\n\x0csingle_sweep\x18\x02 \x01(\x0b\x32\x1f.cirq.google.api.v2.SingleSweepH\x00\x42\x07\n\x05sweep\"\xc6\x01\n\rSweepFunction\x12\x45\n\rfunction_type\x18\x01 \x01(\x0e\x32..cirq.google.api.v2.SweepFunction.FunctionType\x12)\n\x06sweeps\x18\x02 \x03(\x0b\x32\x19.cirq.google.api.v2.Sweep\"C\n\x0c\x46unctionType\x12\x1d\n\x19\x46UNCTION_TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07PRODUCT\x10\x01\x12\x07\n\x03ZIP\x10\x02\"W\n\x0f\x44\x65viceParameter\x12\x0c\n\x04path\x18\x01 \x03(\t\x12\x10\n\x03idx\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x12\n\x05units\x18\x03 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04_idxB\x08\n\x06_units\"\xcf\x03\n\x14\x44\x65viceParametersDiff\x12\x46\n\x06groups\x18\x01 \x03(\x0b\x32\x36.cirq.google.api.v2.DeviceParametersDiff.ResourceGroup\x12>\n\x06params\x18\x02 \x03(\x0b\x32..cirq.google.api.v2.DeviceParametersDiff.Param\x12\x0c\n\x04strs\x18\x04 \x03(\t\x1a-\n\rResourceGroup\x12\x0e\n\x06parent\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x1a\x36\n\x0cGenericValue\x12\x17\n\x0ftype_descriptor\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x1a\xb9\x01\n\x05Param\x12\x16\n\x0eresource_group\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\x05\x12-\n\x05value\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.ArgValueH\x00\x12N\n\rgeneric_value\x18\x04 \x01(\x0b\x32\x35.cirq.google.api.v2.DeviceParametersDiff.GenericValueH\x00\x42\x0b\n\tparam_val\"\xf1\x01\n\x0bSingleSweep\x12\x15\n\rparameter_key\x18\x01 \x01(\t\x12,\n\x06points\x18\x02 \x01(\x0b\x32\x1a.cirq.google.api.v2.PointsH\x00\x12\x30\n\x08linspace\x18\x03 \x01(\x0b\x32\x1c.cirq.google.api.v2.LinspaceH\x00\x12*\n\x05\x63onst\x18\x05 \x01(\x0b\x32\x19.cirq.google.api.v2.ConstH\x00\x12\x36\n\tparameter\x18\x04 \x01(\x0b\x32#.cirq.google.api.v2.DeviceParameterB\x07\n\x05sweep\"\x18\n\x06Points\x12\x0e\n\x06points\x18\x01 \x03(\x02\"G\n\x08Linspace\x12\x13\n\x0b\x66irst_point\x18\x01 \x01(\x02\x12\x12\n\nlast_point\x18\x02 \x01(\x02\x12\x12\n\nnum_points\x18\x03 \x01(\x03\"g\n\x05\x43onst\x12\x11\n\x07is_none\x18\x01 \x01(\x08H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x03H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x42\x07\n\x05valueB2\n\x1d\x63om.google.cirq.google.api.v2B\x0fRunContextProtoP\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -43,9 +43,11 @@ _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_start=1036 _globals['_DEVICEPARAMETERSDIFF_PARAM']._serialized_end=1221 _globals['_SINGLESWEEP']._serialized_start=1224 - _globals['_SINGLESWEEP']._serialized_end=1421 - _globals['_POINTS']._serialized_start=1423 - _globals['_POINTS']._serialized_end=1447 - _globals['_LINSPACE']._serialized_start=1449 - _globals['_LINSPACE']._serialized_end=1520 + _globals['_SINGLESWEEP']._serialized_end=1465 + _globals['_POINTS']._serialized_start=1467 + _globals['_POINTS']._serialized_end=1491 + _globals['_LINSPACE']._serialized_start=1493 + _globals['_LINSPACE']._serialized_end=1564 + _globals['_CONST']._serialized_start=1566 + _globals['_CONST']._serialized_end=1669 # @@protoc_insertion_point(module_scope) diff --git a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi index 89a31694549..c8b8c32448b 100644 --- a/cirq-google/cirq_google/api/v2/run_context_pb2.pyi +++ b/cirq-google/cirq_google/api/v2/run_context_pb2.pyi @@ -2,6 +2,7 @@ @generated by mypy-protobuf. Do not edit manually! isort:skip_file """ + import builtins import cirq_google.api.v2.program_pb2 import collections.abc @@ -19,7 +20,7 @@ else: DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -@typing_extensions.final +@typing.final class RunContext(google.protobuf.message.Message): """The context for running a quantum program.""" @@ -30,6 +31,7 @@ class RunContext(google.protobuf.message.Message): @property def parameter_sweeps(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ParameterSweep]: """The parameters for operations in a program.""" + @property def device_parameters_override(self) -> global___DeviceParametersDiff: """Optional override of select device parameters before program @@ -38,18 +40,19 @@ class RunContext(google.protobuf.message.Message): If the same parameter is supplied in both places, the provision here in device_parameters_override will have no effect. """ + def __init__( self, *, parameter_sweeps: collections.abc.Iterable[global___ParameterSweep] | None = ..., device_parameters_override: global___DeviceParametersDiff | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["device_parameters_override", b"device_parameters_override"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["device_parameters_override", b"device_parameters_override", "parameter_sweeps", b"parameter_sweeps"]) -> None: ... + def HasField(self, field_name: typing.Literal["device_parameters_override", b"device_parameters_override"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["device_parameters_override", b"device_parameters_override", "parameter_sweeps", b"parameter_sweeps"]) -> None: ... global___RunContext = RunContext -@typing_extensions.final +@typing.final class ParameterSweep(google.protobuf.message.Message): """Specifies how to repeatedly sample a circuit, with or without sweeping over varying parameter-dicts. @@ -72,18 +75,19 @@ class ParameterSweep(google.protobuf.message.Message): no parameterization is assumed (and the program must have no args with symbols). """ + def __init__( self, *, repetitions: builtins.int = ..., sweep: global___Sweep | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["sweep", b"sweep"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["repetitions", b"repetitions", "sweep", b"sweep"]) -> None: ... + def HasField(self, field_name: typing.Literal["sweep", b"sweep"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["repetitions", b"repetitions", "sweep", b"sweep"]) -> None: ... global___ParameterSweep = ParameterSweep -@typing_extensions.final +@typing.final class Sweep(google.protobuf.message.Message): """A sweep over all of the parameters in a program.""" @@ -101,13 +105,13 @@ class Sweep(google.protobuf.message.Message): sweep_function: global___SweepFunction | None = ..., single_sweep: global___SingleSweep | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["single_sweep", b"single_sweep", "sweep", b"sweep", "sweep_function", b"sweep_function"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["single_sweep", b"single_sweep", "sweep", b"sweep", "sweep_function", b"sweep_function"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["sweep", b"sweep"]) -> typing_extensions.Literal["sweep_function", "single_sweep"] | None: ... + def HasField(self, field_name: typing.Literal["single_sweep", b"single_sweep", "sweep", b"sweep", "sweep_function", b"sweep_function"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["single_sweep", b"single_sweep", "sweep", b"sweep", "sweep_function", b"sweep_function"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["sweep", b"sweep"]) -> typing.Literal["sweep_function", "single_sweep"] | None: ... global___Sweep = Sweep -@typing_extensions.final +@typing.final class SweepFunction(google.protobuf.message.Message): """A function that takes multiple sweeps and produces more sweeps.""" @@ -117,7 +121,7 @@ class SweepFunction(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _FunctionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[SweepFunction._FunctionType.ValueType], builtins.type): # noqa: F821 + class _FunctionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[SweepFunction._FunctionType.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor FUNCTION_TYPE_UNSPECIFIED: SweepFunction._FunctionType.ValueType # 0 """The function type is not specified. Should never be used.""" @@ -196,32 +200,34 @@ class SweepFunction(google.protobuf.message.Message): @property def sweeps(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Sweep]: """The argument sweeps to the function.""" + def __init__( self, *, function_type: global___SweepFunction.FunctionType.ValueType = ..., sweeps: collections.abc.Iterable[global___Sweep] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["function_type", b"function_type", "sweeps", b"sweeps"]) -> None: ... + def ClearField(self, field_name: typing.Literal["function_type", b"function_type", "sweeps", b"sweeps"]) -> None: ... global___SweepFunction = SweepFunction -@typing_extensions.final +@typing.final class DeviceParameter(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor PATH_FIELD_NUMBER: builtins.int IDX_FIELD_NUMBER: builtins.int UNITS_FIELD_NUMBER: builtins.int - @property - def path(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """Path to the parameter key""" idx: builtins.int """If the value is an array, the index of the array to change.""" units: builtins.str """String representation of the units, if any. Examples: "GHz", "ns", etc. """ + @property + def path(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Path to the parameter key""" + def __init__( self, *, @@ -229,16 +235,16 @@ class DeviceParameter(google.protobuf.message.Message): idx: builtins.int | None = ..., units: builtins.str | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["_idx", b"_idx", "_units", b"_units", "idx", b"idx", "units", b"units"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["_idx", b"_idx", "_units", b"_units", "idx", b"idx", "path", b"path", "units", b"units"]) -> None: ... + def HasField(self, field_name: typing.Literal["_idx", b"_idx", "_units", b"_units", "idx", b"idx", "units", b"units"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_idx", b"_idx", "_units", b"_units", "idx", b"idx", "path", b"path", "units", b"units"]) -> None: ... @typing.overload - def WhichOneof(self, oneof_group: typing_extensions.Literal["_idx", b"_idx"]) -> typing_extensions.Literal["idx"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_idx", b"_idx"]) -> typing.Literal["idx"] | None: ... @typing.overload - def WhichOneof(self, oneof_group: typing_extensions.Literal["_units", b"_units"]) -> typing_extensions.Literal["units"] | None: ... + def WhichOneof(self, oneof_group: typing.Literal["_units", b"_units"]) -> typing.Literal["units"] | None: ... global___DeviceParameter = DeviceParameter -@typing_extensions.final +@typing.final class DeviceParametersDiff(google.protobuf.message.Message): """A bundle of multiple DeviceParameters and their values. The main use case is to set those parameters with the @@ -252,7 +258,7 @@ class DeviceParametersDiff(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - @typing_extensions.final + @typing.final class ResourceGroup(google.protobuf.message.Message): """A resource group a device parameter belongs to. The identifier of a resource group is DeviceParameter.path without the @@ -273,9 +279,9 @@ class DeviceParametersDiff(google.protobuf.message.Message): parent: builtins.int = ..., name: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["name", b"name", "parent", b"parent"]) -> None: ... + def ClearField(self, field_name: typing.Literal["name", b"name", "parent", b"parent"]) -> None: ... - @typing_extensions.final + @typing.final class GenericValue(google.protobuf.message.Message): """Param value whose type is not among proto field types supported by ArgValue. In other words, it is the responsibility of the client codes @@ -300,9 +306,9 @@ class DeviceParametersDiff(google.protobuf.message.Message): type_descriptor: builtins.str = ..., value: builtins.bytes = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["type_descriptor", b"type_descriptor", "value", b"value"]) -> None: ... + def ClearField(self, field_name: typing.Literal["type_descriptor", b"type_descriptor", "value", b"value"]) -> None: ... - @typing_extensions.final + @typing.final class Param(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -321,11 +327,13 @@ class DeviceParametersDiff(google.protobuf.message.Message): """this param's new value, as message ArgValue to allow types of bool, string, double, float and arrays. """ + @property def generic_value(self) -> global___DeviceParametersDiff.GenericValue: """this param's new value, and its type is not among the variants supported by ArgValue. """ + def __init__( self, *, @@ -334,9 +342,9 @@ class DeviceParametersDiff(google.protobuf.message.Message): value: cirq_google.api.v2.program_pb2.ArgValue | None = ..., generic_value: global___DeviceParametersDiff.GenericValue | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["generic_value", b"generic_value", "param_val", b"param_val", "value", b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["generic_value", b"generic_value", "name", b"name", "param_val", b"param_val", "resource_group", b"resource_group", "value", b"value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["param_val", b"param_val"]) -> typing_extensions.Literal["value", "generic_value"] | None: ... + def HasField(self, field_name: typing.Literal["generic_value", b"generic_value", "param_val", b"param_val", "value", b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["generic_value", b"generic_value", "name", b"name", "param_val", b"param_val", "resource_group", b"resource_group", "value", b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["param_val", b"param_val"]) -> typing.Literal["value", "generic_value"] | None: ... GROUPS_FIELD_NUMBER: builtins.int PARAMS_FIELD_NUMBER: builtins.int @@ -350,6 +358,7 @@ class DeviceParametersDiff(google.protobuf.message.Message): """List of all key, dir, and deletion names in these contents. ResourceGroup.name, Param.name, and Deletion.name are indexes into this list. """ + def __init__( self, *, @@ -357,11 +366,11 @@ class DeviceParametersDiff(google.protobuf.message.Message): params: collections.abc.Iterable[global___DeviceParametersDiff.Param] | None = ..., strs: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["groups", b"groups", "params", b"params", "strs", b"strs"]) -> None: ... + def ClearField(self, field_name: typing.Literal["groups", b"groups", "params", b"params", "strs", b"strs"]) -> None: ... global___DeviceParametersDiff = DeviceParametersDiff -@typing_extensions.final +@typing.final class SingleSweep(google.protobuf.message.Message): """A set of values to loop over for a particular parameter.""" @@ -370,6 +379,7 @@ class SingleSweep(google.protobuf.message.Message): PARAMETER_KEY_FIELD_NUMBER: builtins.int POINTS_FIELD_NUMBER: builtins.int LINSPACE_FIELD_NUMBER: builtins.int + CONST_FIELD_NUMBER: builtins.int PARAMETER_FIELD_NUMBER: builtins.int parameter_key: builtins.str """The parameter key being varied. This cannot be the empty string. @@ -378,29 +388,37 @@ class SingleSweep(google.protobuf.message.Message): @property def points(self) -> global___Points: """An explicit list of points to try.""" + @property def linspace(self) -> global___Linspace: """Uniformly-spaced sampling over a range.""" + + @property + def const(self) -> global___Const: + """A constant value.""" + @property def parameter(self) -> global___DeviceParameter: """Optional arguments for if this is a device parameter. (as opposed to a circuit symbol) """ + def __init__( self, *, parameter_key: builtins.str = ..., points: global___Points | None = ..., linspace: global___Linspace | None = ..., + const: global___Const | None = ..., parameter: global___DeviceParameter | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["linspace", b"linspace", "parameter", b"parameter", "points", b"points", "sweep", b"sweep"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["linspace", b"linspace", "parameter", b"parameter", "parameter_key", b"parameter_key", "points", b"points", "sweep", b"sweep"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["sweep", b"sweep"]) -> typing_extensions.Literal["points", "linspace"] | None: ... + def HasField(self, field_name: typing.Literal["const", b"const", "linspace", b"linspace", "parameter", b"parameter", "points", b"points", "sweep", b"sweep"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["const", b"const", "linspace", b"linspace", "parameter", b"parameter", "parameter_key", b"parameter_key", "points", b"points", "sweep", b"sweep"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["sweep", b"sweep"]) -> typing.Literal["points", "linspace", "const"] | None: ... global___SingleSweep = SingleSweep -@typing_extensions.final +@typing.final class Points(google.protobuf.message.Message): """A list of explicit values.""" @@ -410,16 +428,17 @@ class Points(google.protobuf.message.Message): @property def points(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]: """The values.""" + def __init__( self, *, points: collections.abc.Iterable[builtins.float] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["points", b"points"]) -> None: ... + def ClearField(self, field_name: typing.Literal["points", b"points"]) -> None: ... global___Points = Points -@typing_extensions.final +@typing.final class Linspace(google.protobuf.message.Message): """A range of evenly-spaced values. @@ -449,6 +468,35 @@ class Linspace(google.protobuf.message.Message): last_point: builtins.float = ..., num_points: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["first_point", b"first_point", "last_point", b"last_point", "num_points", b"num_points"]) -> None: ... + def ClearField(self, field_name: typing.Literal["first_point", b"first_point", "last_point", b"last_point", "num_points", b"num_points"]) -> None: ... global___Linspace = Linspace + +@typing.final +class Const(google.protobuf.message.Message): + """A constant value.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + IS_NONE_FIELD_NUMBER: builtins.int + FLOAT_VALUE_FIELD_NUMBER: builtins.int + INT_VALUE_FIELD_NUMBER: builtins.int + STRING_VALUE_FIELD_NUMBER: builtins.int + is_none: builtins.bool + """This value should always be true if set, which represent the python None object.""" + float_value: builtins.float + int_value: builtins.int + string_value: builtins.str + def __init__( + self, + *, + is_none: builtins.bool = ..., + float_value: builtins.float = ..., + int_value: builtins.int = ..., + string_value: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["float_value", b"float_value", "int_value", b"int_value", "is_none", b"is_none", "string_value", b"string_value", "value", b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["float_value", b"float_value", "int_value", b"int_value", "is_none", b"is_none", "string_value", b"string_value", "value", b"value"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["value", b"value"]) -> typing.Literal["is_none", "float_value", "int_value", "string_value"] | None: ... + +global___Const = Const diff --git a/cirq-google/cirq_google/api/v2/sweeps.py b/cirq-google/cirq_google/api/v2/sweeps.py index c325b5bb896..cdbe7929cb6 100644 --- a/cirq-google/cirq_google/api/v2/sweeps.py +++ b/cirq-google/cirq_google/api/v2/sweeps.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, List, Optional +from typing import Any, cast, Dict, List, Optional import sympy @@ -21,6 +21,34 @@ from cirq_google.study.device_parameter import DeviceParameter +def _build_sweep_const(value: Any) -> run_context_pb2.Const: + """Build the sweep const message from a value.""" + if value is None: + return run_context_pb2.Const(is_none=True) + elif isinstance(value, float): + return run_context_pb2.Const(float_value=value) + elif isinstance(value, int): + return run_context_pb2.Const(int_value=value) + elif isinstance(value, str): + return run_context_pb2.Const(string_value=value) + else: + raise ValueError( + f"Unsupported type for serializing const sweep: {value=} and {type(value)=}" + ) + + +def _recover_sweep_const(const_pb: run_context_pb2.Const) -> Any: + """Recover a const value from the sweep const message.""" + if const_pb.WhichOneof('value') == 'is_none': + return None + if const_pb.WhichOneof('value') == 'float_value': + return const_pb.float_value + if const_pb.WhichOneof('value') == 'int_value': + return const_pb.int_value + if const_pb.WhichOneof('value') == 'string_value': + return const_pb.string_value + + def sweep_to_proto( sweep: cirq.Sweep, *, out: Optional[run_context_pb2.Sweep] = None ) -> run_context_pb2.Sweep: @@ -63,7 +91,10 @@ def sweep_to_proto( out.single_sweep.parameter.units = sweep.metadata.units elif isinstance(sweep, cirq.Points) and not isinstance(sweep.key, sympy.Expr): out.single_sweep.parameter_key = sweep.key - out.single_sweep.points.points.extend(sweep.points) + if len(sweep.points) == 1: + out.single_sweep.const.MergeFrom(_build_sweep_const(sweep.points[0])) + else: + out.single_sweep.points.points.extend(sweep.points) # Use duck-typing to support google-internal Parameter objects if sweep.metadata and getattr(sweep.metadata, 'path', None): out.single_sweep.parameter.path.extend(sweep.metadata.path) @@ -128,6 +159,10 @@ def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: ) if msg.single_sweep.WhichOneof('sweep') == 'points': return cirq.Points(key=key, points=msg.single_sweep.points.points, metadata=metadata) + if msg.single_sweep.WhichOneof('sweep') == 'const': + return cirq.Points( + key=key, points=[_recover_sweep_const(msg.single_sweep.const)], metadata=metadata + ) raise ValueError(f'single sweep type not set: {msg}') diff --git a/cirq-google/cirq_google/api/v2/sweeps_test.py b/cirq-google/cirq_google/api/v2/sweeps_test.py index 53de71d8807..e27f6ee0df6 100644 --- a/cirq-google/cirq_google/api/v2/sweeps_test.py +++ b/cirq-google/cirq_google/api/v2/sweeps_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math from typing import Iterator import pytest @@ -68,6 +68,17 @@ def _values(self) -> Iterator[float]: + (cirq.Points('g', [1, 2]) * cirq.Points('h', [-1, 0, 1])) ) ), + # Sweep with constant. Type ignore is because cirq.Points type annotated with floats. + cirq.Points('a', [None]), # type: ignore[list-item] + cirq.Points('a', [None]) * cirq.Points('b', [1, 2, 3]), # type: ignore[list-item] + cirq.Points('a', [None]) + cirq.Points('b', [2]), # type: ignore[list-item] + cirq.Points('a', [1]), + cirq.Points('b', [1.0]), + cirq.Points('c', ["abc"]), # type: ignore[list-item] + ( + cirq.Points('a', [1]) * cirq.Points('b', [1.0]) + + cirq.Points('c', ["abc"]) * cirq.Points("d", [1, 2, 3, 4]) # type: ignore[list-item] + ), ], ) def test_sweep_to_proto_roundtrip(sweep): @@ -98,6 +109,20 @@ def test_sweep_to_proto_linspace(): ) +@pytest.mark.parametrize("val", [None, 1, 1.5, 's']) +def test_build_recover_const(val): + val2 = v2.sweeps._recover_sweep_const(v2.sweeps._build_sweep_const(val)) + if isinstance(val, float): + assert math.isclose(val, val2) # avoid the floating precision issue. + else: + assert val2 == val + + +def test_build_const_unsupported_type(): + with pytest.raises(ValueError, match='Unsupported type for serializing const sweep'): + v2.sweeps._build_sweep_const((1, 2)) + + def test_list_sweep_bad_expression(): with pytest.raises(TypeError, match='formula'): _ = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})]) @@ -111,7 +136,7 @@ def test_symbol_to_string_conversion(): expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP p1 = expected.sweep_function.sweeps.add() p1.single_sweep.parameter_key = 'a' - p1.single_sweep.points.points.extend([4.0]) + p1.single_sweep.const.float_value = 4.0 assert proto == expected @@ -131,6 +156,15 @@ def test_sweep_to_proto_unit(): assert not proto.HasField('sweep_function') +def test_sweep_to_none_const(): + proto = v2.sweep_to_proto(cirq.Points('foo', [None])) + assert isinstance(proto, v2.run_context_pb2.Sweep) + assert proto.HasField('single_sweep') + assert proto.single_sweep.parameter_key == 'foo' + assert proto.single_sweep.WhichOneof('sweep') == 'const' + assert proto.single_sweep.const.is_none + + def test_sweep_from_proto_unknown_sweep_type(): with pytest.raises(ValueError, match='cannot convert to v2 Sweep proto'): v2.sweep_to_proto(UnknownSweep('foo')) diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index a120f309fab..e273c5ce26f 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -871,9 +871,9 @@ def test_run_sweep_params_with_unary_rpcs(client): client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 - for i, v in enumerate([1.0, 2.0]): + for i, v in enumerate([1, 2]): assert sweeps[i].repetitions == 1 - assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] + assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.const.int_value == v client().get_job_async.assert_called_once() client().get_job_results_async.assert_called_once() @@ -912,9 +912,9 @@ def test_run_sweep_params_with_stream_rpcs(client): client().run_job_over_stream.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 - for i, v in enumerate([1.0, 2.0]): + for i, v in enumerate([1, 2]): assert sweeps[i].repetitions == 1 - assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] + assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.const.int_value == v @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 80487b75452..22c90761cb5 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -492,9 +492,9 @@ def test_run_sweep_params_with_unary_rpcs(client): client().create_job_async.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 - for i, v in enumerate([1.0, 2.0]): + for i, v in enumerate([1, 2]): assert sweeps[i].repetitions == 1 - assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] + assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.const.int_value == v client().get_job_async.assert_called_once() client().get_job_results_async.assert_called_once() @@ -522,9 +522,9 @@ def test_run_sweep_params_with_stream_rpcs(client): client().run_job_over_stream.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 - for i, v in enumerate([1.0, 2.0]): + for i, v in enumerate([1, 2]): assert sweeps[i].repetitions == 1 - assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] + assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.const.int_value == v @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) @@ -560,8 +560,7 @@ def test_run_multiple_times(client): assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} assert len(sweeps1) == 1 assert sweeps1[0].repetitions == 1 - points1 = sweeps1[0].sweep.sweep_function.sweeps[0].single_sweep.points - assert points1.points == [1] + assert sweeps1[0].sweep.sweep_function.sweeps[0].single_sweep.const.int_value == 1 assert len(sweeps2) == 1 assert sweeps2[0].repetitions == 2 assert sweeps2[0].sweep.single_sweep.points.points == [3, 4]