Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arithmetic and bitwise operations are typechecked the same way #3716

Merged
merged 2 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 63 additions & 81 deletions frontends/p4/typeChecking/typeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1871,8 +1871,17 @@ const IR::Node* TypeInference::postorder(IR::Concat* expression) {
expression->right);
return expression;
}
if (auto se = ltype->to<IR::Type_SerEnum>()) ltype = getTypeType(se->type);
if (auto se = rtype->to<IR::Type_SerEnum>()) rtype = getTypeType(se->type);

bool castLeft = false;
bool castRight = false;
if (auto se = ltype->to<IR::Type_SerEnum>()) {
ltype = getTypeType(se->type);
castLeft = true;
}
if (auto se = rtype->to<IR::Type_SerEnum>()) {
rtype = getTypeType(se->type);
castRight = true;
}
if (ltype == nullptr || rtype == nullptr) {
// getTypeType should have already taken care of the error message
return expression;
Expand All @@ -1885,6 +1894,22 @@ const IR::Node* TypeInference::postorder(IR::Concat* expression) {
auto bl = ltype->to<IR::Type_Bits>();
auto br = rtype->to<IR::Type_Bits>();
const IR::Type* resultType = IR::Type_Bits::get(bl->size + br->size, bl->isSigned);

if (castLeft) {
auto e = expression->clone();
e->left = new IR::Cast(e->left->srcInfo, bl, e->left);
if (isCompileTimeConstant(expression->left)) setCompileTimeConstant(e->left);
setType(e->left, ltype);
expression = e;
}
if (castRight) {
auto e = expression->clone();
e->right = new IR::Cast(e->right->srcInfo, br, e->right);
if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(e->right);
setType(e->right, rtype);
expression = e;
}

resultType = canonicalize(resultType);
if (resultType != nullptr) {
setType(getOriginal(), resultType);
Expand Down Expand Up @@ -2236,9 +2261,17 @@ const IR::Node* TypeInference::binaryArith(const IR::Operation_Binary* expressio
auto ltype = getType(expression->left);
auto rtype = getType(expression->right);
if (ltype == nullptr || rtype == nullptr) return expression;
bool castLeft = false;
bool castRight = false;

if (auto se = ltype->to<IR::Type_SerEnum>()) ltype = getTypeType(se->type);
if (auto se = rtype->to<IR::Type_SerEnum>()) rtype = getTypeType(se->type);
if (auto se = ltype->to<IR::Type_SerEnum>()) {
ltype = getTypeType(se->type);
castLeft = true;
}
if (auto se = rtype->to<IR::Type_SerEnum>()) {
rtype = getTypeType(se->type);
castRight = true;
}
BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType");

const IR::Type_Bits* bl = ltype->to<IR::Type_Bits>();
Expand Down Expand Up @@ -2272,23 +2305,37 @@ const IR::Node* TypeInference::binaryArith(const IR::Operation_Binary* expressio
typeError("%1%: Cannot operate on values with different signs", expression);
return expression;
}
} else if (bl == nullptr && br != nullptr) {
}
if ((bl == nullptr && br != nullptr) || castLeft) {
// must insert cast on the left
auto leftResultType = br;
if (castLeft && !br) leftResultType = bl;
auto e = expression->clone();
e->left = new IR::Cast(e->left->srcInfo, br, e->left);
setType(e->left, rtype);
e->left = new IR::Cast(e->left->srcInfo, leftResultType, e->left);
setType(e->left, leftResultType);
if (isCompileTimeConstant(expression->left)) {
e->left = constantFold(e->left);
setCompileTimeConstant(e->left);
setType(e->left, leftResultType);
}
expression = e;
resultType = rtype;
setType(expression, resultType);
} else if (bl != nullptr && br == nullptr) {
resultType = leftResultType;
}
if ((bl != nullptr && br == nullptr) || castRight) {
auto e = expression->clone();
e->right = new IR::Cast(e->right->srcInfo, bl, e->right);
setType(e->right, ltype);
auto rightResultType = bl;
if (castRight && !bl) rightResultType = br;
e->right = new IR::Cast(e->right->srcInfo, rightResultType, e->right);
setType(e->right, rightResultType);
if (isCompileTimeConstant(expression->right)) {
e->right = constantFold(e->right);
setCompileTimeConstant(e->right);
setType(e->right, rightResultType);
}
expression = e;
resultType = ltype;
setType(expression, resultType);
} else {
setType(expression, resultType);
resultType = rightResultType;
}

setType(getOriginal(), resultType);
setType(expression, resultType);
if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) {
Expand Down Expand Up @@ -2409,71 +2456,6 @@ const IR::Node* TypeInference::shift(const IR::Operation_Binary* expression) {
return expression;
}

const IR::Node* TypeInference::bitwise(const IR::Operation_Binary* expression) {
if (done()) return expression;
auto ltype = getType(expression->left);
auto rtype = getType(expression->right);
if (ltype == nullptr || rtype == nullptr) return expression;

if (auto se = ltype->to<IR::Type_SerEnum>()) ltype = getTypeType(se->type);
if (auto se = rtype->to<IR::Type_SerEnum>()) rtype = getTypeType(se->type);
BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType");

const IR::Type_Bits* bl = ltype->to<IR::Type_Bits>();
const IR::Type_Bits* br = rtype->to<IR::Type_Bits>();
if (bl == nullptr && !ltype->is<IR::Type_InfInt>()) {
typeError("%1%: cannot be applied to expression '%2%' with type '%3%'",
expression->getStringOp(), expression->left, ltype->toString());
return expression;
} else if (br == nullptr && !rtype->is<IR::Type_InfInt>()) {
typeError("%1%: cannot be applied to expressio '%2%' with type '%3%'",
expression->getStringOp(), expression->right, rtype->toString());
return expression;
} else if (ltype->is<IR::Type_InfInt>() && rtype->is<IR::Type_InfInt>()) {
auto t = new IR::Type_InfInt();
setType(getOriginal(), t);
auto result = constantFold(expression);
setType(result, t);
setCompileTimeConstant(result);
setCompileTimeConstant(getOriginal<IR::Expression>());
return result;
}

const IR::Type* resultType = ltype;
if (bl != nullptr && br != nullptr) {
if (!typeMap->equivalent(bl, br)) {
typeError("%1%: Cannot operate on values with different types %2% and %3%", expression,
bl->toString(), br->toString());
return expression;
}
} else if (bl == nullptr && br != nullptr) {
auto e = expression->clone();
auto cst = expression->left->to<IR::Constant>();
CHECK_NULL(cst);
e->left = new IR::Constant(cst->srcInfo, rtype, cst->value, cst->base);
setType(e->left, rtype);
setCompileTimeConstant(e->left);
expression = e;
resultType = rtype;
} else if (bl != nullptr && br == nullptr) {
auto e = expression->clone();
auto cst = expression->right->to<IR::Constant>();
CHECK_NULL(cst);
e->right = new IR::Constant(cst->srcInfo, ltype, cst->value, cst->base);
setType(e->right, ltype);
setCompileTimeConstant(e->right);
expression = e;
resultType = ltype;
}
setType(expression, resultType);
setType(getOriginal(), resultType);
if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) {
setCompileTimeConstant(expression);
setCompileTimeConstant(getOriginal<IR::Expression>());
}
return expression;
}

// Handle .. and &&&
const IR::Node* TypeInference::typeSet(const IR::Operation_Binary* expression) {
if (done()) return expression;
Expand Down
7 changes: 3 additions & 4 deletions frontends/p4/typeChecking/typeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class TypeInference : public Transform {
const IR::Node* binaryArith(const IR::Operation_Binary* op);
const IR::Node* unsBinaryArith(const IR::Operation_Binary* op);
const IR::Node* shift(const IR::Operation_Binary* op);
const IR::Node* bitwise(const IR::Operation_Binary* op);
const IR::Node* typeSet(const IR::Operation_Binary* op);

const IR::Type* cloneWithFreshTypeVariables(const IR::IMayBeGenericType* type);
Expand Down Expand Up @@ -288,9 +287,9 @@ class TypeInference : public Transform {
const IR::Node* postorder(IR::Mod* expression) override { return unsBinaryArith(expression); }
const IR::Node* postorder(IR::Shl* expression) override { return shift(expression); }
const IR::Node* postorder(IR::Shr* expression) override { return shift(expression); }
const IR::Node* postorder(IR::BXor* expression) override { return bitwise(expression); }
const IR::Node* postorder(IR::BAnd* expression) override { return bitwise(expression); }
const IR::Node* postorder(IR::BOr* expression) override { return bitwise(expression); }
const IR::Node* postorder(IR::BXor* expression) override { return binaryArith(expression); }
const IR::Node* postorder(IR::BAnd* expression) override { return binaryArith(expression); }
const IR::Node* postorder(IR::BOr* expression) override { return binaryArith(expression); }
const IR::Node* postorder(IR::Mask* expression) override { return typeSet(expression); }
const IR::Node* postorder(IR::Range* expression) override { return typeSet(expression); }
const IR::Node* postorder(IR::LNot* expression) override;
Expand Down
2 changes: 1 addition & 1 deletion testdata/p4_16_errors_outputs/binary_e.p4-stderr
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ binary_e.p4(32): [--Werror=type-error] error: Indexing a[2] applied to non-array
binary_e.p4(33): [--Werror=type-error] error: Array index d must be an integer, but it has type bool
c = stack[d]; // indexing with bool
^
binary_e.p4(35): [--Werror=type-error] error: &: Cannot operate on values with different types bit<2> and bit<4>
binary_e.p4(35): [--Werror=type-error] error: &: Cannot operate on values with different widths 2 and 4
f = e & f; // different width
^^^^^
binary_e.p4(38): [--Werror=type-error] error: <: not defined on bool and bool
Expand Down
2 changes: 1 addition & 1 deletion testdata/p4_16_samples/enumCast.p4
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ parser p(packet_in packet, out O o) {
bb = bb && (b == 0);

a = (E1) b; // OK
a = (E1)(E1.e1 + 1); // Final explicit casting makes the assinment legal
a = (E1)(E1.e1 + 1); // Final explicit casting makes the assignment legal
a = (E1)(E2.e1 + E2.e2); // Final explicit casting makes the assignment legal

packet.extract(o.b);
Expand Down
16 changes: 16 additions & 0 deletions testdata/p4_16_samples/issue3635.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
enum bit<4> e
{
a = 1,
b = 2,
c = 3
}
void f()
{
bit<8> good1;
good1 = e.a ++ e.b;
}

const bit<8> good = ((bit<4>)e.a) ++ ((bit<4>)e.b);
const bit<4> bad = e.a + e.b;
const bit<8> bad1 = e.a ++ e.b;
const bit<4> bad2 = e.a & e.b;
4 changes: 2 additions & 2 deletions testdata/p4_16_samples_outputs/enumCast-first.p4
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ parser p(packet_in packet, out O o) {
bb = bb && a == 8w0;
bb = bb && b == 8w0;
a = (E1)b;
a = (E1)(E1.e1 + 8w1);
a = (E1)(E2.e1 + E2.e2);
a = (E1)8w1;
a = (E1)8w21;
packet.extract<B>(o.b);
transition select(o.b.x) {
X.Zero &&& 32w0x1: accept;
Expand Down
4 changes: 2 additions & 2 deletions testdata/p4_16_samples_outputs/issue3056-first.p4
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ control compute() {
bit<8> x2 = -e;
bit<4> x3 = e[3:0];
bool x4 = 8w0 == e;
bit<8> x5 = e;
bit<8> x5 = (bit<8>)e;
bit<8> x6 = e << 3;
bit<16> x7 = e ++ 8w0;
bit<16> x7 = (bit<8>)e ++ 8w0;
bit<4> x8 = 4w0;
}
}
Expand Down
14 changes: 14 additions & 0 deletions testdata/p4_16_samples_outputs/issue3635-first.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
enum bit<4> e {
a = 4w1,
b = 4w2,
c = 4w3
}

void f() {
bit<8> good1;
good1 = 8w18;
}
const bit<8> good = 8w18;
const bit<4> bad = 4w3;
const bit<8> bad1 = 8w18;
const bit<4> bad2 = 4w0;
Empty file.
14 changes: 14 additions & 0 deletions testdata/p4_16_samples_outputs/issue3635.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
enum bit<4> e {
a = 1,
b = 2,
c = 3
}

void f() {
bit<8> good1;
good1 = e.a ++ e.b;
}
const bit<8> good = (bit<4>)e.a ++ (bit<4>)e.b;
const bit<4> bad = e.a + e.b;
const bit<8> bad1 = e.a ++ e.b;
const bit<4> bad2 = e.a & e.b;
1 change: 1 addition & 0 deletions testdata/p4_16_samples_outputs/issue3635.p4-stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[--Wwarn=missing] warning: Program does not contain a `main' module