diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index f022bccc24ac..8228f4e0634b 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -90,16 +90,15 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { new_sym.outputs.emplace_back(last_node); std::ostringstream node_name; std::string op_name; - MKLDNNSelfAttParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->op() && (node->op()->name == SELFATT_QK || node->op()->name == SELFATT_VALATT)) { op_name = node->op()->name; auto param = nnvm::get(node->attrs.parsed); - new_param.heads = param.heads; - new_param.quantized = false; - new_param.enable_float_output = false; + n->attrs.dict["heads"] = std::to_string(param.heads); + n->attrs.dict["quantized"] = "False"; + n->attrs.dict["enable_float_output"] = "False"; } }); node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id); @@ -109,7 +108,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { n->attrs.op = Op::Get(OpMapping.at(op_name)); CHECK(n->attrs.op); n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); - n->attrs.parsed = new_param; + n->op()->attr_parser(&(n->attrs)); return n; }