Skip to content

Commit

Permalink
[T5 optimization] fuse rel_pos_bias and remove extended mask (#14645)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

1. fuse rel_pos_bias in T5.
2. remove extended masks in T5 decoder and decoder_init since they
generate all zeros
3. fix a bug in onnx_model.py


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
  • Loading branch information
wangyems and Ubuntu authored Feb 14, 2023
1 parent 3703397 commit 2a4c9a5
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 11 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/transformers/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def has_same_value(tensor1: TensorProto, tensor2: TensorProto) -> bool:
return False
if tensor1.HasField("raw_data") and tensor2.HasField("raw_data"):
return tensor1.raw_data == tensor2.raw_data
return numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)
return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all()

def remove_duplicated_initializer(self):
"""Remove initializers with duplicated values, and only keep the first one.
Expand Down
171 changes: 161 additions & 10 deletions onnxruntime/python/tools/transformers/onnx_model_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import logging
from typing import Union

import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from fusion_base import Fusion
from fusion_skiplayernorm import FusionSkipLayerNormalization
from onnx import NodeProto
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel

Expand Down Expand Up @@ -48,16 +50,89 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
return


# It's much easier to export it with the custom op. TODO: revisit later
class FusionRelativePositionBiasBlock(Fusion):
def __init__(self, model: OnnxModel, max_distance: int, is_bidirectional: bool):
super().__init__(model, "RelativePositionBias", "Add")
def __init__(self, model: OnnxModel, max_distance: int):
super().__init__(model, "RelativePositionBias", ["Add", "Slice"])
self.max_distance = max_distance
self.is_bidirectional = is_bidirectional
# bidirectional=(not self.is_decoder)
self.is_bidirectional = False

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Not implemented yet
return
# TODO: Optimization opportunity: only last dimension of relative_position_bias is used in decoder.
# Cuda kernel can be optimized to only compute last dimension.
if node.op_type != "Add" and node.op_type != "Slice":
return

compute_bias_nodes = self.model.match_parent_path(
node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1]
)
if compute_bias_nodes is None:
compute_bias_nodes = self.model.match_parent_path(
node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1]
)
if compute_bias_nodes is None:
return

gather = compute_bias_nodes[2]
where = compute_bias_nodes[-1]
unsqueeze = compute_bias_nodes[0]

compute_buckets_nodes = self.model.match_parent_path(
where,
["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"],
[2, 1, 0, 0, 0, 0, 0, 0, 0],
)
if compute_buckets_nodes is None:
return

div = compute_buckets_nodes[-1]

range_nodes = self.model.match_parent_path(
div,
["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"],
[0, 0, 0, 1, 0, 0, 0, 0],
)
if range_nodes is None:
range_nodes = self.model.match_parent_path(
div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0]
)
self.is_bidirectional = True
if range_nodes is None:
return

range_node = range_nodes[-1]

self.nodes_to_remove.extend(compute_bias_nodes)
self.nodes_to_remove.extend(compute_buckets_nodes)
self.nodes_to_remove.extend(range_nodes)

node_name_prefix = "encoder" if self.is_bidirectional else "decoder"

table_weight_i = self.model.get_initializer(gather.input[0])
table_weight = NumpyHelper.to_array(table_weight_i)
table_weight_t = np.transpose(table_weight)
bias_table = helper.make_tensor(
name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix),
data_type=TensorProto.FLOAT,
dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]],
vals=table_weight_t.flatten().tolist(),
)

self.model.add_initializer(bias_table, self.this_graph_name)
inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
outputs = [unsqueeze.output[0]]
rpb_node = helper.make_node(
"RelativePositionBias",
inputs=inputs,
outputs=outputs,
name=self.model.create_node_name("RelativePositionBias", name_prefix=node_name_prefix),
)
rpb_node.domain = "com.microsoft"
rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)])
rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", self.is_bidirectional)])

self.nodes_to_add.append(rpb_node)
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name


class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
Expand All @@ -77,16 +152,92 @@ def __init__(self, model, num_heads, hidden_size):
self.attention_mask = AttentionMask(self)
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
# TODO: hardcode for now. double check later
self.rpb_fusion = FusionRelativePositionBiasBlock(self, 32, True)
# TODO: consider retrive max_distance from model.
# math.log(max_distance / (num_buckets // 2))
self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128)

def fuse_attention(self):
self.attention_fusion.apply()

def fuse_skip_layer_norm(self):
self.skip_layer_norm_fusion.apply()

# Remove get_extended_attention_mask() since it generates all zeros.
def remove_extended_mask_decoder_init(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == "Add":
extended_mask_nodes = self.match_parent_path(
node,
[
"Mul",
"Sub",
"Mul",
"Unsqueeze",
"Cast",
"LessOrEqual",
"Tile",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if extended_mask_nodes is None:
continue

rpb_nodes = self.match_parent_path(node, ["RelativePositionBias"], [0])
if rpb_nodes is None:
continue

rpb_node = rpb_nodes[0]
rpb_node.output[0] = node.output[0]

nodes_to_remove.extend(extended_mask_nodes)
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)

def remove_extended_mask_decoder(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == "Add":
extended_mask_nodes = self.match_parent_path(
node,
[
"Mul",
"Sub",
"Mul",
"Unsqueeze",
"Concat",
"Cast",
"LessOrEqual",
"Tile",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
],
[1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0],
)
if extended_mask_nodes is None:
continue

rpb_nodes = self.match_parent_path(node, ["Slice", "RelativePositionBias"], [0, 0])
if rpb_nodes is None:
continue

rpb_node = rpb_nodes[0]
rpb_node.output[0] = node.output[0]

nodes_to_remove.extend(extended_mask_nodes)
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)

def postprocess(self):
self.rpb_fusion.apply()
self.clean_graph()
# remove get_extended_attention_mask() since it generates all zeros.
self.remove_extended_mask_decoder_init()
self.remove_extended_mask_decoder()

self.prune_graph()

0 comments on commit 2a4c9a5

Please sign in to comment.