Description
Describe the bug
In the case when there are fewer than num_spans_to_keep
total spans in the original text, some padding makes its way into the top_span_scores
output of SpanPruner
with scores of -inf
. Even though the top_spans_mask
output is correct, this is a problem because multiplying the scores by the mask produces nan
in those slots instead of the desired 0.0
.
To Reproduce
In python REPL:
import torch
from allennlp.modules.span_pruner import SpanPruner
emb = torch.ones([1, 2, 1]) # batch size 1, 2 spans, embedding size 1
scorer = torch.nn.Linear(1, 1)
mask = torch.tensor([1, 0]).view(1, 2).float() # only 1 span is present in the instance
pruner = SpanPruner(scorer)
_, _, _, scores = pruner(emb, mask, 2)
print(scores)
For me, outputs:
tensor([[[ 0.5783],
[ -inf]]])
though of course the non-inf number is arbitrary.
Expected behavior
I think in this case we should replace the -inf
s with -1
. Because of this issue I had a loss of nan
that I had to debug until I found this. It should be an easy fix in SpanPruner
. BTW, there's nothing particular to spans in SpanPruner
, is there? Might as well just call it Pruner
, right?
System (please complete the following information):
- OS: Reproduced on mac OS 10.13.3 on CPU as well as Ubuntu 16.04.4 LTS on GPU (V100).
- Python version: 3.6.5 and 3.6.6.
- AllenNLP version: v0.5.0, but looking at the source code on master I assume this remains an issue.
- PyTorch version: 0.4.0
Additional context
My guess is this hasn't come up before because span pruning was only used with long texts where this doesn't ever happen. It came up for me because I'm using span pruning with a sentence-level model where a few of the sentences have only 2 tokens and are batched with 4-token sentences.