8000 use torch.finfo for dtype other than float by wenzhe-nrv · Pull Request #4584 · espnet/espnet · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

use torch.finfo for dtype other than float #4584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 20, 2022

Conversation

wenzhe-nrv
Copy link
Contributor

We encountered an issue when data type is not float. In our case, we were using bfloat16.

  File "/espnet/espnet/nets/pytorch_backend/transformer/attention.py", line 83, in forward_attention
    scores = scores.masked_fill(mask, min_value)
RuntimeError: value cannot be converted to type at::BFloat16 without overflow

To fix it, we use torch.finfo and set the correct min value based on the torch.dtype.
From the output below, it shows that torch.float32 min is less than the min of torch.bfloat16, thus causing the overflow error.

In [4]: torch.finfo(torch.float32).min
Out[4]: -3.4028234663852886e+38In 

[5]: torch.finfo(torch.bfloat16).min
Out[5]: -3.3895313892515355e+38 

@mergify mergify bot added the ESPnet1 label Aug 18, 2022
@b-flo
Copy link
Member
b-flo commented Aug 19, 2022

Hi,

Thanks for the fix! Hum, we also rely on inf constant and cast to float64 here. It may be an issue:

scores = scores.masked_fill(mask, float("-inf"))

P.S: You need to remove the import for numpy otherwise flake8 complains.

@wenzhe-nrv
Copy link
Contributor Author
wenzhe-nrv commented Aug 19, 2022

we also rely on inf constant and cast to float64 here. It may be an issue:

scores = scores.masked_fill(mask, float("-inf"))

yes, I think so. It's better to use torch.finfo based on the input datatype.
@b-flo are you suggesting to change that place as well? I could do that in this PR. Please let me know. Thanks.

@codecov
Copy link
codecov bot commented Aug 19, 2022

Codecov Report

Merging #4584 (592de81) into master (8331902) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4584      +/-   ##
==========================================
- Coverage   83.06%   83.06%   -0.01%     
==========================================
  Files         508      508              
  Lines       43655    43654       -1     
==========================================
- Hits        36262    36260       -2     
- Misses       7393     7394       +1     
Flag Coverage Δ
test_integration_espnet1 66.36% <100.00%> (-0.01%) ⬇️
test_integration_espnet2 49.53% <100.00%> (-0.01%) ⬇️
test_python 70.60% <100.00%> (-0.01%) ⬇️
test_utils 23.28% <100.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...pnet/nets/pytorch_backend/transformer/attention.py 96.11% <100.00%> (-0.04%) ⬇️
espnet/distributed/pytorch_backend/launch.py 82.75% <0.00%> (-1.15%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@sw005320 sw005320 added this to the v.202209 milestone Aug 20, 2022
@b-flo
Copy link
Member
b-flo commented Aug 20, 2022

@b-flo are you suggesting to change that place as well? I could do that in this PR. Please let me know. Thanks.

You don't need to, thanks! The code snippet is for the Transducer model which doesn't support half-precision outside dynamic quantization.
I'll do the modification later to pass the selected dtype during initialization in inference and avoid calling torch.finfo multiple times.

@sw005320
Copy link
Contributor

Thanks a lot, @wenzhe-nrv!

@sw005320 sw005320 merged commit 939da5d into espnet:master Aug 20, 2022
@wenzhe-nrv wenzhe-nrv deleted the fix_dtype_min_overflow branch August 22, 2022 16:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0