-
Notifications
You must be signed in to change notification settings - Fork 3k
Add support for Numba FP16 RNNT Loss #6991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4ea7d29
7037b56
bfb6a46
e72ba5b
d623dc4
1d07953
9109eb8
707eb05
cdc5df4
052cdbd
ef1d5d1
2bee915
d425775
d43c99f
0495c51
8e1d410
a0d696f
44c80ee
2eefea2
d8b7aef
79b34f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,8 +231,8 @@ def cost_and_grad_kernel( | |
) | ||
|
||
# Scale llForward by FastEmit lambda | ||
llForward *= 1.0 + self.fastemit_lambda_ | ||
llBackward *= 1.0 + self.fastemit_lambda_ | ||
llForward += llForward * self.fastemit_lambda_ | ||
hainan-xv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
llBackward += llBackward * self.fastemit_lambda_ | ||
|
||
diff = (llForward - llBackward).abs() | ||
if diff > 0.1: | ||
|
@@ -300,6 +300,10 @@ def compute_betas_and_grads( | |
Returns: | ||
Loglikelihood of the forward variable and inplace updates the grad tensor. | ||
""" | ||
# Patch for CPU + fp16 | ||
if log_probs.dtype == torch.float16 and not log_probs.is_cuda: | ||
log_probs = log_probs.float() | ||
|
||
idx = CpuRNNT_index(U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first) | ||
betas[idx(T - 1, U - 1)] = log_probs[idx(T - 1, U - 1) * 2] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
import math | ||
from typing import Optional, Tuple | ||
|
||
import numba | ||
Check noticeCode scanning / CodeQL Module is imported with 'import' and 'import from'
Module 'numba' is imported with both 'import' and 'import from'.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. address this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm thats just numba call style. Autocomplete doesn't work well if you do numba.cuda.* There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh OK. |
||
import torch | ||
from numba i 10000 mport cuda | ||
|
||
|
@@ -112,7 +113,7 @@ def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda | |
if idx < length: | ||
copy_data_1d(source, dest, idx) | ||
dest[idx] *= -1.0 | ||
dest[idx] *= 1.0 + fastemit_lambda | ||
dest[idx] *= numba.float32(1.0 + fastemit_lambda) | ||
hainan-xv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def get_workspace_size( | ||
|
Uh oh!
There was an error while loading. Please reload this page.