8000 Ensure contiguous initial state tensors in `_EncoderBase(stateful=True)` by rloganiv · Pull Request #2451 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Ensure contiguous initial state tensors in _EncoderBase(stateful=True) #2451

Merged
merged 10 commits into from
Mar 18, 2019

Conversation

rloganiv
Copy link
Contributor

This PR fixes the following bug:

Describe the bug If subsequent batches of inputs containing a zero-length sequence are passed to a stateful encoder (e.g. a child of _EncoderBase with the stateful parameter set to True) then the following error is raised:

RuntimeError: rnn: hx is not contiguous

To Reproduce

import torch

from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper


lstm = torch.nn.LSTM(input_size=2,
                     hidden_size=2,
                     num_layers=3,
                     dropout=0.1,
                     batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm, stateful=True)
encoder = encoder.cuda(0)

inputs = torch.randn(4, 4, 2).cuda(0)
mask = torch.ones(4, 4).cuda(0)
mask[2, :] = 0

encoder(inputs, mask)
encoder(inputs, mask)

Copy link
Contributor
@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this looks good but can you add your little snippet as a test here

@rloganiv
Copy link
Contributor Author
rloganiv commented Feb 6, 2019

Sure thing! I added checks to the test written for #1493, since it essentially deals with the same problem (only for non-stateful encoders). Also, since I can only provoke the issue on the GPU I added a GPU-only copy of the tests.

@DeNeutoy
Copy link
Contributor

Hi! Sorry could you guard your test using the pytest decorator, rather than just the if statement inside the test: @pytest.mark.skip(torch.cuda.is_available(), "requires cuda") or something?

@DeNeutoy DeNeutoy merged commit 3cdb7e2 into allenai:master Mar 18, 2019
reiyw pushed a commit to reiyw/allennlp that referenced this pull request Nov 12, 2019
…e)` (allenai#2451)

This PR fixes the following bug:

**Describe the bug** If subsequent batches of inputs containing a zero-length sequence are passed to a stateful encoder (e.g. a child of `_EncoderBase` with the `stateful` parameter set to `True`) then the following error is raised:
```
RuntimeError: rnn: hx is not contiguous
```

**To Reproduce**
```python
import torch

from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper


lstm = torch.nn.LSTM(input_size=2,
                     hidden_size=2,
                     num_layers=3,
                     dropout=0.1,
                     batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm, stateful=True)
encoder = encoder.cuda(0)

inputs = torch.randn(4, 4, 2).cuda(0)
mask = torch.ones(4, 4).cuda(0)
mask[2, :] = 0

encoder(inputs, mask)
encoder(inputs, mask)
```
TalSchuster pushed a commit to TalSchuster/allennlp-MultiLang that referenced this pull request Feb 20, 2020
…e)` (allenai#2451)

This PR fixes the following bug:

**Describe the bug** If subsequent batches of inputs containing a zero-length sequence are passed to a stateful encoder (e.g. a child of `_EncoderBase` with the `stateful` parameter set to `True`) then the following error is raised:
```
RuntimeError: rnn: hx is not contiguous
```

**To Reproduce**
```python
import torch

from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper


lstm = torch.nn.LSTM(input_size=2,
                     hidden_size=2,
                     num_layers=3,
                     dropout=0.1,
                     batch_first=True)
encoder = PytorchSeq2SeqWrapper(lstm, stateful=True)
encoder = encoder.cuda(0)

inputs = torch.randn(4, 4, 2).cuda(0)
mask = torch.ones(4, 4).cuda(0)
mask[2, :] = 0

encoder(inputs, mask)
encoder(inputs, mask)
```
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0