8000 Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 by shubhambhokare1 · Pull Request #6669 · onnx/onnx · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 #6669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions docs/Changelog.md
2851
Original file line number Diff line number Diff line change
Expand Up @@ -29756,6 +29756,136 @@ This version of the operator has been available since version 23 of the default
<dd>Constrain input and output types to all tensor types up to IRv11.</dd>
</dl>

## Version 24 of the default ONNX operator set
### <a name="SkipLayerNormalization-24"></a>**SkipLayerNormalization-24**</a>

Applies LayerNormalization to an expanded skip connection as described in the paper https://arxiv.org/pdf/2105.07205v1
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

LayerNorm is then applied to xSkip as follows:
```
output = LayerNormalization(xSkip)
```

#### Version

This version of the operator has been available since version 24 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for layer normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs (3 - 5)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D tensor representing scale input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>beta</tt> (optional) : T</dt>
<dd>1D tensor representing bias input of layer normalization with shape of the spatial dimension along which layer normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which layer normalization is applied.</dd>
</dl>

#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input types and output Y type to float tensors.</dd>
<dt><tt>U</tt> : tensor(float)</dt>
<dd>Constrain mean and inv_std_var to float tensors.</dd>
</dl>

### <a name="SkipRMSNormalization-24"></a>**SkipRMSNormalization-24**</a>

Applies RMSNormalization to an expanded skip connection similar to SkipLayerNormalization
The expanded skip connection is defined as follows:
```
xSkip = (scaling_factor * input) + F(input) + Bias
```
where,
F(input): denotes the output of a particular layer.
scaling_factor: a modulating scalar that adjusts the importance of the skip.
Bias: a bias term added to the output of the skip connection.

RMSNorm is then applied to xSkip as follows:
```
output = RMSNormalization(xSkip)

#### Version

This version of the operator has been available since version 24 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>axis</tt> : int (default is -1)</dt>
<dd>The dimension for rms normalization. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back.</dd>
<dt><tt>epsilon</tt> : float (default is 1e-05)</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>scaling_factor</tt> : int (default is 1)</dt>
<dd>Modulating scalar by which the skip input is multiplied.</dd>
</dl>

#### Inputs (3 - 4)

<dl>
<dt><tt>X</tt> : T</dt>
<dd>The output of the layer for which the skip connection is being created. In general, the shape is (N, C, D1, D2, ... , Dn) for n-dimensional data, where D1 to Dn are the spatial dimension sizes and N is the batch size, C is the number of channels.</dd>
<dt><tt>S</tt> : T</dt>
<dd>Skip input with same shape as X. This is the input to the layer for which the skip connection is being created.</dd>
<dt><tt>gamma</tt> : T</dt>
<dd>1D tensor representing scale input of rms normalization with shape of the spatial dimension along which rms normalization is applied.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd>1D bias tensor for the skip connection with shape of the spatial dimension along which rms normalization is applied.</dd>
</dl>

#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T</dt>
<dd>Output tensor with same shape as X</dd>
<dt><tt>InputSkipBiasSum</tt> (optional) : T</dt>
<dd>Sum of the input and skip inputs (and bias if it exists). Same shape as X</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input X, S (skip) and B (bias) types to float tensors.</dd>
<dt><tt>V</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain output Y and input gamma type to float tensors.</dd>
<dt><tt>U</tt> : tensor(float)</dt>
<dd>Constrain mean and inv_std_var to float tensors.</dd>
</dl>

# ai.onnx.preview.training
## Version 1 of the 'ai.onnx.preview.training' operator set
### <a name="ai.onnx.preview.training.Adagrad-1"></a>**ai.onnx.preview.training.Adagrad-1**</a>
Expand Down
Loading
Loading
0