This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Implementation of ESIM model #1469
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
3e9b5ae
WIP: ESIM model
matt-peters 8127486
WIP: ESIM model for SNLI
matt-peters 5d775f7
WIP: ESIM
matt-peters 9829160
WIP: ESIM
matt-peters 8c51111
WIP: ESIM
matt-peters 3e1faac
WIP: ESIM
matt-peters 9db872d
ESLM model with ELMo
matt-peters c79bd99
Add a ESIM predictor that works with SNLI formatted files
matt-peters fb4d93f
:Merge branch 'mp/esim' of github.com:matt-peters/allennlp into mp/esim
matt-peters 8c10588
Merge remote-tracking branch 'upstream/master' into mp/esim
matt-peters 2708665
Merge conflict
matt-peters 3f173be
Move ESIM predictor
matt-peters fe1cbd2
Merge branch 'master' into mp/esim2
matt-peters 2b722ce
Clean up
matt-peters fa5e670
Add test for ESIM
matt-peters 3e336a5
Add predictor for ESIM
matt-peters 23bfeca
pylint
matt-peters 1d0c905
pylint
matt-peters 12325be
mypy
matt-peters d9730f4
fix the docs
matt-peters 4f6d37f
ESIM predictor
matt-peters 7b57e42
Add comment to esim training config
matt-peters 7ea3e47
Move InputVariationalDropout
matt-peters 54db604
pylint
matt-peters 9ae74aa
Fix the docs
matt-peters a3cf48d
fix the docs
matt-peters 101a71c
Remove ESIM predictor
matt-peters 9b901f9
Scrub all of ESIMPredictor
matt-peters File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
from typing import Dict, Optional, List, Any | ||
|
||
import torch | ||
|
||
from allennlp.common import Params | ||
from allennlp.common.checks import check_dimensions_match | ||
from allennlp.data import Vocabulary | ||
from allennlp.models.model import Model | ||
from allennlp.modules import FeedForward, InputVariationalDropout | ||
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention | ||
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TextFieldEmbedder | ||
from allennlp.nn import InitializerApplicator, RegularizerApplicator | ||
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum, replace_masked_values | ||
from allennlp.training.metrics import CategoricalAccuracy | ||
|
||
|
||
@Model.register("esim") | ||
class ESIM(Model): | ||
""" | ||
This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference" | ||
<https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_ | ||
by Chen et al., 2017. | ||
|
||
Parameters | ||
---------- | ||
vocab : ``Vocabulary`` | ||
text_field_embedder : ``TextFieldEmbedder`` | ||
Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the | ||
model. | ||
encoder : ``Seq2SeqEncoder`` | ||
Used to encode the premise and hypothesis. | ||
similarity_function : ``SimilarityFunction`` | ||
This is the similarity function used when computing the similarity matrix between encoded | ||
words in the premise and words in the hypothesis. | ||
projection_feedforward : ``FeedForward`` | ||
The feedforward network used to project down the encoded and enhanced premise and hypothesis. | ||
inference_encoder : ``Seq2SeqEncoder`` | ||
Used to encode the projected premise and hypothesis for prediction. | ||
output_feedforward : ``FeedForward`` | ||
Used to prepare the concatenated premise and hypothesis for prediction. | ||
output_logit : ``FeedForward`` | ||
This feedforward network computes the output logits. | ||
dropout : ``float``, optional (default=0.5) | ||
Dropout percentage to use. | ||
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) | ||
Used to initialize the model parameters. | ||
regularizer : ``RegularizerApplicator``, optional (default=``None``) | ||
If provided, will be used to calculate the regularization penalty during training. | ||
""" | ||
def __init__(self, vocab: Vocabulary, | ||
text_field_embedder: TextFieldEmbedder, | ||
encoder: Seq2SeqEncoder, | ||
similarity_function: SimilarityFunction, | ||
projection_feedforward: FeedForward, | ||
inference_encoder: Seq2SeqEncoder, | ||
output_feedforward: FeedForward, | ||
output_logit: FeedForward, | ||
dropout: float = 0.5, | ||
initializer: InitializerApplicator = InitializerApplicator(), | ||
regularizer: Optional[RegularizerApplicator] = None) -> None: | ||
super().__init__(vocab, regularizer) | ||
|
||
self._text_field_embedder = text_field_embedder | ||
self._encoder = encoder | ||
|
||
self._matrix_attention = LegacyMatrixAttention(similarity_function) | ||
self._projection_feedforward = projection_feedforward | ||
|
||
self._inference_encoder = inference_encoder | ||
|
||
if dropout: | ||
self.dropout = torch.nn.Dropout(dropout) | ||
self.rnn_input_dropout = InputVariationalDropout(dropout) | ||
else: | ||
self.dropout = None | ||
self.rnn_input_dropout = None | ||
|
||
self._output_feedforward = output_feedforward | ||
self._output_logit = output_logit | ||
|
||
self._num_labels = vocab.get_vocab_size(namespace="labels") | ||
|
||
check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), | ||
"text field embedding dim", "encoder input dim") | ||
check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(), | ||
"encoder output dim", "projection feedforward input") | ||
check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(), | ||
"proj feedforward output dim", "inference lstm input dim") | ||
|
||
self._accuracy = CategoricalAccuracy() | ||
self._loss = torch.nn.CrossEntropyLoss() | ||
|
||
initializer(self) | ||
|
||
def forward(self, # type: ignore | ||
premise: Dict[str, torch.LongTensor], | ||
hypothesis: Dict[str, torch.LongTensor], | ||
label: torch.IntTensor = None, | ||
metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument | ||
) -> Dict[str, torch.Tensor]: | ||
# pylint: disable=arguments-differ | ||
""" | ||
Parameters | ||
---------- | ||
premise : Dict[str, torch.LongTensor] | ||
From a ``TextField`` | ||
hypothesis : Dict[str, torch.LongTensor] | ||
From a ``TextField`` | ||
label : torch.IntTensor, optional (default = None) | ||
From a ``LabelField`` | ||
metadata : ``List[Dict[str, Any]]``, optional, (default = None) | ||
Metadata containing the original tokenization of the premise and | ||
hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. | ||
|
||
Returns | ||
------- | ||
An output dictionary consisting of: | ||
|
||
label_logits : torch.FloatTensor | ||
A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log | ||
probabilities of the entailment label. | ||
label_probs : torch.FloatTensor | ||
A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the | ||
entailment label. | ||
loss : torch.FloatTensor, optional | ||
A scalar loss to be optimised. | ||
""" | ||
embedded_premise = self._text_field_embedder(premise) | ||
embedded_hypothesis = self._text_field_embedder(hypothesis) | ||
premise_mask = get_text_field_mask(premise).float() | ||
hypothesis_mask = get_text_field_mask(hypothesis).float() | ||
|
||
# apply dropout for LSTM | ||
if self.rnn_input_dropout: | ||
embedded_premise = self.rnn_input_dropout(embedded_premise) | ||
embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) | ||
|
||
# encode premise and hypothesis | ||
encoded_premise = self._encoder(embedded_premise, premise_mask) | ||
encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) | ||
|
||
# Shape: (batch_size, premise_length, hypothesis_length) | ||
similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) | ||
|
||
# Shape: (batch_size, premise_length, hypothesis_length) | ||
p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) | ||
# Shape: (batch_size, premise_length, embedding_dim) | ||
attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) | ||
|
||
|
||
h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) | ||
# Shape: (batch_size, hypothesis_length, embedding_dim) | ||
attended_premise = weighted_sum(encoded_premise, h2p_attention) | ||
|
||
# the "enhancement" layer | ||
premise_enhanced = torch.cat( | ||
[encoded_premise, attended_hypothesis, | ||
encoded_premise - attended_hypothesis, | ||
encoded_premise * attended_hypothesis], | ||
dim=-1 | ||
) | ||
hypothesis_enhanced = torch.cat( | ||
[encoded_hypothesis, attended_premise, | ||
encoded_hypothesis - attended_premise, | ||
encoded_hypothesis * attended_premise], | ||
dim=-1 | ||
) | ||
|
||
# The projection layer down to the model dimension. Dropout is not applied before | ||
# projection. | ||
projected_enhanced_premise = self._projection_feedforward(premise_enhanced) | ||
projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) | ||
|
||
# Run the inference layer | ||
if self.rnn_input_dropout: | ||
projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) | ||
projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) | ||
v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) | ||
v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) | ||
|
||
# The pooling layer -- max and avg pooling. | ||
# (batch_size, model_dim) | ||
v_a_max, _ = replace_masked_values( | ||
v_ai, premise_mask.unsqueeze(-1), -1e7 | ||
).max(dim=1) | ||
v_b_max, _ = replace_masked_values( | ||
v_bi, hypothesis_mask.unsqueeze(-1), -1e7 | ||
).max(dim=1) | ||
|
||
v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( | ||
premise_mask, 1, keepdim=True | ||
) | ||
v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( | ||
hypothesis_mask, 1, keepdim=True | ||
) | ||
|
||
# Now concat | ||
# (batch_size, model_dim * 2 * 4) | ||
v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) | ||
|
||
# the final MLP -- apply dropout to input, and MLP applies to output & hidden | ||
if self.dropout: | ||
v_all = self.dropout(v_all) | ||
|
||
output_hidden = self._output_feedforward(v_all) | ||
label_logits = self._output_logit(output_hidden) | ||
label_probs = torch.nn.functional.softmax(label_logits, dim=-1) | ||
|
||
output_dict = {"label_logits": label_logits, "label_probs": label_probs} | ||
|
||
if label is 2E18 not None: | ||
loss = self._loss(label_logits, label.long().view(-1)) | ||
self._accuracy(label_logits, label) | ||
output_dict["loss"] = loss | ||
|
||
return output_dict | ||
|
||
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | ||
return {'accuracy': self._accuracy.get_metric(reset)} | ||
|
||
@classmethod | ||
def from_params(cls, vocab: Vocabulary, params: Params) -> 'ESIM': | ||
embedder_params = params.pop("text_field_embedder") | ||
text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params) | ||
|
||
encoder = Seq2SeqEncoder.from_params(params.pop("encoder")) | ||
similarity_function = SimilarityFunction.from_params(params.pop("similarity_function")) | ||
projection_feedforward = FeedForward.from_params(params.pop('projection_feedforward')) | ||
inference_encoder = Seq2SeqEncoder.from_params(params.pop("inference_encoder")) | ||
output_feedforward = FeedForward.from_params(params.pop('output_feedforward')) | ||
output_logit = FeedForward.from_params(params.pop('output_logit')) | ||
initializer = InitializerApplicator.from_params(params.pop('initializer', [])) | ||
regularizer = RegularizerApplicator.from_params(params.pop('regularizer', [])) | ||
|
||
dropout = params.pop("dropout", 0) | ||
|
||
params.assert_empty(cls.__name__) | ||
return cls(vocab=vocab, | ||
text_field_embedder=text_field_embedder, | ||
encoder=encoder, | ||
similarity_function=similarity_function, | ||
projection_feedforward=projection_feedforward, | ||
inference_encoder=inference_encoder, | ||
output_feedforward=output_feedforward, | ||
output_logit=output_logit, | ||
dropout=dropout, | ||
initializer=initializer, | ||
regularizer=regularizer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
|
||
class InputVariationalDropout(torch.nn.Dropout): | ||
""" | ||
Apply the dropout technique in Gal and Ghahramani, "Dropout as a Bayesian Approximation: | ||
Representing Model Uncertainty in Deep Learning" (https://arxiv.org/abs/1506.02142) to a | ||
3D tensor. | ||
|
||
This module accepts a 3D tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` | ||
and samples a single dropout mask of shape ``(batch_size, embedding_dim)`` and applies | ||
it to every time step. | ||
""" | ||
def forward(self, input_tensor): | ||
# pylint: disable=arguments-differ | ||
""" | ||
Apply dropout to input tensor. | ||
|
||
Parameters | ||
---------- | ||
input_tensor: ``torch.FloatTensor`` | ||
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` | ||
|
||
Returns | ||
------- | ||
output: ``torch.FloatTensor`` | ||
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` with dropout applied. | ||
""" | ||
input_tensor.shape[-1]) | ||
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False) | ||
if self.inplace: | ||
input_tensor *= dropout_mask.unsqueeze(1) | ||
return None | ||
else: | ||
return dropout_mask.unsqueeze(1) * input_tensor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
marginally more informative variable names would be better here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These names follow the notation in the original paper and are helpful if someone wanted to align the code with equations.