Skip to content

Commit

Permalink
Implement header union invaidation in the BMv2 back end.
Browse files Browse the repository at this point in the history
Signed-off-by: fruffy <fruffy@nyu.edu>
  • Loading branch information
fruffy committed Oct 25, 2024
1 parent 52ea499 commit c088263
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
28 changes: 28 additions & 0 deletions backends/bmv2/common/action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ cstring ActionConverter::jsonAssignment(const IR::Type *type) {
return "assign"_cs;
}

namespace {
/// Invalidates all other headers in a header union except the provided source header.
/// Has no effect if the parent structure is not a header union.
void invalidateOtherHeaderUnionHeaders(const IR::Member *sourceHeader,
const ConversionContext &ctxt, Util::JsonArray *result,
const IR::StatOrDecl *sourceStatement) {
const auto *type = ctxt.typeMap->getType(sourceHeader->expr, true);
if (const auto *headerUnionType = type->to<IR::Type_HeaderUnion>()) {
for (const auto *field : headerUnionType->fields) {
// Do not set the source member invalid.
if (sourceHeader->member == field->name) {
continue;
}
auto *member = new IR::Member(field->type, sourceHeader->expr, field->name);
ctxt.typeMap->setType(member, field->type);
auto *primitive = mkPrimitive("remove_header"_cs, result);
primitive->emplace_non_null("source_info"_cs, sourceStatement->sourceInfoJsonObj());
primitive->emplace("parameters", new Util::JsonArray({ctxt.conv->convert(member)}));
}
}
}
} // namespace

void ActionConverter::convertActionBody(const IR::Vector<IR::StatOrDecl> *body,
Util::JsonArray *result) {
for (auto s : *body) {
Expand Down Expand Up @@ -146,6 +169,11 @@ void ActionConverter::convertActionBody(const IR::Vector<IR::StatOrDecl> *body,
prim = "add_header"_cs;
} else if (builtin->name == IR::Type_Header::setInvalid) {
prim = "remove_header"_cs;
// If setInvalid is called on any header in a header union, we need to
// invalidate all other headers in the union.
if (const auto *parentStructure = builtin->appliedTo->to<IR::Member>()) {
invalidateOtherHeaderUnionHeaders(parentStructure, *ctxt, result, s);
}
} else if (builtin->name == IR::Type_Stack::push_front) {
BUG_CHECK(mc->arguments->size() == 1, "Expected 1 argument for %1%", mc);
auto arg = ctxt->conv->convert(mc->arguments->at(0)->expression);
Expand Down
83 changes: 63 additions & 20 deletions backends/bmv2/common/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,39 @@ limitations under the License.
#include "frontends/p4-14/fromv1.0/v1model.h"
#include "frontends/p4/coreLibrary.h"
#include "lib/algorithm.h"
#include "lib/json.h"

namespace P4::BMV2 {

namespace {
/// Invalidates all other headers in a header union except the provided source header.
/// Has no effect if the parent structure is not a header union.
void invalidateOtherHeaderUnionHeaders(const IR::Member *sourceHeader,
const ConversionContext &ctxt, Util::JsonArray *result) {
const auto *type = ctxt.typeMap->getType(sourceHeader->expr, true);
if (const auto *headerUnionType = type->to<IR::Type_HeaderUnion>()) {
for (const auto *field : headerUnionType->fields) {
// Do not set the source member invalid.
if (sourceHeader->member == field->name) {
continue;
}
auto *member = new IR::Member(field->type, sourceHeader->expr, field->name);
ctxt.typeMap->setType(member, field->type);
auto *obj = new Util::JsonObject();
obj->emplace("op", "primitive");
auto *params = mkArrayField(obj, "parameters"_cs);
auto *paramsValue = new Util::JsonObject();
params->append(paramsValue);
auto *pp = mkArrayField(paramsValue, "parameters"_cs);
auto *biObj = ctxt.conv->convert(member);
pp->append(biObj);
paramsValue->emplace("op", "remove_header");
result->append(obj);
}
}
}
} // namespace

cstring ParserConverter::jsonAssignment(const IR::Type *type) {
if (type->is<IR::Type_HeaderUnion>()) return "assign_union"_cs;
if (type->is<IR::Type_Header>() || type->is<IR::Type_Struct>()) return "assign_header"_cs;
Expand All @@ -40,9 +70,10 @@ cstring ParserConverter::jsonAssignment(const IR::Type *type) {
return "set"_cs;
}

Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat) {
auto result = new Util::JsonObject();
auto params = mkArrayField(result, "parameters"_cs);
Util::JsonArray *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat) {
auto *result = new Util::JsonArray();
auto *obj = new Util::JsonObject();
auto *params = mkArrayField(obj, "parameters"_cs);
auto isR = false;
IR::MethodCallExpression *mce2 = nullptr;
if (stat->is<IR::AssignmentStatement>()) {
Expand Down Expand Up @@ -104,7 +135,7 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
auto assign = stat->to<IR::AssignmentStatement>();
auto type = ctxt->typeMap->getType(assign->left, true);
cstring operation = jsonAssignment(type);
result->emplace("op", operation);
obj->emplace("op", operation);
auto l = ctxt->conv->convertLeftValue(assign->left);
bool convertBool = type->is<IR::Type_Boolean>();
auto r = ctxt->conv->convert(assign->right, true, true, convertBool);
Expand All @@ -116,10 +147,11 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
auto wrap = new Util::JsonObject();
wrap->emplace("op", "primitive");
auto params = mkParameters(wrap);
params->append(result);
result = wrap;
params->append(obj);
obj = wrap;
}

result->push_back(obj);
return result;
} else if (stat->is<IR::MethodCallStatement>()) {
auto mce = stat->to<IR::MethodCallStatement>()->methodCall;
Expand All @@ -135,7 +167,7 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
}

cstring ename = argCount == 1 ? "extract"_cs : "extract_VL"_cs;
result->emplace("op", ename);
obj->emplace("op", ename);
auto arg = mce->arguments->at(0);
auto argtype = ctxt->typeMap->getType(arg->expression, true);
if (!argtype->is<IR::Type_Header>()) {
Expand Down Expand Up @@ -206,6 +238,7 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
rwrap->emplace("value", jexpr);
params->append(rwrap);
}
result->push_back(obj);
return result;
} else if (extmeth->method->name.name == corelib.packetIn.lookahead.name) {
// bare lookahead call -- should flag an error if there's not enough
Expand All @@ -218,8 +251,9 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
}
auto arg = mce->arguments->at(0);
auto jexpr = ctxt->conv->convert(arg->expression, true, false);
result->emplace("op", "advance");
obj->emplace("op", "advance");
params->append(jexpr);
result->push_back(obj);
return result;
} else if ((extmeth->originalExternType->name == "InternetChecksum" &&
(extmeth->method->name.name == "clear" ||
Expand All @@ -236,16 +270,17 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
json = ExternConverter::cvtExternObject(ctxt, extmeth, mce, stat, true);
}
if (json) {
result->emplace("op", "primitive");
obj->emplace("op", "primitive");
params->append(json);
}
result->push_back(obj);
return result;
}
} else if (minst->is<P4::ExternFunction>()) {
auto extfn = minst->to<P4::ExternFunction>();
auto extFuncName = extfn->method->name.name;
if (extFuncName == IR::ParserState::verify) {
result->emplace("op", "verify");
obj->emplace("op", "verify");
BUG_CHECK(mce->arguments->size() == 2, "%1%: Expected 2 arguments", mce);
{
auto cond = mce->arguments->at(0);
Expand All @@ -261,10 +296,9 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
auto jexpr = ctxt->conv->convert(error->expression, true, false);
params->append(jexpr);
}
return result;
} else if (extFuncName == "assert" || extFuncName == "assume") {
BUG_CHECK(mce->arguments->size() == 1, "%1%: Expected 1 argument ", mce);
result->emplace("op", "primitive");
obj->emplace("op", "primitive");
auto paramValue = new Util::JsonObject();
params->append(paramValue);
auto paramsArray = mkArrayField(paramValue, "parameters"_cs);
Expand All @@ -276,12 +310,13 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
} else if (extFuncName == P4V1::V1Model::instance.log_msg.name) {
BUG_CHECK(mce->arguments->size() == 2 || mce->arguments->size() == 1,
"%1%: Expected 1 or 2 arguments", mce);
result->emplace("op", "primitive");
obj->emplace("op", "primitive");
auto ef = minst->to<P4::ExternFunction>();
auto ijson = ExternConverter::cvtExternFunction(ctxt, ef, mce, stat, false);
params->append(ijson);
return result;
}
result->push_back(obj);
return result;
} else if (minst->is<P4::BuiltInMethod>()) {
/* example result:
{
Expand All @@ -293,21 +328,26 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
],
"op" : "primitive"
} */
result->emplace("op", "primitive");
obj->emplace("op", "primitive");

auto bi = minst->to<P4::BuiltInMethod>();
cstring primitive;
auto paramsValue = new Util::JsonObject();
params->append(paramsValue);

auto pp = mkArrayField(paramsValue, "parameters"_cs);
auto obj = ctxt->conv->convert(bi->appliedTo);
pp->append(obj);
auto biObj = ctxt->conv->convert(bi->appliedTo);
pp->append(biObj);

if (bi->name == IR::Type_Header::setValid) {
primitive = "add_header"_cs;
} else if (bi->name == IR::Type_Header::setInvalid) {
primitive = "remove_header"_cs;
// If setInvalid is called on any header in a header union, we need to
// invalidate all other headers in the union.
if (const auto *parentStructure = bi->appliedTo->to<IR::Member>()) {
invalidateOtherHeaderUnionHeaders(parentStructure, *ctxt, result);
}
} else if (bi->name == IR::Type_Stack::push_front ||
bi->name == IR::Type_Stack::pop_front) {
if (bi->name == IR::Type_Stack::push_front)
Expand All @@ -323,6 +363,7 @@ Util::IJson *ParserConverter::convertParserStatement(const IR::StatOrDecl *stat)
}

paramsValue->emplace("op", primitive);
result->push_back(obj);
return result;
}
}
Expand Down Expand Up @@ -560,9 +601,11 @@ bool ParserConverter::preorder(const IR::P4Parser *parser) {
// For the state we use the internal name, not the control-plane name
auto state_id = ctxt->json->add_parser_state(parser_id, state->name);
// convert statements
for (auto s : state->components) {
auto op = convertParserStatement(s);
if (op) ctxt->json->add_parser_op(state_id, op);
for (const auto *s : state->components) {
auto *op = convertParserStatement(s);
for (auto *o : *op) {
ctxt->json->add_parser_op(state_id, o);
}
}
// convert transitions
if (state->selectExpression != nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion backends/bmv2/common/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ParserConverter : public Inspector {
unsigned combine(const IR::Expression *keySet, const IR::ListExpression *select, big_int &value,
big_int &mask, bool &is_vset, cstring &vset_name) const;
Util::IJson *stateName(IR::ID state);
Util::IJson *convertParserStatement(const IR::StatOrDecl *stat);
Util::JsonArray *convertParserStatement(const IR::StatOrDecl *stat);
Util::IJson *convertSelectKey(const IR::SelectExpression *expr);
Util::IJson *convertPathExpression(const IR::PathExpression *expr);
Util::IJson *createDefaultTransition();
Expand Down

0 comments on commit c088263

Please sign in to comment.