8000 [Bug] BCELoss should not be masked · Issue #1192 · coqui-ai/TTS · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Bug] BCELoss should not be masked #1192
Closed
@iamanigeeit

Description

@iamanigeeit

I have trained Tacotron2 but during eval / inference, it often doesn't know when to stop decoding. This is a known issue in seq2seq models and i was trying to solve it in TensorFlowTTS when i gave up due to Tensorflow problems.

Training with enable_bos_eos=True helps a bit but the output is still 3x the ground truth mel length for shorter audio: see length_data_eos.csv vs length_data_no_eos.csv

One reason is the BCELossMasked criterion -- in its current form, it encourages the model never to stop decoding once it has passed mel_length. Some of the loss results don't quite make sense, as seen below:

import torch
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
    return mask

from torch.nn import functional
length = torch.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = torch.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float()  # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100  # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = torch.zeros(target.shape) - 100.  # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100.  # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100.  # simulate logits on late stopping

# if we mask
>>> functional.binary_cross_entropy_with_logits(mask * true_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(3.4657)  # Should be zero! It's not zero because of trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * zero_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)
>>> functional.binary_cross_entropy_with_logits(mask * late_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)  # Stopping late should be better than not stopping at all. Again due to trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * early_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(203.4657)  # Early stopping should be worse than late stopping because the audio will be cut

# if we don't mask
>>> functional.binary_cross_entropy_with_logits(true_x, target, pos_weight=pos_weight, reduction='sum')
tensor(0.)  # correct
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=pos_weight, reduction='sum')
tensor(3000.)  # correct
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=pos_weight, reduction='sum')
tensor(1000.)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=pos_weight, reduction='sum')
tensor(200.)  # still wrong

# pos_weight should be < 1 to penalize early stopping
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(120.0000)
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(40.0000)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(200.)  # correct

For now i am passing length=None to avoid the mask and setting pos_weight=0.2 to experiment. Will update the training results.

Additional context

I would also propose renaming stop_tokens to either stop_probs or stop_logits depending on context. Currently, inference() produces stop_tokens that represent stop probabilities, while forward() produces the logits before sigmoid. Confusingly, both are called stop_tokens.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0