Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 3, 2023
1 parent 4b0afd8 commit 72890de
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
53 changes: 53 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Do not modify directly.*
* <a href="#com.microsoft.FusedConv">com.microsoft.FusedConv</a>
* <a href="#com.microsoft.FusedGemm">com.microsoft.FusedGemm</a>
* <a href="#com.microsoft.FusedMatMul">com.microsoft.FusedMatMul</a>
* <a href="#com.microsoft.GatedRelativePositionBias">com.microsoft.GatedRelativePositionBias</a>
* <a href="#com.microsoft.GatherND">com.microsoft.GatherND</a>
* <a href="#com.microsoft.Gelu">com.microsoft.Gelu</a>
* <a href="#com.microsoft.GemmFastGelu">com.microsoft.GemmFastGelu</a>
Expand Down Expand Up @@ -1573,6 +1574,58 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GatedRelativePositionBias"></a><a name="com.microsoft.gatedrelativepositionbias">**com.microsoft.GatedRelativePositionBias**</a>

query_layer = (query_layer + query_bias).reshape(batch_size, seq_len, num_heads, head_size).transpose(1, 2)
gate_u, gate_r = torch.sigmoid(
self.gate_ur_linear(query_layer).view(batch_size, num_head, seq_len, 2, D/2).sum(-1, keepdim=False)
).chunk(2, dim=-1)
gate_u_1 = gate_u * (gate_r * self.eco_a - 1.0) + 2.0
rel_pos_bias = gate_u_1 * rel_pos

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>

#### Inputs

<dl>
<dt><tt>query_layer</tt> : T</dt>
<dd>tensor with shape (batch_size, seq_len, num_heads x head_size)</dd>
<dt><tt>query_bias</tt> : T</dt>
<dd>1-d tensor with shape (num_heads x head_size)</dd>
<dt><tt>rel_pos</tt> : T</dt>
<dd>tensor with shape (1, num_head, seq_len, seq_len)</dd>
<dt><tt>weight</tt> : T</dt>
<dd>gemm weight for the gated_ur_linear, shape (head_size, D), D is divisible by 2</dd>
<dt><tt>bias</tt> : T</dt>
<dd>bias for the gated_ur_linear, shape (D)</dd>
<dt><tt>eco_a</tt> : T</dt>
<dd>tensor of shape (1, num_heads, 1, 1)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>output tensor with shape (batch_size, num_heads, seq_len, seq_len)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


### <a name="com.microsoft.GatherND"></a><a name="com.microsoft.gathernd">**com.microsoft.GatherND**</a>

Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather
Expand Down
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ Do not modify directly.*
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
Expand Down Expand Up @@ -1087,7 +1088,8 @@ Do not modify directly.*
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down

0 comments on commit 72890de

Please sign in to comment.