-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add SkipLayerNormalization and SkipRMSNormalization to ONNX Opset 24 #6669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #6669 +/- ##
==========================================
- Coverage 56.49% 56.33% -0.16%
==========================================
Files 509 513 +4
Lines 32724 32878 +154
Branches 3097 3101 +4
==========================================
+ Hits 18487 18523 +36
- Misses 13379 13493 +114
- Partials 858 862 +4 ☔ View full report in Codecov by Sentry. |
@@ -209,7 +209,9 @@ | |||
from onnx.reference.ops.op_sin import Sin | |||
from onnx.reference.ops.op_sinh import Sinh | |||
from onnx.reference.ops.op_size import Size | |||
from onnx.reference.ops.op_skip_layer_normalization import SkipLayerNormalization | |||
from onnx.reference.ops.op_skip_rms_normalization import SkipRMSNormalization |
Check notice
Code scanning / CodeQL
Unused import Note
b54b3ef
to
1e9922b
Compare
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
b9cf847
to
f8d405a
Compare
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
ONNX_OPERATOR_SET_SCHEMA( | ||
SkipLayerNormalization, |
Check notice
Code scanning / CodeQL
Unused static variable Note
ONNX_OPERATOR_SET_SCHEMA( | ||
SkipRMSNormalization, |
Check notice
Code scanning / CodeQL
Unused static variable Note
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
onnx/defs/nn/defs.cc
Outdated
.TypeConstraint( | ||
"T", | ||
{"tensor(float)", "tensor(float16)"}, | ||
"Constrain input and output types to float or half tensors.") | ||
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this should be aligned with the type constraints of the normal LayerNormalization.
https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#type-constraints
bfloat16
in particular, and just for completeness double
although I don't see the need for double.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on this point, we are also missing the stash_type
attribute of LN. My concern is this ends up essentially duplicating a lot of the apparatus of our existing normalization ops, and any future changes to the normalization ops would need to be reflected back in these skip norm ops too.
Do you know what's the main use case of these ops and what was the motivation they were introduced as contrib ops? If the motivation is to suggest useful fusion patterns, would it be simple enough for backends to handle this with pattern matching?
Signed-off-by: shubhambhokare1 <shubhambhokare@gmail.com>
PR adding the following ops to ONNX Opset 24
RMSNormalization merged in
#6443