8000 Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 by shubhambhokare1 · Pull Request #6669 · onnx/onnx · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 #6669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

shubhambhokare1
Copy link
Contributor
@shubhambhokare1 shubhambhokare1 commented Jan 31, 2025

PR adding the following ops to ONNX Opset 24

  • SkipLayerNormalization
  • SkipRMSNormalization

RMSNormalization merged in
#6443

Copy link
codecov bot commented Jan 31, 2025

Codecov Report

Attention: Patch coverage is 23.37662% with 118 lines in your changes missing coverage. Please review.

Project coverage is 56.33%. Comparing base (94e8207) to head (51f047b).
Report is 14 commits behind head on main.

Files with missing lines Patch % Lines
...x/backend/test/case/node/skiplayernormalization.py 0.00% 57 Missing ⚠️
...nnx/backend/test/case/node/skiprmsnormalization.py 0.00% 53 Missing ⚠️
onnx/reference/ops/op_skip_layer_normalization.py 76.47% 2 Missing and 2 partials ⚠️
onnx/reference/ops/op_skip_rms_normalization.py 76.47% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6669      +/-   ##
==========================================
- Coverage   56.49%   56.33%   -0.16%     
==========================================
  Files         509      513       +4     
  Lines       32724    32878     +154     
  Branches     3097     3101       +4     
==========================================
+ Hits        18487    18523      +36     
- Misses      13379    13493     +114     
- Partials      858      862       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -209,7 +209,9 @@
from onnx.reference.ops.op_sin import Sin
from onnx.reference.ops.op_sinh import Sinh
from onnx.reference.ops.op_size import Size
from onnx.reference.ops.op_skip_layer_normalization import SkipLayerNormalization
from onnx.reference.ops.op_skip_rms_normalization import SkipRMSNormalization

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'SkipLayerNormalization' is not used.
10000
@shubhambhokare1 shubhambhokare1 changed the title Add SkipLayerNormalization and RMSNormalization to ONNX Opset 23 Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 Apr 30, 2025
@shubhambhokare1 shubhambhokare1 marked this pull request as ready for review April 30, 2025 08:59
@shubhambhokare1 shubhambhokare1 requested review from a team as code owners April 30, 2025 08:59
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Comment on lines +3869 to +3870
ONNX_OPERATOR_SET_SCHEMA(
SkipLayerNormalization,

Check notice

Code scanning / CodeQL

Unused static variable Note

Static variable dbg_count_check_Onnx_24_verSkipLayerNormalization is never read.
Comment on lines +4025 to +4026
ONNX_OPERATOR_SET_SCHEMA(
SkipRMSNormalization,

Check notice

Code scanning / CodeQL

Unused static variable Note

Static variable dbg_count_check_Onnx_24_verSkipRMSNormalization is never read.
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Comment on lines 3924 to 3928
.TypeConstraint(
"T",
{"tensor(float)", "tensor(float16)"},
"Constrain input and output types to float or half tensors.")
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this should be aligned with the type constraints of the normal LayerNormalization.
https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#type-constraints
bfloat16 in particular, and just for completeness double although I don't see the need for double.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on this point, we are also missing the stash_type attribute of LN. My concern is this ends up essentially duplicating a lot of the apparatus of our existing normalization ops, and any future changes to the normalization ops would need to be reflected back in these skip norm ops too.

Do you know what's the main use case of these ops and what was the motivation they were introduced as contrib ops? If the motivation is to suggest useful fusion patterns, would it be simple enough for backends to handle this with pattern matching?

Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In progress
Development

Successfully merging this pull request may close these issues.

3 participants
0