From 95dc7a69b5647608fa6824332e09db0b6b64a72f Mon Sep 17 00:00:00 2001 From: John Siirola Date: Thu, 31 Aug 2023 09:54:41 -0600 Subject: [PATCH 1/3] Ensure templatize_constraint always returns an expression --- pyomo/core/expr/relational_expr.py | 7 +++++++ pyomo/core/expr/template_expr.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyomo/core/expr/relational_expr.py b/pyomo/core/expr/relational_expr.py index 2909be95c5a..6e4831d5c0c 100644 --- a/pyomo/core/expr/relational_expr.py +++ b/pyomo/core/expr/relational_expr.py @@ -458,3 +458,10 @@ def _generate_relational_expression(etype, lhs, rhs): ) else: return InequalityExpression((lhs, rhs), _relational_op[etype][2]) + + +def tuple_to_relational_expr(args): + if len(args) == 2: + return EqualityExpression(args) + else: + return inequality(*args) diff --git a/pyomo/core/expr/template_expr.py b/pyomo/core/expr/template_expr.py index 4682844fef0..5009a2497d7 100644 --- a/pyomo/core/expr/template_expr.py +++ b/pyomo/core/expr/template_expr.py @@ -34,6 +34,7 @@ value, is_constant, ) +from pyomo.core.expr.relational_expr import tuple_to_relational_expr from pyomo.core.expr.visitor import ( ExpressionReplacementVisitor, StreamBasedExpressionVisitor, @@ -1173,4 +1174,7 @@ def templatize_rule(block, rule, index_set): def templatize_constraint(con): - return templatize_rule(con.parent_block(), con.rule, con.index_set()) + expr, indices = templatize_rule(con.parent_block(), con.rule, con.index_set()) + if expr.__class__ is tuple: + expr = tuple_to_relational_expr(expr) + return expr, indices From 677289a3a94b80bff23c26ce6209f724db7b0638 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Thu, 31 Aug 2023 09:54:48 -0600 Subject: [PATCH 2/3] Add missing import --- pyomo/core/expr/template_expr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyomo/core/expr/template_expr.py b/pyomo/core/expr/template_expr.py index 5009a2497d7..6103e9c429c 100644 --- a/pyomo/core/expr/template_expr.py +++ b/pyomo/core/expr/template_expr.py @@ -25,6 +25,7 @@ Numeric_NPV_Mixin, register_arg_type, ARG_TYPE, + _balanced_parens, ) from pyomo.core.expr.numvalue import ( NumericValue, From ecbd7095ef12c615cb26f2b612ba4ec947b4e15d Mon Sep 17 00:00:00 2001 From: John Siirola Date: Thu, 31 Aug 2023 10:00:35 -0600 Subject: [PATCH 3/3] Add tests for templatizing rules returning tuples --- pyomo/core/tests/unit/test_template_expr.py | 47 +++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pyomo/core/tests/unit/test_template_expr.py b/pyomo/core/tests/unit/test_template_expr.py index 069acc907cc..4b4ea494b0e 100644 --- a/pyomo/core/tests/unit/test_template_expr.py +++ b/pyomo/core/tests/unit/test_template_expr.py @@ -353,6 +353,53 @@ def c(m, i): indices[0].set_value(2) self.assertEqual(str(resolve_template(template)), 'x[2] <= 0') + def test_tuple_rules(self): + m = ConcreteModel() + m.I = RangeSet(3) + m.x = Var(m.I) + + @m.Constraint(m.I) + def c(m, i): + return (None, m.x[i], 0) + + template, indices = templatize_constraint(m.c) + self.assertEqual(len(indices), 1) + self.assertIs(indices[0]._set, m.I) + self.assertEqual(str(template), "x[_1] <= 0") + # Test that the RangeSet iterator was put back + self.assertEqual(list(m.I), list(range(1, 4))) + # Evaluate the template + indices[0].set_value(2) + self.assertEqual(str(resolve_template(template)), 'x[2] <= 0') + + @m.Constraint(m.I) + def d(m, i): + return (0, m.x[i], 10) + + template, indices = templatize_constraint(m.d) + self.assertEqual(len(indices), 1) + self.assertIs(indices[0]._set, m.I) + self.assertEqual(str(template), "0 <= x[_1] <= 10") + # Test that the RangeSet iterator was put back + self.assertEqual(list(m.I), list(range(1, 4))) + # Evaluate the template + indices[0].set_value(2) + self.assertEqual(str(resolve_template(template)), '0 <= x[2] <= 10') + + @m.Constraint(m.I) + def e(m, i): + return (m.x[i], 0) + + template, indices = templatize_constraint(m.e) + self.assertEqual(len(indices), 1) + self.assertIs(indices[0]._set, m.I) + self.assertEqual(str(template), "x[_1] == 0") + # Test that the RangeSet iterator was put back + self.assertEqual(list(m.I), list(range(1, 4))) + # Evaluate the template + indices[0].set_value(2) + self.assertEqual(str(resolve_template(template)), 'x[2] == 0') + def test_simple_rule_nonfinite_set(self): m = ConcreteModel() m.x = Var(Integers, dense=False)