-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sympy Serialization #850
Conversation
unit tests to ensure round trip of each datatype using both numeric and sympy inputs.
Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
Sympy expressions can now be serialized through a recursive proto.
The qpe_hubbard_model is too large to test and causes developer friction. We have disabled the tests for now
These tests are very large and will cause problems when running. These tests are now disabled.
Move the changes from bloq_finder.py to bloq_report_card.py. This way, only the serialization portion of the test is affected.
Rather than building on top of quantumlib#849, we keep the changes separate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more tests, refactor the sympy serialization code into it's module and rename things for consistency.
qualtran/serialization/bloq.py
Outdated
def decompose_sympy(expr: sympy.Expr): | ||
function = _get_sympy_function_type(expr) | ||
operands = [] | ||
if function == sympy_pb2.Function.NONE: | ||
parameter = _get_sympy_operand(expr) | ||
operands.append(sympy_pb2.Operand(parameter=parameter)) | ||
|
||
else: | ||
for term in expr.args: | ||
inner_term = decompose_sympy(term) | ||
|
||
operands.append(sympy_pb2.Operand(term=inner_term)) | ||
|
||
return sympy_pb2.Term(function=function, operands=operands) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add all the sympy related method to a new file qualtran/serialization/sympy.py
and add corresponding tests to qualtran/serialization/sympy_test.py
Also, please add roundtrip unit tests to convert different types of sympy expression to / from proto. There are virtually no tests for sympy serialization verification right now.
qualtran/serialization/bloq.py
Outdated
return sympy_pb2.Parameter(const_irrat=expr.name) | ||
|
||
|
||
def decompose_sympy(expr: sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we call this method sympy_expr_to_proto
? Also add return type annotations
def decompose_sympy(expr: sympy.Expr): | |
def sympy_expr_to_proto(expr: sympy.Expr) -> sympy_pb2.Term: |
qualtran/serialization/bloq.py
Outdated
return function(parameters) | ||
|
||
|
||
def compose_sympy(expr: sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type annotations are incorrect and the expr
should be a sympy_pb2.Term
? Please also rename the method to sympy_expr_from_proto
to keep the naming consistent with other serialization methods. Also add return type annotations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, left another round of comments. We are getting close I think. Thanks for the patience and the iterations!
qualtran/bloqs/factoring/mod_mul.py
Outdated
def namespace(self) -> str: | ||
return "qualtran." + self.__module__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be needed now?
In [5]: from qualtran.bloqs.factoring.mod_mul import CtrlModMul
In [6]: CtrlModMul(5, 10, 10).__module__
Out[6]: 'qualtran.bloqs.factoring.mod_mul'
qualtran/serialization/sympy.py
Outdated
else: | ||
raise NotImplementedError(f"{term.function} has not been fully implimented.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
raise NotImplementedError(f"{term.function} has not been fully implimented.") | |
raise NotImplementedError(f"{term.function} has not been fully implemented.") |
qualtran/serialization/sympy.py
Outdated
return deserialized_parameter | ||
|
||
|
||
def sympy_expr_from_proto(term: sympy_pb2.Term) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a more concrete return type instead of Any
?
qualtran/serialization/sympy.py
Outdated
|
||
|
||
def sympy_expr_from_proto(term: sympy_pb2.Term) -> Any: | ||
"""Deserialize a sympy expression.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the caveats to the docstring. When does it raise an error? Which terms can it return?
qualtran/protos/sympy.proto
Outdated
} | ||
|
||
// Represents a constant, rational number. | ||
message Fraction { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this to Rational
for consistency with sympy
qualtran/serialization/sympy.py
Outdated
return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.Infinity) | ||
if isinstance(expr, sympy.core.numbers.ImaginaryUnit): | ||
return sympy_pb2.Parameter(const_symbol=sympy_pb2.ConstSymbol.ImaginaryUnit) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Else is not needed.
else: |
qualtran/serialization/sympy.py
Outdated
if parameter_type == "symbol": | ||
deserialized_parameter = sympy.symbols(serialized_parameter.symbol) | ||
elif parameter_type == "const_int": | ||
deserialized_parameter = serialized_parameter.const_int | ||
elif parameter_type == "const_rat": | ||
fraction = serialized_parameter.const_rat | ||
numerator = _get_parameter(fraction.numerator) | ||
denominator = _get_parameter(fraction.denominator) | ||
deserialized_parameter = sympy.Rational(numerator, denominator) | ||
elif parameter_type == "const_float": | ||
deserialized_parameter = serialized_parameter.const_float | ||
elif parameter_type == "const_symbol": | ||
deserialized_parameter = _get_sympy_const_from_enum(serialized_parameter.const_symbol) | ||
else: | ||
raise TypeError(f"Type is not supported for {serialized_input}") | ||
|
||
return deserialized_parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly return the deserialized_parameter
and raise a TypeError if none of the if
clauses are satisfied.
if parameter_type == "symbol": | |
deserialized_parameter = sympy.symbols(serialized_parameter.symbol) | |
elif parameter_type == "const_int": | |
deserialized_parameter = serialized_parameter.const_int | |
elif parameter_type == "const_rat": | |
fraction = serialized_parameter.const_rat | |
numerator = _get_parameter(fraction.numerator) | |
denominator = _get_parameter(fraction.denominator) | |
deserialized_parameter = sympy.Rational(numerator, denominator) | |
elif parameter_type == "const_float": | |
deserialized_parameter = serialized_parameter.const_float | |
elif parameter_type == "const_symbol": | |
deserialized_parameter = _get_sympy_const_from_enum(serialized_parameter.const_symbol) | |
else: | |
raise TypeError(f"Type is not supported for {serialized_input}") | |
return deserialized_parameter | |
if parameter_type == "symbol": | |
return sympy.symbols(serialized_parameter.symbol) | |
if parameter_type == "const_int": | |
return serialized_parameter.const_int | |
if parameter_type == "const_rat": | |
fraction = serialized_parameter.const_rat | |
numerator = _get_parameter(fraction.numerator) | |
denominator = _get_parameter(fraction.denominator) | |
return sympy.Rational(numerator, denominator) | |
if parameter_type == "const_float": | |
return serialized_parameter.const_float | |
if parameter_type == "const_symbol": | |
return _get_sympy_const_from_enum(serialized_parameter.const_symbol) | |
raise TypeError(f"Type is not supported for {serialized_input}") | |
return deserialized_parameter |
qualtran/serialization/sympy_test.py
Outdated
@pytest.mark.parametrize( | ||
'expr', | ||
[ | ||
( | ||
sympy.parse_expr("5") | ||
+ sympy.symbols("x") | ||
+ sympy.parse_expr("1/2") | ||
+ sympy.pi | ||
+ sympy.parse_expr("2j") | ||
), | ||
(sympy.parse_expr("(-b + sqrt(-4*a*c + b**2))/(2*a)")), | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not depend upon parse_expr
for tests. Instead, let's create a list of expressions; potentially generated by small functions; which we'd like to test and parameterize over it. For example:
@pytest.mark.parametrize( | |
'expr', | |
[ | |
( | |
sympy.parse_expr("5") | |
+ sympy.symbols("x") | |
+ sympy.parse_expr("1/2") | |
+ sympy.pi | |
+ sympy.parse_expr("2j") | |
), | |
(sympy.parse_expr("(-b + sqrt(-4*a*c + b**2))/(2*a)")), | |
], | |
) | |
x, N = sympy.Symbol('x', positive=True), sympy.Symbol('N') | |
# These should return a `sympy_pb2.Parameter` proto object? | |
sympy_parameters_to_test = [ | |
# Only symbols | |
sympy.Symbol('x'), | |
sympy.Symbol('N'), | |
sympy.Symbol('E'), | |
# Sympy constants | |
sympy.pi, sympy.Infinity, sympy.E, sympy.I # etc. | |
# Integers, Floats, Rationals | |
sympy.Integer(5), | |
sympy.Float(0.1), | |
sympy.Rational("1/2"), | |
sympy.Rational('1/10'), | |
] | |
sympy_expr_to_test = [ | |
5 * x + sympy.sqrt(N), | |
# Add different types of functions that are supported, maybe in a single expression or multiple expressions. | |
] | |
@pytest.mark.parametrize('expr', sympy_parameters_to_test + sympy_exprs_to_test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a final round of comments. We can merge once these are resolved.
Call sympy serialization directly rather than calling it through bloq.
Addressed nit comments. Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
Adds Protos to hold sympy expressions an a nested structure. Currently the following basic operations are supported: multiplication, addition, power, and mod. More operations will be added in a follow-up cl.