diff --git a/frontends/p4/typeChecking/typeChecker.cpp b/frontends/p4/typeChecking/typeChecker.cpp index 6567d9204c..1725643802 100644 --- a/frontends/p4/typeChecking/typeChecker.cpp +++ b/frontends/p4/typeChecking/typeChecker.cpp @@ -1871,8 +1871,17 @@ const IR::Node* TypeInference::postorder(IR::Concat* expression) { expression->right); return expression; } - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (auto se = rtype->to()) rtype = getTypeType(se->type); + + bool castLeft = false; + bool castRight = false; + if (auto se = ltype->to()) { + ltype = getTypeType(se->type); + castLeft = true; + } + if (auto se = rtype->to()) { + rtype = getTypeType(se->type); + castRight = true; + } if (ltype == nullptr || rtype == nullptr) { // getTypeType should have already taken care of the error message return expression; @@ -1885,6 +1894,22 @@ const IR::Node* TypeInference::postorder(IR::Concat* expression) { auto bl = ltype->to(); auto br = rtype->to(); const IR::Type* resultType = IR::Type_Bits::get(bl->size + br->size, bl->isSigned); + + if (castLeft) { + auto e = expression->clone(); + e->left = new IR::Cast(e->left->srcInfo, bl, e->left); + if (isCompileTimeConstant(expression->left)) setCompileTimeConstant(e->left); + setType(e->left, ltype); + expression = e; + } + if (castRight) { + auto e = expression->clone(); + e->right = new IR::Cast(e->right->srcInfo, br, e->right); + if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(e->right); + setType(e->right, rtype); + expression = e; + } + resultType = canonicalize(resultType); if (resultType != nullptr) { setType(getOriginal(), resultType); @@ -2236,9 +2261,17 @@ const IR::Node* TypeInference::binaryArith(const IR::Operation_Binary* expressio auto ltype = getType(expression->left); auto rtype = getType(expression->right); if (ltype == nullptr || rtype == nullptr) return expression; + bool castLeft = false; + bool castRight = false; - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (auto se = rtype->to()) rtype = getTypeType(se->type); + if (auto se = ltype->to()) { + ltype = getTypeType(se->type); + castLeft = true; + } + if (auto se = rtype->to()) { + rtype = getTypeType(se->type); + castRight = true; + } BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); const IR::Type_Bits* bl = ltype->to(); @@ -2272,23 +2305,37 @@ const IR::Node* TypeInference::binaryArith(const IR::Operation_Binary* expressio typeError("%1%: Cannot operate on values with different signs", expression); return expression; } - } else if (bl == nullptr && br != nullptr) { + } + if ((bl == nullptr && br != nullptr) || castLeft) { + // must insert cast on the left + auto leftResultType = br; + if (castLeft && !br) leftResultType = bl; auto e = expression->clone(); - e->left = new IR::Cast(e->left->srcInfo, br, e->left); - setType(e->left, rtype); + e->left = new IR::Cast(e->left->srcInfo, leftResultType, e->left); + setType(e->left, leftResultType); + if (isCompileTimeConstant(expression->left)) { + e->left = constantFold(e->left); + setCompileTimeConstant(e->left); + setType(e->left, leftResultType); + } expression = e; - resultType = rtype; - setType(expression, resultType); - } else if (bl != nullptr && br == nullptr) { + resultType = leftResultType; + } + if ((bl != nullptr && br == nullptr) || castRight) { auto e = expression->clone(); - e->right = new IR::Cast(e->right->srcInfo, bl, e->right); - setType(e->right, ltype); + auto rightResultType = bl; + if (castRight && !bl) rightResultType = br; + e->right = new IR::Cast(e->right->srcInfo, rightResultType, e->right); + setType(e->right, rightResultType); + if (isCompileTimeConstant(expression->right)) { + e->right = constantFold(e->right); + setCompileTimeConstant(e->right); + setType(e->right, rightResultType); + } expression = e; - resultType = ltype; - setType(expression, resultType); - } else { - setType(expression, resultType); + resultType = rightResultType; } + setType(getOriginal(), resultType); setType(expression, resultType); if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { @@ -2409,71 +2456,6 @@ const IR::Node* TypeInference::shift(const IR::Operation_Binary* expression) { return expression; } -const IR::Node* TypeInference::bitwise(const IR::Operation_Binary* expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (auto se = rtype->to()) rtype = getTypeType(se->type); - BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); - - const IR::Type_Bits* bl = ltype->to(); - const IR::Type_Bits* br = rtype->to(); - if (bl == nullptr && !ltype->is()) { - typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->left, ltype->toString()); - return expression; - } else if (br == nullptr && !rtype->is()) { - typeError("%1%: cannot be applied to expressio '%2%' with type '%3%'", - expression->getStringOp(), expression->right, rtype->toString()); - return expression; - } else if (ltype->is() && rtype->is()) { - auto t = new IR::Type_InfInt(); - setType(getOriginal(), t); - auto result = constantFold(expression); - setType(result, t); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - - const IR::Type* resultType = ltype; - if (bl != nullptr && br != nullptr) { - if (!typeMap->equivalent(bl, br)) { - typeError("%1%: Cannot operate on values with different types %2% and %3%", expression, - bl->toString(), br->toString()); - return expression; - } - } else if (bl == nullptr && br != nullptr) { - auto e = expression->clone(); - auto cst = expression->left->to(); - CHECK_NULL(cst); - e->left = new IR::Constant(cst->srcInfo, rtype, cst->value, cst->base); - setType(e->left, rtype); - setCompileTimeConstant(e->left); - expression = e; - resultType = rtype; - } else if (bl != nullptr && br == nullptr) { - auto e = expression->clone(); - auto cst = expression->right->to(); - CHECK_NULL(cst); - e->right = new IR::Constant(cst->srcInfo, ltype, cst->value, cst->base); - setType(e->right, ltype); - setCompileTimeConstant(e->right); - expression = e; - resultType = ltype; - } - setType(expression, resultType); - setType(getOriginal(), resultType); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - // Handle .. and &&& const IR::Node* TypeInference::typeSet(const IR::Operation_Binary* expression) { if (done()) return expression; diff --git a/frontends/p4/typeChecking/typeChecker.h b/frontends/p4/typeChecking/typeChecker.h index ffb206f60b..3044346bd1 100644 --- a/frontends/p4/typeChecking/typeChecker.h +++ b/frontends/p4/typeChecking/typeChecker.h @@ -165,7 +165,6 @@ class TypeInference : public Transform { const IR::Node* binaryArith(const IR::Operation_Binary* op); const IR::Node* unsBinaryArith(const IR::Operation_Binary* op); const IR::Node* shift(const IR::Operation_Binary* op); - const IR::Node* bitwise(const IR::Operation_Binary* op); const IR::Node* typeSet(const IR::Operation_Binary* op); const IR::Type* cloneWithFreshTypeVariables(const IR::IMayBeGenericType* type); @@ -288,9 +287,9 @@ class TypeInference : public Transform { const IR::Node* postorder(IR::Mod* expression) override { return unsBinaryArith(expression); } const IR::Node* postorder(IR::Shl* expression) override { return shift(expression); } const IR::Node* postorder(IR::Shr* expression) override { return shift(expression); } - const IR::Node* postorder(IR::BXor* expression) override { return bitwise(expression); } - const IR::Node* postorder(IR::BAnd* expression) override { return bitwise(expression); } - const IR::Node* postorder(IR::BOr* expression) override { return bitwise(expression); } + const IR::Node* postorder(IR::BXor* expression) override { return binaryArith(expression); } + const IR::Node* postorder(IR::BAnd* expression) override { return binaryArith(expression); } + const IR::Node* postorder(IR::BOr* expression) override { return binaryArith(expression); } const IR::Node* postorder(IR::Mask* expression) override { return typeSet(expression); } const IR::Node* postorder(IR::Range* expression) override { return typeSet(expression); } const IR::Node* postorder(IR::LNot* expression) override; diff --git a/testdata/p4_16_errors_outputs/binary_e.p4-stderr b/testdata/p4_16_errors_outputs/binary_e.p4-stderr index 1589cbd2c7..228dfd38c9 100644 --- a/testdata/p4_16_errors_outputs/binary_e.p4-stderr +++ b/testdata/p4_16_errors_outputs/binary_e.p4-stderr @@ -4,7 +4,7 @@ binary_e.p4(32): [--Werror=type-error] error: Indexing a[2] applied to non-array binary_e.p4(33): [--Werror=type-error] error: Array index d must be an integer, but it has type bool c = stack[d]; // indexing with bool ^ -binary_e.p4(35): [--Werror=type-error] error: &: Cannot operate on values with different types bit<2> and bit<4> +binary_e.p4(35): [--Werror=type-error] error: &: Cannot operate on values with different widths 2 and 4 f = e & f; // different width ^^^^^ binary_e.p4(38): [--Werror=type-error] error: <: not defined on bool and bool diff --git a/testdata/p4_16_samples/enumCast.p4 b/testdata/p4_16_samples/enumCast.p4 index 7d5b6542f6..a08b8d6dfc 100644 --- a/testdata/p4_16_samples/enumCast.p4 +++ b/testdata/p4_16_samples/enumCast.p4 @@ -41,7 +41,7 @@ parser p(packet_in packet, out O o) { bb = bb && (b == 0); a = (E1) b; // OK - a = (E1)(E1.e1 + 1); // Final explicit casting makes the assinment legal + a = (E1)(E1.e1 + 1); // Final explicit casting makes the assignment legal a = (E1)(E2.e1 + E2.e2); // Final explicit casting makes the assignment legal packet.extract(o.b); diff --git a/testdata/p4_16_samples/issue3635.p4 b/testdata/p4_16_samples/issue3635.p4 new file mode 100644 index 0000000000..ee2d0aa0b1 --- /dev/null +++ b/testdata/p4_16_samples/issue3635.p4 @@ -0,0 +1,16 @@ +enum bit<4> e +{ + a = 1, + b = 2, + c = 3 +} +void f() +{ + bit<8> good1; + good1 = e.a ++ e.b; +} + +const bit<8> good = ((bit<4>)e.a) ++ ((bit<4>)e.b); +const bit<4> bad = e.a + e.b; +const bit<8> bad1 = e.a ++ e.b; +const bit<4> bad2 = e.a & e.b; diff --git a/testdata/p4_16_samples_outputs/enumCast-first.p4 b/testdata/p4_16_samples_outputs/enumCast-first.p4 index a3d15cefc2..f38360a66f 100644 --- a/testdata/p4_16_samples_outputs/enumCast-first.p4 +++ b/testdata/p4_16_samples_outputs/enumCast-first.p4 @@ -42,8 +42,8 @@ parser p(packet_in packet, out O o) { bb = bb && a == 8w0; bb = bb && b == 8w0; a = (E1)b; - a = (E1)(E1.e1 + 8w1); - a = (E1)(E2.e1 + E2.e2); + a = (E1)8w1; + a = (E1)8w21; packet.extract(o.b); transition select(o.b.x) { X.Zero &&& 32w0x1: accept; diff --git a/testdata/p4_16_samples_outputs/issue3056-first.p4 b/testdata/p4_16_samples_outputs/issue3056-first.p4 index 46f00c1735..f56f49072c 100644 --- a/testdata/p4_16_samples_outputs/issue3056-first.p4 +++ b/testdata/p4_16_samples_outputs/issue3056-first.p4 @@ -10,9 +10,9 @@ control compute() { bit<8> x2 = -e; bit<4> x3 = e[3:0]; bool x4 = 8w0 == e; - bit<8> x5 = e; + bit<8> x5 = (bit<8>)e; bit<8> x6 = e << 3; - bit<16> x7 = e ++ 8w0; + bit<16> x7 = (bit<8>)e ++ 8w0; bit<4> x8 = 4w0; } } diff --git a/testdata/p4_16_samples_outputs/issue3635-first.p4 b/testdata/p4_16_samples_outputs/issue3635-first.p4 new file mode 100644 index 0000000000..f7c5fc8cb2 --- /dev/null +++ b/testdata/p4_16_samples_outputs/issue3635-first.p4 @@ -0,0 +1,14 @@ +enum bit<4> e { + a = 4w1, + b = 4w2, + c = 4w3 +} + +void f() { + bit<8> good1; + good1 = 8w18; +} +const bit<8> good = 8w18; +const bit<4> bad = 4w3; +const bit<8> bad1 = 8w18; +const bit<4> bad2 = 4w0; diff --git a/testdata/p4_16_samples_outputs/issue3635-frontend.p4 b/testdata/p4_16_samples_outputs/issue3635-frontend.p4 new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testdata/p4_16_samples_outputs/issue3635.p4 b/testdata/p4_16_samples_outputs/issue3635.p4 new file mode 100644 index 0000000000..56c98c6efb --- /dev/null +++ b/testdata/p4_16_samples_outputs/issue3635.p4 @@ -0,0 +1,14 @@ +enum bit<4> e { + a = 1, + b = 2, + c = 3 +} + +void f() { + bit<8> good1; + good1 = e.a ++ e.b; +} +const bit<8> good = (bit<4>)e.a ++ (bit<4>)e.b; +const bit<4> bad = e.a + e.b; +const bit<8> bad1 = e.a ++ e.b; +const bit<4> bad2 = e.a & e.b; diff --git a/testdata/p4_16_samples_outputs/issue3635.p4-stderr b/testdata/p4_16_samples_outputs/issue3635.p4-stderr new file mode 100644 index 0000000000..7e57a518ff --- /dev/null +++ b/testdata/p4_16_samples_outputs/issue3635.p4-stderr @@ -0,0 +1 @@ +[--Wwarn=missing] warning: Program does not contain a `main' module