8000 Support enh_s2t joint training on multi-speaker data by Emrys365 · Pull Request #4566 · espnet/espnet · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support enh_s2t joint training on multi-speaker data #4566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 31, 2022

Conversation

Emrys365
Copy link
Collaborator
@Emrys365 Emrys365 commented Aug 11, 2022

As discussed with @simpleoier and @YoshikiMas, I made this PR to add support for joint training of Enh and ASR tasks on multi-speaker data. (The compatibility with other S2T tasks is preserved as before.)

Most changes are made in two scripts:

  • egs2/TEMPLATE/enh_asr1/enh_asr.sh

    • I replace all text data files with speaker-related ones (text_spk1, text_spk2, etc.)
    • I added support for loading noise and dereverberation reference signals during training.
    • The ASR evaluation stage is also modified accordingly to support multi-speaker data.
  • espnet2/enh/espnet_enh_s2t_model.py

    • I replace all text data files with speaker-related ones (text_spk1, text_spk2, etc.)
    • Permutation invariant training related code is added.
    • NOTE: the multi-condition training related code should be rewritten to utilize the information from utt2category.

Note:

  • Currently, the bypass_enh_prob flag cannot be used for multi-speaker cases, as the multi-condition training function is not well designed for now. We can add support for this after the multi-condition training part is refactored.

TODO:

  • The inference-related code should also be modified accordingly.

I also added a recipe in egs2/wsj0_2mix_spatialized/enh_asr1 for training MIMO-Speech style models.

@Emrys365 Emrys365 added Recipe ESPnet2 SE Speech enhancement labels Aug 11, 2022
@Emrys365 Emrys365 requested a review from simpleoier August 11, 2022 17:18
@sw005320 sw005320 added this to the v.202209 milestone Aug 11, 2022
@codecov
Copy link
codecov bot commented Aug 12, 2022

Codecov Report

Merging #4566 (8af19a3) into master (24b12f8) will increase coverage by 0.60%.
The diff coverage is 82.42%.

@@            Coverage Diff             @@
##           master    #4566      +/-   ##
==========================================
+ Coverage   82.46%   83.07%   +0.60%     
==========================================
  Files         487      508      +21     
  Lines       42112    43790    +1678     
==========================================
+ Hits        34729    36379    +1650     
- Misses       7383     7411      +28     
Flag Coverage Δ
test_integration_espnet1 66.36% <ø> (-0.01%) ⬇️
test_integration_espnet2 49.53% <55.15%> (+1.26%) ⬆️
test_python 70.62% <65.45%> (+0.96%) ⬆️
test_utils 23.28% <ø> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
espnet2/enh/espnet_enh_s2t_model.py 81.72% <76.19%> (-0.64%) ⬇️
espnet2/bin/asr_inference.py 84.98% <82.92%> (+0.08%) ⬆️
espnet2/asr/espnet_model.py 81.38% <85.71%> (+0.27%) ⬆️
espnet2/enh/espnet_model.py 86.56% <100.00%> (+1.08%) ⬆️
espnet2/enh/loss/wrappers/fixed_order.py 91.30% <100.00%> (+0.39%) ⬆️
espnet2/tasks/abs_task.py 75.52% <100.00%> (+0.03%) ⬆️
espnet2/tasks/enh_s2t.py 96.66% <100.00%> (+0.05%) ⬆️
...pnet/nets/pytorch_backend/transformer/attention.py 96.11% <0.00%> (-0.04%) ⬇️
espnet2/tasks/asr.py 91.76% <0.00%> (ø)
espnet2/asr/transducer/joint_network.py
... and 31 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Collaborator
@simpleoier simpleoier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Emrys365 . If the training experiment works, it would be good.

@@ -329,6 +401,97 @@ def nll(

batchify_nll = ESPnetASRModel.batchify_nll

def asr_pit_loss(self, speech, speech_lengths, text, text_lengths):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a rough idea about implementing all the pit loss within a PIT class. Maybe we can implement this later. In this PIT class, it can compute the pit loss for all criterions and determine the optimal permutation, can reorder the data, etc. The major target is to keep the code minimal and clean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Indeed, this part can be wrapped as a standalone module. We can add support for pit later, and for now we can always use enh_model to determine the permutation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to another comment from @simpleoier, can we use the permutation defined by asr_pit_loss for the enhancement loss?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the permutation defined by asr_pit_loss for the enhancement loss?

This is possible. But the implementation might be complex. We can consider adding this function later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we can do it later.

@@ -41,5 +41,5 @@ word_vocab_size=65000
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--bpe_train_text "data/${train_set}/text" \
--lm_train_text "data/${train_set}/text data/local/other_text/text" "$@"
--bpe_train_text "data/${train_set}/text_spk1" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep text here to be the same as chime4/asr1/run.sh. But either is fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I made a symbolic link in local/data.sh for chime4. Here I just want to explicitly use text_spk1 to reminder users. Because for multi-speaker data, the use of text could cause errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_spk1 looks good to me. This can emphasizes the integration to users.

Copy link
Contributor
@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!
The overall design looks good to me.

@@ -558,3 +558,36 @@ def _calc_transducer_loss(
)

return loss_transducer, cer_transducer, wer_transducer

def _calc_batch_ctc_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It this function included in the unit test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is implicitly covered by test_enh_asr_model_2spk in test/espnet2/enh/test_espnet_enh_s2t_model.py when calc_enh_loss is True.

@Emrys365
Copy link
Collaborator Author

Hi @simpleoier and @YoshikiMas, could you review the new changes?

@Emrys365
Copy link
Collaborator Author

Thanks @Emrys365 . If the training experiment works, it would be good.

I've verified the training, inference, and scoring stages can be run successfully.

Copy link
Contributor
@YoshikiMas YoshikiMas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a minor question.

@Emrys365
9E12 Copy link
Collaborator Author

One point that needs discussion is that:

  • is it better to add a scaling factor for loss_enh when both Enh and ASR losses are used?

Because Enh and ASR losses may have different treatment for the sequence length dimension (average or sum), the resultant loss scale can be very different and impedes the joint training.

@YoshikiMas
Copy link
Contributor

I think weight in wrapper_conf for enh_criterions can adjust the ratio between two losses.
And, can length_normalized_loss help reducing such a discrepancy further? I didn't try it yet.

@Emrys365
Copy link
Collaborator Author

I think weight in wrapper_conf for enh_criterions can adjust the ratio between two losses. And, can length_normalized_loss help reducing such a discrepancy further? I didn't try it yet.

Oh yes, enh_criterions can be used to scale loss_enh.

length_normalized_loss should also be helpful, but we may have to care about all ASR losses if this flag is used, because it only impacts loss_att.

@YoshikiMas
Copy link
Contributor

length_normalized_loss should also be helpful, but we may have to care about all ASR losses if this flag is used, because it only impacts loss_att.

I see. If the training works with the single scale parameter (weight), the current one is good enough for me as an initial implementation.

if text_lengths is not None:
assert text_lengths.dim() == 1, text_lengths.shape
if speech_lengths is not None and text_lengths is not None:
if "text" in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it compact here? What do you think about the following:

if any(key.startswith("text_spk") for key in kwargs.keys()):  # to check if there are text input using "text_spk*"
    filtered_keys = filter(...). # filter the keys of text_spk*
    num_refs = len(filtered_keys)
    text_ref_keys = [f"text_spk{i}" for i in range(num_refs)]
else:
    text_ref_keys = ["text"]

text = []
for key in text_ref_keys:
    text.append(kwargs.get(key, None))
    text_ref_lengths.append(kwargs.get(key + "lengths", None))

# remember to check the number of speakers is consistant with enhancement model
assert len(text) == self.enh_model.num_spk

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the legacy case (text is used), the variables text and text_ref_lengths should be a tensor instead of a list. I feel it is a bit complicated to process both cases with unified code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the legacy case (text is used), the variables text and text_ref_lengths should be a tensor instead of a list. I feel it is a bit complicated to process both cases with unified code.

@Emrys365 Emrys365 added Enhancement Enhancement ASR Automatic speech recogntion labels Aug 22, 2022
@Emrys365
Copy link
Collaborator Author
Emrys365 commented Aug 22, 2022

If the training works with the single scale parameter (weight), the current one is good enough for me as an initial implementation.

@YoshikiMas, I think you can work with it now. Below are my training curves based on conf/tuning/train_enh_asr_beamformer_fbank_transformer.yaml:

The anechoic 8-channel data is used for training.
The weight for loss_enh is 10.0.
No multi-condition data is used.

mtl_enh_asr

It took 4 days to finish 50 epochs with a single TITAN RTX (24 GB).

I am now training the language model. After that, I can obtain the decoding performance.

Copy link
Contributor

Hi @Emrys365, Great PR!
I think we can marge this PR now, and then we can do some tuning.

I have one major comment that we can consider in the future.
We may think of the multi-speaker case as new ASR or new SE-ASR joint model instead of hacking the enh_s2t model (like we make asr_model and enh_s2t_model separately).
In the future, we should also think of adding a target speaker model or recursive speaker model, and this separated models (or we can even make a new task) could be more flexible.

@Emrys365
Copy link
Collaborator Author

Hi @Emrys365, Great PR! I think we can marge this PR now, and then we can do some tuning.

I have one major comment that we can consider in the future. We may think of the multi-speaker case as new ASR or new SE-ASR joint model instead of hacking the enh_s2t model (like we make asr_model and enh_s2t_model separately). In the future, we should also think of adding a target speaker model or recursive speaker model, and this separated models (or we can even make a new task) could be more flexible.

@sw005320 Thanks for your comments! I totally agree that it is better to split the single-speaker and multi-speaker functions into different scripts.

For the target speaker extraction function, I have an espnet2-based implementation at hand, but it needs more work to tidy up. Anyway, we may need to carefully discuss the overall design some time later.

@mergify mergify bot added the README label Aug 26, 2022
@Emrys365
Copy link
Collaborator Author

I just updated the experimental results in egs2/wsj0_2mix_spatialized/enh_asr1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR Automatic speech recogntion Enhancement Enhancement ESPnet2 README Recipe SE Speech enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0