-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Support enh_s2t joint training on multi-speaker data #4566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4566 +/- ##
==========================================
+ Coverage 82.46% 83.07% +0.60%
==========================================
Files 487 508 +21
Lines 42112 43790 +1678
==========================================
+ Hits 34729 36379 +1650
- Misses 7383 7411 +28
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Emrys365 . If the training experiment works, it would be good.
@@ -329,6 +401,97 @@ def nll( | |||
|
|||
batchify_nll = ESPnetASRModel.batchify_nll | |||
|
|||
def asr_pit_loss(self, speech, speech_lengths, text, text_lengths): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a rough idea about implementing all the pit loss within a PIT class. Maybe we can implement this later. In this PIT class, it can compute the pit loss for all criterions and determine the optimal permutation, can reorder the data, etc. The major target is to keep the code minimal and clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Indeed, this part can be wrapped as a standalone module. We can add support for pit later, and for now we can always use enh_model
to determine the permutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to another comment from @simpleoier, can we use the permutation defined by asr_pit_loss
for the enhancement loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use the permutation defined by asr_pit_loss for the enhancement loss?
This is possible. But the implementation might be complex. We can consider adding this function later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we can do it later.
egs2/wsj0_2mix_spatialized/enh_asr1/conf/tuning/train_enh_asr_beamformer_fbank_transformer.yaml
Outdated
Show resolved
Hide resolved
@@ -41,5 +41,5 @@ word_vocab_size=65000 | |||
--train_set "${train_set}" \ | |||
--valid_set "${valid_set}" \ | |||
--test_sets "${test_sets}" \ | |||
--bpe_train_text "data/${train_set}/text" \ | |||
--lm_train_text "data/${train_set}/text data/local/other_text/text" "$@" | |||
--bpe_train_text "data/${train_set}/text_spk1" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep text
here to be the same as chime4/asr1/run.sh. But either is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I made a symbolic link in local/data.sh
for chime4. Here I just want to explicitly use text_spk1
to reminder users. Because for multi-speaker data, the use of text
could cause errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
text_spk1
looks good to me. This can emphasizes the integration to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
The overall design looks good to me.
@@ -558,3 +558,36 @@ def _calc_transducer_loss( | |||
) | |||
|
|||
return loss_transducer, cer_transducer, wer_transducer | |||
|
|||
def _calc_batch_ctc_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It this function included in the unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is implicitly covered by test_enh_asr_model_2spk
in test/espnet2/enh/test_espnet_enh_s2t_model.py
when calc_enh_loss
is True.
Hi @simpleoier and @YoshikiMas, could you review the new changes? |
I've verified the training, inference, and scoring stages can be run successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a minor question.
One point that needs discussion is that:
Because Enh and ASR losses may have different treatment for the sequence length dimension (average or sum), the resultant loss scale can be very different and impedes the joint training. |
I think |
Oh yes,
|
I see. If the training works with the single scale parameter ( |
if text_lengths is not None: | ||
assert text_lengths.dim() == 1, text_lengths.shape | ||
if speech_lengths is not None and text_lengths is not None: | ||
if "text" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it compact here? What do you think about the following:
if any(key.startswith("text_spk") for key in kwargs.keys()): # to check if there are text input using "text_spk*"
filtered_keys = filter(...). # filter the keys of text_spk*
num_refs = len(filtered_keys)
text_ref_keys = [f"text_spk{i}" for i in range(num_refs)]
else:
text_ref_keys = ["text"]
text = []
for key in text_ref_keys:
text.append(kwargs.get(key, None))
text_ref_lengths.append(kwargs.get(key + "lengths", None))
# remember to check the number of speakers is consistant with enhancement model
assert len(text) == self.enh_model.num_spk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the legacy case (text
is used), the variables text
and text_ref_lengths
should be a tensor instead of a list. I feel it is a bit complicated to process both cases with unified code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the legacy case (text
is used), the variables text
and text_ref_lengths
should be a tensor instead of a list. I feel it is a bit complicated to process both cases with unified code.
@YoshikiMas, I think you can work with it now. Below are my training curves based on
It took 4 days to finish 50 epochs with a single TITAN RTX (24 GB).
|
Hi @Emrys365, Great PR! I have one major comment that we can consider in the future. |
@sw005320 Thanks for your comments! I totally agree that it is better to split the single-speaker and multi-speaker functions into different scripts. For the target speaker extraction function, I have an espnet2-based implementation at hand, but it needs more work to tidy up. Anyway, we may need to carefully discuss the overall design some time later. |
I just updated the experimental results in |
As discussed with @simpleoier and @YoshikiMas, I made this PR to add support for joint training of Enh and ASR tasks on multi-speaker data. (The compatibility with other S2T tasks is preserved as before.)
Most changes are made in two scripts:
egs2/TEMPLATE/enh_asr1/enh_asr.sh
text
data files with speaker-related ones (text_spk1
,text_spk2
, etc.)espnet2/enh/espnet_enh_s2t_model.py
text
data files with speaker-related ones (text_spk1
,text_spk2
, etc.)utt2category
.Note:
bypass_enh_prob
flag cannot be used for multi-speaker cases, as the multi-condition training function is not well designed for now. We can add support for this after the multi-condition training part is refactored.TODO:
I also added a recipe in
egs2/wsj0_2mix_spatialized/enh_asr1
for training MIMO-Speech style models.