Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Assign attributes of transformer operators (#20902)
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelGlomski-Intel authored Mar 4, 2022
1 parent 6ebd3bb commit edba375
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/operator/subgraph/mkldnn/mkldnn_transformer_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,15 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
new_sym.outputs.emplace_back(last_node);
std::ostringstream node_name;
std::string op_name;
MKLDNNSelfAttParam new_param;
DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) {
if (node->op() &&
(node->op()->name == SELFATT_QK ||
node->op()->name == SELFATT_VALATT)) {
op_name = node->op()->name;
auto param = nnvm::get<InterleavedMatMulParam>(node->attrs.parsed);
new_param.heads = param.heads;
new_param.quantized = false;
new_param.enable_float_output = false;
n->attrs.dict["heads"] = std::to_string(param.heads);
n->attrs.dict["quantized"] = "False";
n->attrs.dict["enable_float_output"] = "False";
}
});
node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id);
Expand All @@ -109,7 +108,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty {
n->attrs.op = Op::Get(OpMapping.at(op_name));
CHECK(n->attrs.op);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
n->attrs.parsed = new_param;
n->op()->attr_parser(&(n->attrs));
return n;
}

Expand Down

0 comments on commit edba375

Please sign in to comment.