-
Notifications
You must be signed in to change notification settings - Fork 12
Only convert inputs to FP16 when FP16 stage is used #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! It's funny that we already had the downcast
arg but just never used it.
I approve, but I suggest that you ask @vgodsoe-groq to test the changes within GroqFlow on a GroqNode before merging (just to avoid any potential back and forth additional PRs).
I have a decision to make about this PR and my Magic 8 ball seems to be broken. Which solution is the most sound here? Issue: Potential solutions: |
Closes #336
Description
This PR ensures that
expected_input_dtypes
are updated when the FP16 stage is used.Testing
A test was added to
build_model.py
. You can also test the changes using the script below: