diff --git a/dev_tools/qualtran_dev_tools/bloq_report_card.py b/dev_tools/qualtran_dev_tools/bloq_report_card.py index 053a31872..4af1d66d5 100644 --- a/dev_tools/qualtran_dev_tools/bloq_report_card.py +++ b/dev_tools/qualtran_dev_tools/bloq_report_card.py @@ -105,6 +105,7 @@ def get_bloq_report_card( bexamples: Optional[Iterable[BloqExample]] = None, package_prefix: str = 'qualtran.bloqs.', ) -> pd.DataFrame: + if bclasses is None: bclasses = get_bloq_classes() if bexamples is None: diff --git a/qualtran/protos/__init__.py b/qualtran/protos/__init__.py index a1c0721cb..ce76a8d10 100644 --- a/qualtran/protos/__init__.py +++ b/qualtran/protos/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/qualtran/protos/bloq.proto b/qualtran/protos/bloq.proto index 03e9e54f3..c5e0bc30a 100644 --- a/qualtran/protos/bloq.proto +++ b/qualtran/protos/bloq.proto @@ -21,6 +21,7 @@ import "qualtran/protos/args.proto"; import "qualtran/protos/registers.proto"; import "qualtran/protos/data_types.proto"; import "qualtran/protos/ctrl_spec.proto"; +import "qualtran/protos/sympy.proto"; package qualtran; @@ -32,7 +33,7 @@ message BloqArg { double float_val = 3; string string_val = 4; // Sympy expression generated using str(expr). - string sympy_expr = 5; + Term sympy_expr = 5; // N-dimensional numpy array stored as bytes. NDArray ndarray = 6; // Integer reference of a subbloq. Assumes access to a BloqLibrary. diff --git a/qualtran/protos/bloq_pb2.py b/qualtran/protos/bloq_pb2.py index c8b631096..bec27c2e7 100644 --- a/qualtran/protos/bloq_pb2.py +++ b/qualtran/protos/bloq_pb2.py @@ -17,9 +17,10 @@ from qualtran.protos import registers_pb2 as qualtran_dot_protos_dot_registers__pb2 from qualtran.protos import data_types_pb2 as qualtran_dot_protos_dot_data__types__pb2 from qualtran.protos import ctrl_spec_pb2 as qualtran_dot_protos_dot_ctrl__spec__pb2 +from qualtran.protos import sympy_pb2 as qualtran_dot_protos_dot_sympy__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aqualtran/protos/bloq.proto\x12\x08qualtran\x1a!qualtran/protos/annotations.proto\x1a\x1aqualtran/protos/args.proto\x1a\x1fqualtran/protos/registers.proto\x1a qualtran/protos/data_types.proto\x1a\x1fqualtran/protos/ctrl_spec.proto\"\x95\x03\n\x07\x42loqArg\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\x07int_val\x18\x02 \x01(\x03H\x00\x12\x13\n\tfloat_val\x18\x03 \x01(\x01H\x00\x12\x14\n\nstring_val\x18\x04 \x01(\tH\x00\x12\x14\n\nsympy_expr\x18\x05 \x01(\tH\x00\x12$\n\x07ndarray\x18\x06 \x01(\x0b\x32\x11.qualtran.NDArrayH\x00\x12\x11\n\x07subbloq\x18\x07 \x01(\x05H\x00\x12\x18\n\x0e\x63irq_json_gzip\x18\x08 \x01(\x0cH\x00\x12)\n\nqdata_type\x18\t \x01(\x0b\x32\x13.qualtran.QDataTypeH\x00\x12&\n\x08register\x18\n \x01(\x0b\x32\x12.qualtran.RegisterH\x00\x12(\n\tregisters\x18\x0b \x01(\x0b\x32\x13.qualtran.RegistersH\x00\x12\'\n\tctrl_spec\x18\x0c \x01(\x0b\x32\x12.qualtran.CtrlSpecH\x00\x12(\n\x0b\x63omplex_val\x18\r \x01(\x0b\x32\x11.qualtran.ComplexH\x00\x42\x05\n\x03val\"\xe8\x02\n\x0b\x42loqLibrary\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\x05table\x18\x02 \x03(\x0b\x32+.qualtran.BloqLibrary.BloqWithDecomposition\x1a\x8e\x02\n\x15\x42loqWithDecomposition\x12\x0f\n\x07\x62loq_id\x18\x01 \x01(\x05\x12+\n\rdecomposition\x18\x02 \x03(\x0b\x32\x14.qualtran.Connection\x12P\n\x0b\x62loq_counts\x18\x03 \x03(\x0b\x32;.qualtran.BloqLibrary.BloqWithDecomposition.BloqCountsEntry\x12\x1c\n\x04\x62loq\x18\x04 \x01(\x0b\x32\x0e.qualtran.Bloq\x1aG\n\x0f\x42loqCountsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.qualtran.IntOrSympy:\x02\x38\x01\"\x8a\x01\n\x04\x42loq\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1f\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x11.qualtran.BloqArg\x12&\n\tregisters\x18\x03 \x01(\x0b\x32\x13.qualtran.Registers\x12+\n\x0ct_complexity\x18\x04 \x01(\x0b\x32\x15.qualtran.TComplexity\"4\n\x0c\x42loqInstance\x12\x13\n\x0binstance_id\x18\x01 \x01(\x05\x12\x0f\n\x07\x62loq_id\x18\x02 \x01(\x05\"\x8d\x01\n\x06Soquet\x12/\n\rbloq_instance\x18\x01 \x01(\x0b\x32\x16.qualtran.BloqInstanceH\x00\x12\x14\n\ndangling_t\x18\x02 \x01(\tH\x00\x12$\n\x08register\x18\x03 \x01(\x0b\x32\x12.qualtran.Register\x12\r\n\x05index\x18\x04 \x03(\x05\x42\x07\n\x05\x62inst\"M\n\nConnection\x12\x1e\n\x04left\x18\x01 \x01(\x0b\x32\x10.qualtran.Soquet\x12\x1f\n\x05right\x18\x02 \x01(\x0b\x32\x10.qualtran.Soquetb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1aqualtran/protos/bloq.proto\x12\x08qualtran\x1a!qualtran/protos/annotations.proto\x1a\x1aqualtran/protos/args.proto\x1a\x1fqualtran/protos/registers.proto\x1a qualtran/protos/data_types.proto\x1a\x1fqualtran/protos/ctrl_spec.proto\x1a\x1bqualtran/protos/sympy.proto\"\xa5\x03\n\x07\x42loqArg\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\x07int_val\x18\x02 \x01(\x03H\x00\x12\x13\n\tfloat_val\x18\x03 \x01(\x01H\x00\x12\x14\n\nstring_val\x18\x04 \x01(\tH\x00\x12$\n\nsympy_expr\x18\x05 \x01(\x0b\x32\x0e.qualtran.TermH\x00\x12$\n\x07ndarray\x18\x06 \x01(\x0b\x32\x11.qualtran.NDArrayH\x00\x12\x11\n\x07subbloq\x18\x07 \x01(\x05H\x00\x12\x18\n\x0e\x63irq_json_gzip\x18\x08 \x01(\x0cH\x00\x12)\n\nqdata_type\x18\t \x01(\x0b\x32\x13.qualtran.QDataTypeH\x00\x12&\n\x08register\x18\n \x01(\x0b\x32\x12.qualtran.RegisterH\x00\x12(\n\tregisters\x18\x0b \x01(\x0b\x32\x13.qualtran.RegistersH\x00\x12\'\n\tctrl_spec\x18\x0c \x01(\x0b\x32\x12.qualtran.CtrlSpecH\x00\x12(\n\x0b\x63omplex_val\x18\r \x01(\x0b\x32\x11.qualtran.ComplexH\x00\x42\x05\n\x03val\"\xe8\x02\n\x0b\x42loqLibrary\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\x05table\x18\x02 \x03(\x0b\x32+.qualtran.BloqLibrary.BloqWithDecomposition\x1a\x8e\x02\n\x15\x42loqWithDecomposition\x12\x0f\n\x07\x62loq_id\x18\x01 \x01(\x05\x12+\n\rdecomposition\x18\x02 \x03(\x0b\x32\x14.qualtran.Connection\x12P\n\x0b\x62loq_counts\x18\x03 \x03(\x0b\x32;.qualtran.BloqLibrary.BloqWithDecomposition.BloqCountsEntry\x12\x1c\n\x04\x62loq\x18\x04 \x01(\x0b\x32\x0e.qualtran.Bloq\x1aG\n\x0f\x42loqCountsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.qualtran.IntOrSympy:\x02\x38\x01\"\x8a\x01\n\x04\x42loq\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1f\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x11.qualtran.BloqArg\x12&\n\tregisters\x18\x03 \x01(\x0b\x32\x13.qualtran.Registers\x12+\n\x0ct_complexity\x18\x04 \x01(\x0b\x32\x15.qualtran.TComplexity\"4\n\x0c\x42loqInstance\x12\x13\n\x0binstance_id\x18\x01 \x01(\x05\x12\x0f\n\x07\x62loq_id\x18\x02 \x01(\x05\"\x8d\x01\n\x06Soquet\x12/\n\rbloq_instance\x18\x01 \x01(\x0b\x32\x16.qualtran.BloqInstanceH\x00\x12\x14\n\ndangling_t\x18\x02 \x01(\tH\x00\x12$\n\x08register\x18\x03 \x01(\x0b\x32\x12.qualtran.Register\x12\r\n\x05index\x18\x04 \x03(\x05\x42\x07\n\x05\x62inst\"M\n\nConnection\x12\x1e\n\x04left\x18\x01 \x01(\x0b\x32\x10.qualtran.Soquet\x12\x1f\n\x05right\x18\x02 \x01(\x0b\x32\x10.qualtran.Soquetb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,20 +29,20 @@ DESCRIPTOR._options = None _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._options = None _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._serialized_options = b'8\001' - _globals['_BLOQARG']._serialized_start=204 - _globals['_BLOQARG']._serialized_end=609 - _globals['_BLOQLIBRARY']._serialized_start=612 - _globals['_BLOQLIBRARY']._serialized_end=972 - _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION']._serialized_start=702 - _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION']._serialized_end=972 - _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._serialized_start=901 - _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._serialized_end=972 - _globals['_BLOQ']._serialized_start=975 - _globals['_BLOQ']._serialized_end=1113 - _globals['_BLOQINSTANCE']._serialized_start=1115 - _globals['_BLOQINSTANCE']._serialized_end=1167 - _globals['_SOQUET']._serialized_start=1170 - _globals['_SOQUET']._serialized_end=1311 - _globals['_CONNECTION']._serialized_start=1313 - _globals['_CONNECTION']._serialized_end=1390 + _globals['_BLOQARG']._serialized_start=233 + _globals['_BLOQARG']._serialized_end=654 + _globals['_BLOQLIBRARY']._serialized_start=657 + _globals['_BLOQLIBRARY']._serialized_end=1017 + _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION']._serialized_start=747 + _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION']._serialized_end=1017 + _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._serialized_start=946 + _globals['_BLOQLIBRARY_BLOQWITHDECOMPOSITION_BLOQCOUNTSENTRY']._serialized_end=1017 + _globals['_BLOQ']._serialized_start=1020 + _globals['_BLOQ']._serialized_end=1158 + _globals['_BLOQINSTANCE']._serialized_start=1160 + _globals['_BLOQINSTANCE']._serialized_end=1212 + _globals['_SOQUET']._serialized_start=1215 + _globals['_SOQUET']._serialized_end=1356 + _globals['_CONNECTION']._serialized_start=1358 + _globals['_CONNECTION']._serialized_end=1435 # @@protoc_insertion_point(module_scope) diff --git a/qualtran/protos/bloq_pb2.pyi b/qualtran/protos/bloq_pb2.pyi index aed8f8c4c..6f3d44727 100644 --- a/qualtran/protos/bloq_pb2.pyi +++ b/qualtran/protos/bloq_pb2.pyi @@ -26,6 +26,7 @@ import qualtran.protos.args_pb2 import qualtran.protos.ctrl_spec_pb2 import qualtran.protos.data_types_pb2 import qualtran.protos.registers_pb2 +import qualtran.protos.sympy_pb2 import sys if sys.version_info >= (3, 8): @@ -56,8 +57,9 @@ class BloqArg(google.protobuf.message.Message): int_val: builtins.int float_val: builtins.float string_val: builtins.str - sympy_expr: builtins.str - """Sympy expression generated using str(expr).""" + @property + def sympy_expr(self) -> qualtran.protos.sympy_pb2.Term: + """Sympy expression generated using str(expr).""" @property def ndarray(self) -> qualtran.protos.args_pb2.NDArray: """N-dimensional numpy array stored as bytes.""" @@ -86,7 +88,7 @@ class BloqArg(google.protobuf.message.Message): int_val: builtins.int = ..., float_val: builtins.float = ..., string_val: builtins.str = ..., - sympy_expr: builtins.str = ..., + sympy_expr: qualtran.protos.sympy_pb2.Term | None = ..., ndarray: qualtran.protos.args_pb2.NDArray | None = ..., subbloq: builtins.int = ..., cirq_json_gzip: builtins.bytes = ..., diff --git a/qualtran/protos/sympy.proto b/qualtran/protos/sympy.proto new file mode 100644 index 000000000..108de5fc2 --- /dev/null +++ b/qualtran/protos/sympy.proto @@ -0,0 +1,75 @@ +/* + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +syntax = "proto3"; +package qualtran; + +// A function sympy expression. +enum Function { + // Each Term has an associated function. A "NONE" function means that the term + // is made up of a single parameter and can not be decomposed further. + NONE = 0; + Add = 1; + Mul = 2; + Pow = 3; + Mod = 4; + Log = 5; + Floor = 6; + Ceiling = 7; + Max = 8; + Min = 9; + Sin = 10; + Cos = 11; + Tan = 12; +} + +// Represents a constant, rational number. +message Rational { + Parameter numerator = 1; + Parameter denominator = 2; +} + +enum ConstSymbol { + Pi = 0; + E = 1; + EulerGamma = 2; + Infinity = 3; + ImaginaryUnit = 4; +} + +// A single parameter of a sympy expression. +message Parameter { + oneof parameter { + int32 const_int = 1; + string symbol = 2; + Rational const_rat = 3; + float const_float = 4; + ConstSymbol const_symbol = 5; + + } +} + +message Operand { + oneof operand { + Term term = 1; + Parameter parameter = 2; + } +} + +message Term { + Function function = 1; + repeated Operand operands = 2; +} diff --git a/qualtran/protos/sympy_pb2.py b/qualtran/protos/sympy_pb2.py new file mode 100644 index 000000000..d090f4f74 --- /dev/null +++ b/qualtran/protos/sympy_pb2.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: qualtran/protos/sympy.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1bqualtran/protos/sympy.proto\x12\x08qualtran\"\\\n\x08Rational\x12&\n\tnumerator\x18\x01 \x01(\x0b\x32\x13.qualtran.Parameter\x12(\n\x0b\x64\x65nominator\x18\x02 \x01(\x0b\x32\x13.qualtran.Parameter\"\xae\x01\n\tParameter\x12\x13\n\tconst_int\x18\x01 \x01(\x05H\x00\x12\x10\n\x06symbol\x18\x02 \x01(\tH\x00\x12\'\n\tconst_rat\x18\x03 \x01(\x0b\x32\x12.qualtran.RationalH\x00\x12\x15\n\x0b\x63onst_float\x18\x04 \x01(\x02H\x00\x12-\n\x0c\x63onst_symbol\x18\x05 \x01(\x0e\x32\x15.qualtran.ConstSymbolH\x00\x42\x0b\n\tparameter\"^\n\x07Operand\x12\x1e\n\x04term\x18\x01 \x01(\x0b\x32\x0e.qualtran.TermH\x00\x12(\n\tparameter\x18\x02 \x01(\x0b\x32\x13.qualtran.ParameterH\x00\x42\t\n\x07operand\"Q\n\x04Term\x12$\n\x08\x66unction\x18\x01 \x01(\x0e\x32\x12.qualtran.Function\x12#\n\x08operands\x18\x02 \x03(\x0b\x32\x11.qualtran.Operand*\x86\x01\n\x08\x46unction\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03\x41\x64\x64\x10\x01\x12\x07\n\x03Mul\x10\x02\x12\x07\n\x03Pow\x10\x03\x12\x07\n\x03Mod\x10\x04\x12\x07\n\x03Log\x10\x05\x12\t\n\x05\x46loor\x10\x06\x12\x0b\n\x07\x43\x65iling\x10\x07\x12\x07\n\x03Max\x10\x08\x12\x07\n\x03Min\x10\t\x12\x07\n\x03Sin\x10\n\x12\x07\n\x03\x43os\x10\x0b\x12\x07\n\x03Tan\x10\x0c*M\n\x0b\x43onstSymbol\x12\x06\n\x02Pi\x10\x00\x12\x05\n\x01\x45\x10\x01\x12\x0e\n\nEulerGamma\x10\x02\x12\x0c\n\x08Infinity\x10\x03\x12\x11\n\rImaginaryUnit\x10\x04\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'qualtran.protos.sympy_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_FUNCTION']._serialized_start=492 + _globals['_FUNCTION']._serialized_end=626 + _globals['_CONSTSYMBOL']._serialized_start=628 + _globals['_CONSTSYMBOL']._serialized_end=705 + _globals['_RATIONAL']._serialized_start=41 + _globals['_RATIONAL']._serialized_end=133 + _globals['_PARAMETER']._serialized_start=136 + _globals['_PARAMETER']._serialized_end=310 + _globals['_OPERAND']._serialized_start=312 + _globals['_OPERAND']._serialized_end=406 + _globals['_TERM']._serialized_start=408 + _globals['_TERM']._serialized_end=489 +# @@protoc_insertion_point(module_scope) diff --git a/qualtran/protos/sympy_pb2.pyi b/qualtran/protos/sympy_pb2.pyi new file mode 100644 index 000000000..79aea4d17 --- /dev/null +++ b/qualtran/protos/sympy_pb2.pyi @@ -0,0 +1,194 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file + +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _Function: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _FunctionEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Function.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + NONE: _Function.ValueType # 0 + """Each Term has an associated function. A "NONE" function means that the term + is made up of a single parameter and can not be decomposed further. + """ + Add: _Function.ValueType # 1 + Mul: _Function.ValueType # 2 + Pow: _Function.ValueType # 3 + Mod: _Function.ValueType # 4 + Log: _Function.ValueType # 5 + Floor: _Function.ValueType # 6 + Ceiling: _Function.ValueType # 7 + Max: _Function.ValueType # 8 + Min: _Function.ValueType # 9 + Sin: _Function.ValueType # 10 + Cos: _Function.ValueType # 11 + Tan: _Function.ValueType # 12 + +class Function(_Function, metaclass=_FunctionEnumTypeWrapper): + """A function sympy expression.""" + +NONE: Function.ValueType # 0 +"""Each Term has an associated function. A "NONE" function means that the term +is made up of a single parameter and can not be decomposed further. +""" +Add: Function.ValueType # 1 +Mul: Function.ValueType # 2 +Pow: Function.ValueType # 3 +Mod: Function.ValueType # 4 +Log: Function.ValueType # 5 +Floor: Function.ValueType # 6 +Ceiling: Function.ValueType # 7 +Max: Function.ValueType # 8 +Min: Function.ValueType # 9 +Sin: Function.ValueType # 10 +Cos: Function.ValueType # 11 +Tan: Function.ValueType # 12 +global___Function = Function + +class _ConstSymbol: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _ConstSymbolEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_ConstSymbol.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + Pi: _ConstSymbol.ValueType # 0 + E: _ConstSymbol.ValueType # 1 + EulerGamma: _ConstSymbol.ValueType # 2 + Infinity: _ConstSymbol.ValueType # 3 + ImaginaryUnit: _ConstSymbol.ValueType # 4 + +class ConstSymbol(_ConstSymbol, metaclass=_ConstSymbolEnumTypeWrapper): ... + +Pi: ConstSymbol.ValueType # 0 +E: ConstSymbol.ValueType # 1 +EulerGamma: ConstSymbol.ValueType # 2 +Infinity: ConstSymbol.ValueType # 3 +ImaginaryUnit: ConstSymbol.ValueType # 4 +global___ConstSymbol = ConstSymbol + +@typing_extensions.final +class Rational(google.protobuf.message.Message): + """Represents a constant, rational number.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NUMERATOR_FIELD_NUMBER: builtins.int + DENOMINATOR_FIELD_NUMBER: builtins.int + @property + def numerator(self) -> global___Parameter: ... + @property + def denominator(self) -> global___Parameter: ... + def __init__( + self, + *, + numerator: global___Parameter | None = ..., + denominator: global___Parameter | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["denominator", b"denominator", "numerator", b"numerator"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["denominator", b"denominator", "numerator", b"numerator"]) -> None: ... + +global___Rational = Rational + +@typing_extensions.final +class Parameter(google.protobuf.message.Message): + """A single parameter of a sympy expression.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CONST_INT_FIELD_NUMBER: builtins.int + SYMBOL_FIELD_NUMBER: builtins.int + CONST_RAT_FIELD_NUMBER: builtins.int + CONST_FLOAT_FIELD_NUMBER: builtins.int + CONST_SYMBOL_FIELD_NUMBER: builtins.int + const_int: builtins.int + symbol: builtins.str + @property + def const_rat(self) -> global___Rational: ... + const_float: builtins.float + const_symbol: global___ConstSymbol.ValueType + def __init__( + self, + *, + const_int: builtins.int = ..., + symbol: builtins.str = ..., + const_rat: global___Rational | None = ..., + const_float: builtins.float = ..., + const_symbol: global___ConstSymbol.ValueType = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["const_float", b"const_float", "const_int", b"const_int", "const_rat", b"const_rat", "const_symbol", b"const_symbol", "parameter", b"parameter", "symbol", b"symbol"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["const_float", b"const_float", "const_int", b"const_int", "const_rat", b"const_rat", "const_symbol", b"const_symbol", "parameter", b"parameter", "symbol", b"symbol"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["parameter", b"parameter"]) -> typing_extensions.Literal["const_int", "symbol", "const_rat", "const_float", "const_symbol"] | None: ... + +global___Parameter = Parameter + +@typing_extensions.final +class Operand(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TERM_FIELD_NUMBER: builtins.int + PARAMETER_FIELD_NUMBER: builtins.int + @property + def term(self) -> global___Term: ... + @property + def parameter(self) -> global___Parameter: ... + def __init__( + self, + *, + term: global___Term | None = ..., + parameter: global___Parameter | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["operand", b"operand", "parameter", b"parameter", "term", b"term"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["operand", b"operand", "parameter", b"parameter", "term", b"term"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["operand", b"operand"]) -> typing_extensions.Literal["term", "parameter"] | None: ... + +global___Operand = Operand + +@typing_extensions.final +class Term(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_FIELD_NUMBER: builtins.int + OPERANDS_FIELD_NUMBER: builtins.int + function: global___Function.ValueType + @property + def operands(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Operand]: ... + def __init__( + self, + *, + function: global___Function.ValueType = ..., + operands: collections.abc.Iterable[global___Operand] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["function", b"function", "operands", b"operands"]) -> None: ... + +global___Term = Term diff --git a/qualtran/serialization/bloq.py b/qualtran/serialization/bloq.py index c1a4de5bd..c2d4dd3da 100644 --- a/qualtran/serialization/bloq.py +++ b/qualtran/serialization/bloq.py @@ -20,7 +20,6 @@ import cirq import numpy as np import sympy -from sympy.parsing.sympy_parser import parse_expr from qualtran import ( Bloq, @@ -47,6 +46,7 @@ registers, resolver_dict, ) +from qualtran.serialization.sympy import sympy_expr_from_proto, sympy_expr_to_proto def arg_to_proto(*, name: str, val: Any) -> bloq_pb2.BloqArg: @@ -57,7 +57,7 @@ def arg_to_proto(*, name: str, val: Any) -> bloq_pb2.BloqArg: if isinstance(val, str): return bloq_pb2.BloqArg(name=name, string_val=val) if isinstance(val, sympy.Expr): - return bloq_pb2.BloqArg(name=name, sympy_expr=str(val)) + return bloq_pb2.BloqArg(name=name, sympy_expr=sympy_expr_to_proto(val)) if isinstance(val, Register): return bloq_pb2.BloqArg(name=name, register=registers.register_to_proto(val)) if isinstance(val, tuple) and all(isinstance(x, Register) for x in val): @@ -83,7 +83,7 @@ def arg_from_proto(arg: bloq_pb2.BloqArg) -> Dict[str, Any]: if arg.HasField("string_val"): return {arg.name: arg.string_val} if arg.HasField("sympy_expr"): - return {arg.name: parse_expr(arg.sympy_expr)} + return {arg.name: sympy_expr_from_proto(arg.sympy_expr)} if arg.HasField("register"): return {arg.name: registers.register_from_proto(arg.register)} if arg.HasField("registers"): diff --git a/qualtran/serialization/bloq_test.py b/qualtran/serialization/bloq_test.py index 485298ede..fffd55bed 100644 --- a/qualtran/serialization/bloq_test.py +++ b/qualtran/serialization/bloq_test.py @@ -30,6 +30,7 @@ from qualtran.protos import registers_pb2 from qualtran.serialization import bloq as bloq_serialization from qualtran.serialization import resolver_dict +from qualtran.serialization.bloq import arg_from_proto @pytest.mark.parametrize( @@ -123,7 +124,7 @@ def test_cbloq_to_proto_test_two_cswap(): assert cswap_proto.name.split('.')[-1] == "TestCSwap" assert len(cswap_proto.args) == 1 assert cswap_proto.args[0].name == "bitsize" - assert sympy.parse_expr(cswap_proto.args[0].sympy_expr) == bitsize + assert arg_from_proto(cswap_proto.args[0])['bitsize'] == bitsize assert len(cswap_proto.registers.registers) == 3 assert TestCSwap(bitsize) in bloq_serialization.bloqs_from_proto(cswap_proto_lib) diff --git a/qualtran/serialization/sympy.py b/qualtran/serialization/sympy.py new file mode 100644 index 000000000..d5e5b8f59 --- /dev/null +++ b/qualtran/serialization/sympy.py @@ -0,0 +1,230 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Union + +import sympy + +from qualtran.protos import sympy_pb2 + + +def _get_sympy_function_type(expr: sympy.Expr) -> int: + """ + Helper function for serializing a sympy function. + + This method converts a sympy function to its sympy_pb2.Function enum representation. + """ + if isinstance(expr, sympy.core.mul.Mul): + return sympy_pb2.Function.Mul + elif isinstance(expr, sympy.core.add.Add): + return sympy_pb2.Function.Add + elif isinstance(expr, sympy.core.power.Pow): + return sympy_pb2.Function.Pow + elif isinstance(expr, sympy.core.Mod): + return sympy_pb2.Function.Mod + elif isinstance(expr, sympy.functions.elementary.exponential.log): + return sympy_pb2.Function.Log + elif isinstance(expr, sympy.functions.elementary.integers.floor): + return sympy_pb2.Function.Floor + elif isinstance(expr, sympy.functions.elementary.integers.ceiling): + return sympy_pb2.Function.Ceiling + elif isinstance(expr, sympy.functions.elementary.miscellaneous.Max): + return sympy_pb2.Function.Max + elif isinstance(expr, sympy.functions.elementary.miscellaneous.Min): + return sympy_pb2.Function.Min + elif isinstance(expr, sympy.functions.elementary.trigonometric.sin): + return sympy_pb2.Function.Sin + elif isinstance(expr, sympy.functions.elementary.trigonometric.cos): + return sympy_pb2.Function.Cos + elif isinstance(expr, sympy.functions.elementary.trigonometric.tan): + return sympy_pb2.Function.Tan + else: + return sympy_pb2.Function.NONE + + +def _get_sympy_function_from_enum(enum: int) -> Any: + """ + Helper function for sympy function deserialization. + + Sympy functions are represented as a sympy_pb2.Function enum. This method converts + this int enum. + """ + enum_to_sympy = { + sympy_pb2.Function.Mul: sympy.core.mul.Mul, + sympy_pb2.Function.Add: sympy.core.add.Add, + sympy_pb2.Function.Pow: sympy.core.power.Pow, + sympy_pb2.Function.Mod: sympy.core.Mod, + sympy_pb2.Function.Log: sympy.functions.elementary.exponential.log, + sympy_pb2.Function.Floor: sympy.functions.elementary.integers.floor, + sympy_pb2.Function.Ceiling: sympy.functions.elementary.integers.ceiling, + sympy_pb2.Function.Max: sympy.functions.elementary.miscellaneous.Max, + sympy_pb2.Function.Min: sympy.functions.elementary.miscellaneous.Min, + sympy_pb2.Function.Sin: sympy.functions.elementary.trigonometric.sin, + sympy_pb2.Function.Cos: sympy.functions.elementary.trigonometric.cos, + sympy_pb2.Function.Tan: sympy.functions.elementary.trigonometric.tan, + sympy_pb2.Function.NONE: None, + } + + return enum_to_sympy[enum] + + +def _get_sympy_const_from_enum(enum: int) -> Any: + """Helper function for deserializing a sympy symbolic constant. + + Symbolic constants are serialzed as an enum of type sympy_pb2.ConstSymbol. This method converts the + enum representation back to its original sympy representation. + """ + enum_to_sympy = { + sympy_pb2.ConstSymbol.Pi: sympy.pi, + sympy_pb2.ConstSymbol.E: sympy.E, + sympy_pb2.ConstSymbol.EulerGamma: sympy.EulerGamma, + sympy_pb2.ConstSymbol.Infinity: sympy.core.numbers.Infinity(), + sympy_pb2.ConstSymbol.ImaginaryUnit: sympy.core.numbers.ImaginaryUnit(), + } + return enum_to_sympy[enum] + + +def _get_const_symbolic_operand(expr: sympy.Expr) -> sympy_pb2.Parameter: + """ + Helper function for serializing a symbolic constant from a sympy expression. + + Currently supported symbolic constants are: pi, natural exponent, EulerGamma, sqrt(-1), and infinity. + """ + if expr == sympy.pi: + return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.Pi) + if expr == sympy.E: + return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.E) + if expr == sympy.EulerGamma: + return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.EulerGamma) + if isinstance(expr, sympy.core.numbers.Infinity): + return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.Infinity) + if isinstance(expr, sympy.core.numbers.ImaginaryUnit): + return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.ImaginaryUnit) + raise NotImplementedError(f"Sympy expression {str(expr)} cannot be serialized.") + + +def _get_sympy_operand(expr: Union[sympy.Expr, int, float]) -> sympy_pb2.Parameter: + """ + Converts the input to a serializable sympy_pb2 Parameter. + + A parameter represents a single, irreducable numeric entity such as a variable, constant, or an explicit number. + """ + + # Expression is a single, symbolic variable. + if isinstance(expr, sympy.core.symbol.Symbol): + return sympy_pb2.Parameter(symbol=str(expr)) + + # Expression is an integer + if issubclass(expr.__class__, sympy.core.numbers.Integer): + result = expr.numerator + if not isinstance(result, int): + raise NotImplementedError(f"Sympy expression {str(expr)} cannot be serialized.") + return sympy_pb2.Parameter(const_int=result) + + # Expression cannot be broken down further, but is a constant. + if issubclass(expr.__class__, sympy.core.numbers.Number): + if isinstance(expr, sympy.core.numbers.Float): + return sympy_pb2.Parameter(const_float=float(expr)) + if isinstance(expr, sympy.core.numbers.Rational): + numerator = _get_sympy_operand(expr.numerator) + denominator = _get_sympy_operand(expr.denominator) + fraction = sympy_pb2.Rational(numerator=numerator, denominator=denominator) + return sympy_pb2.Parameter(const_rat=fraction) + else: + raise NotImplementedError(f"Sympy expression {str(expr)} cannot be serialized.") + if type(expr) is int: + return sympy_pb2.Parameter(const_int=expr) + if type(expr) is float: + return sympy_pb2.Parameter(const_float=expr) + return _get_const_symbolic_operand(expr) + + +def sympy_expr_to_proto(expr: sympy.Expr) -> sympy_pb2.Term: + """Serializes a sympy expression.""" + + function = _get_sympy_function_type(expr) + operands = [] + if function == sympy_pb2.Function.NONE: + parameter = _get_sympy_operand(expr) + operands.append(sympy_pb2.Operand(parameter=parameter)) + + else: + for term in expr.args: + inner_term = sympy_expr_to_proto(term) + + operands.append(sympy_pb2.Operand(term=inner_term)) + + return sympy_pb2.Term(function=function, operands=operands) + + +def _get_parameter( + serialized_input: Union[sympy_pb2.Operand, sympy_pb2.Parameter] +) -> Union[sympy.core.AtomicExpr, int, float]: + """ + Deserializes a parameter. + + Deserializes a parameter or operand into either its sympy representation or its + python primative numeric representation. + """ + if isinstance(serialized_input, sympy_pb2.Operand): + serialized_parameter = serialized_input.parameter + else: + serialized_parameter = serialized_input + + parameter_type = serialized_parameter.WhichOneof("parameter") + if parameter_type == "symbol": + return sympy.symbols(serialized_parameter.symbol) + if parameter_type == "const_int": + return serialized_parameter.const_int + if parameter_type == "const_rat": + fraction = serialized_parameter.const_rat + numerator = _get_parameter(fraction.numerator) + denominator = _get_parameter(fraction.denominator) + return sympy.Rational(numerator, denominator) + if parameter_type == "const_float": + return serialized_parameter.const_float + if parameter_type == "const_symbol": + return _get_sympy_const_from_enum(serialized_parameter.const_symbol) + + raise TypeError(f"Type is not supported for {serialized_input}") + + +def sympy_expr_from_proto(term: sympy_pb2.Term) -> Union[sympy.core.AtomicExpr, int, float]: + """Deserialize a sympy expression. + + This will take a sympy_pb2.Term which will contain a function and + one or more operands. These operands can also be Terms which will cause + this method to be called recursively. + + An error will be raised if a Term contains a function that is not listed in + the sympy_pb2.Function enum in sympy.proto. Additionally, an error will be + raised if an operand cannot be converted into one of the following: a + sympy symbol, a sympy constant, a nested Term, or a native python numeric type. + """ + function = _get_sympy_function_from_enum(term.function) + parameters = [] + for operand in term.operands: + if operand.HasField("term"): + parameters.append(sympy_expr_from_proto(operand.term)) + else: + parameters.append(_get_parameter(operand)) + + if function: + return function(*parameters) + + # If a term has no function, then it must be a single parameter. + if len(parameters) == 1: + return parameters[0] + + raise NotImplementedError(f"{term.function} has not been fully implemented.") diff --git a/qualtran/serialization/sympy_test.py b/qualtran/serialization/sympy_test.py new file mode 100644 index 000000000..a7e63f774 --- /dev/null +++ b/qualtran/serialization/sympy_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy + +from qualtran.serialization.sympy import sympy_expr_from_proto, sympy_expr_to_proto + +x = sympy.Symbol('x', positive=True) +a, b, c = sympy.symbols("a b c") + +# These should return a `sympy_pb2.Parameter` proto object +sympy_parameters_to_test = [ + # Only symbols + sympy.Symbol('x'), + sympy.Symbol('N'), + sympy.Symbol('E'), + # Sympy constants + sympy.pi, + sympy.oo, # infinity + sympy.E, + sympy.I, + sympy.EulerGamma, + # Integers, Floats, Rationals + sympy.Integer(5), + sympy.Float(0.1), + sympy.Rational("1/2"), + sympy.Rational('1/10'), +] +sympy_exprs_to_test = [ + 5 * x + sympy.sqrt(a), + # Complex Fractions + sympy.Rational("1/10") * sympy.I + 5, + # Basic operations + a / b + c - 5, + # Trig operations + sympy.sin(a) + sympy.cos(b) + sympy.tan(c), + # Integer operations + sympy.floor(5.43) + sympy.ceiling(a), + sympy.Max(a, b), + # Nested Operations + a ** (b * c**2), +] + + +@pytest.mark.parametrize('expr', sympy_parameters_to_test + sympy_exprs_to_test) +def parameter_test(expr: sympy.Expr): + """ + Test types of expressions including fraction, complex, and constant symbol (such as pi). + """ + + serialized = sympy_expr_to_proto(expr) + expr_clone = sympy_expr_from_proto(serialized)['test'] + assert expr == expr_clone + + +def float_fraction_test(): + """ + Test that floats and fractions can be properly combined and serialzed. + """ + float_const = sympy.parse_expr("1.4") + fraction = sympy.parse_expr("1/2") + expr = float_const * fraction + + serialized = sympy_expr_to_proto(expr) + expr_clone = sympy_expr_from_proto(serialized)['test'] + assert abs(expr - expr_clone) < 0.001