10000 Implementation of ESIM model by matt-peters · Pull Request #1469 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Implementation of ESIM model #1469

Merged
merged 28 commits into from
Jul 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3e9b5ae
WIP: ESIM model
matt-peters May 3, 2018
8127486
WIP: ESIM model for SNLI
matt-peters May 3, 2018
5d775f7
WIP: ESIM
matt-peters May 3, 2018
9829160
WIP: ESIM
matt-peters May 3, 2018
8c51111
WIP: ESIM
matt-peters May 3, 2018
3e1faac
WIP: ESIM
matt-peters May 4, 2018
9db872d
ESLM model with ELMo
matt-peters May 7, 2018
c79bd99
Add a ESIM predictor that works with SNLI formatted files
matt-peters May 14, 2018
fb4d93f
:Merge branch 'mp/esim' of github.com:matt-peters/allennlp into mp/esim
matt-peters May 14, 2018
8c10588
Merge remote-tracking branch 'upstream/master' into mp/esim
matt-peters May 30, 2018
2708665
Merge conflict
matt-peters May 30, 2018
3f173be
Move ESIM predictor
matt-peters May 30, 2018
fe1cbd2
Merge branch 'master' into mp/esim2
matt-peters Jul 9, 2018
2b722ce
Clean up
matt-peters Jul 9, 2018
fa5e670
Add test for ESIM
matt-peters Jul 9, 2018
3e336a5
Add predictor for ESIM
matt-peters Jul 9, 2018
23bfeca
pylint
matt-peters Jul 9, 2018
1d0c905
pylint
matt-peters Jul 9, 2018
12325be
mypy
matt-peters Jul 9, 2018
d9730f4
fix the docs
matt-peters Jul 9, 2018
4f6d37f
ESIM predictor
matt-peters Jul 9, 2018
7b57e42
Add comment to esim training config
matt-peters Jul 9, 2018
7ea3e47
Move InputVariationalDropout
matt-peters Jul 9, 2018
54db604
pylint
matt-peters Jul 9, 2018
9ae74aa
Fix the docs
matt-peters Jul 9, 2018
a3cf48d
fix the docs
matt-peters Jul 9, 2018
101a71c
Remove ESIM predictor
matt-peters Jul 9, 2018
9b901f9
Scrub all of ESIMPredictor
matt-peters Jul 9, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from allennlp.models.semantic_parsing.wikitables.wikitables_erm_semantic_parser import WikiTablesErmSemanticParser
from allennlp.models.semantic_role_labeler import SemanticRoleLabeler
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.models.esim import ESIM
248 changes: 248 additions & 0 deletions allennlp/models/esim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import Dict, Optional, List, Any

import torch

from allennlp.common import Params
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward, InputVariationalDropout
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TextFieldEmbedder
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum, replace_masked_values
from allennlp.training.metrics import CategoricalAccuracy


@Model.register("esim")
class ESIM(Model):
"""
This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference"
<https://www.semanticscholar.org/paper/Enhanced-LSTM-for-Natural-Language-Inference-Chen-Zhu/83e7654d545fbbaaf2328df365a781fb67b841b4>`_
by Chen et al., 2017.

Parameters
----------
vocab : ``Vocabulary``
text_field_embedder : ``TextFieldEmbedder``
Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the
model.
encoder : ``Seq2SeqEncoder``
Used to encode the premise and hypothesis.
similarity_function : ``SimilarityFunction``
This is the similarity function used when computing the similarity matrix between encoded
words in the premise and words in the hypothesis.
projection_feedforward : ``FeedForward``
The feedforward network used to project down the encoded and enhanced premise and hypothesis.
inference_encoder : ``Seq2SeqEncoder``
Used to encode the projected premise and hypothesis for prediction.
output_feedforward : ``FeedForward``
Used to prepare the concatenated premise and hypothesis for prediction.
output_logit : ``FeedForward``
This feedforward network computes the output logits.
dropout : ``float``, optional (default=0.5)
Dropout percentage to use.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self, vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
similarity_function: SimilarityFunction,
projection_feedforward: FeedForward,
inference_encoder: Seq2SeqEncoder,
output_feedforward: FeedForward,
output_logit: FeedForward,
dropout: float = 0.5,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:
super().__init__(vocab, regularizer)

self._text_field_embedder = text_field_embedder
self._encoder = encoder

self._matrix_attention = LegacyMatrixAttention(similarity_function)
self._projection_feedforward = projection_feedforward

self._inference_encoder = inference_encoder

if dropout:
self.dropout = torch.nn.Dropout(dropout)
self.rnn_input_dropout = InputVariationalDropout(dropout)
else:
self.dropout = None
self.rnn_input_dropout = None

self._output_feedforward = output_feedforward
self._output_logit = output_logit

self._num_labels = vocab.get_vocab_size(namespace="labels")

check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
"text field embedding dim", "encoder input dim")
check_dimensions_match(encoder.get_output_dim() * 4, projection_feedforward.get_input_dim(),
"encoder output dim", "projection feedforward input")
check_dimensions_match(projection_feedforward.get_output_dim(), inference_encoder.get_input_dim(),
"proj feedforward output dim", "inference lstm input dim")

self._accuracy = CategoricalAccuracy()
self._loss = torch.nn.CrossEntropyLoss()

initializer(self)

def forward(self, # type: ignore
premise: Dict[str, torch.LongTensor],
hypothesis: Dict[str, torch.LongTensor],
label: torch.IntTensor = None,
metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument
) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
----------
premise : Dict[str, torch.LongTensor]
From a ``TextField``
hypothesis : Dict[str, torch.LongTensor]
From a ``TextField``
label : torch.IntTensor, optional (default = None)
From a ``LabelField``
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
Metadata containing the original tokenization of the premise and
hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

Returns
-------
An output dictionary consisting of:

label_logits : torch.FloatTensor
A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
probabilities of the entailment label.
label_probs : torch.FloatTensor
A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
entailment label.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
"""
embedded_premise = self._text_field_embedder(premise)
embedded_hypothesis = self._text_field_embedder(hypothesis)
premise_mask = get_text_field_mask(premise).float()
hypothesis_mask = get_text_field_mask(hypothesis).float()

# apply dropout for LSTM
if self.rnn_input_dropout:
embedded_premise = self.rnn_input_dropout(embedded_premise)
embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

# encode premise and hypothesis
encoded_premise = self._encoder(embedded_premise, premise_mask)
encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

# Shape: (batch_size, premise_length, hypothesis_length)
similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

# Shape: (batch_size, hypothesis_length, premise_length)
h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(encoded_premise, h2p_attention)

# the "enhancement" layer
premise_enhanced = torch.cat(
[encoded_premise, attended_hypothesis,
encoded_premise - attended_hypothesis,
encoded_premise * attended_hypothesis],
dim=-1
)
hypothesis_enhanced = torch.cat(
[encoded_hypothesis, attended_premise,
encoded_hypothesis - attended_premise,
encoded_hypothesis * attended_premise],
dim=-1
)

# The projection layer down to the model dimension. Dropout is not applied before
# projection.
projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

# Run the inference layer
if self.rnn_input_dropout:
projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

marginally more informative variable names would be better here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These names follow the notation in the original paper and are helpful if someone wanted to align the code with equations.

v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

# The pooling layer -- max and avg pooling.
# (batch_size, model_dim)
v_a_max, _ = replace_masked_values(
v_ai, premise_mask.unsqueeze(-1), -1e7
).max(dim=1)
v_b_max, _ = replace_masked_values(
v_bi, hypothesis_mask.unsqueeze(-1), -1e7
).max(dim=1)

v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
premise_mask, 1, keepdim=True
)
v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
hypothesis_mask, 1, keepdim=True
)

# Now concat
# (batch_size, model_dim * 2 * 4)
v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

# the final MLP -- apply dropout to input, and MLP applies to output & hidden
if self.dropout:
v_all = self.dropout(v_all)

output_hidden = self._output_feedforward(v_all)
label_logits = self._output_logit(output_hidden)
label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

output_dict = {"label_logits": label_logits, "label_probs": label_probs}

if label is 2E18 not None:
loss = self._loss(label_logits, label.long().view(-1))
self._accuracy(label_logits, label)
output_dict["loss"] = loss

return output_dict

def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {'accuracy': self._accuracy.get_metric(reset)}

@classmethod
def from_params(cls, vocab: Vocabulary, params: Params) -> 'ESIM':
embedder_params = params.pop("text_field_embedder")
text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)

encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
similarity_function = SimilarityFunction.from_params(params.pop("similarity_function"))
projection_feedforward = FeedForward.from_params(params.pop('projection_feedforward'))
inference_encoder = Seq2SeqEncoder.from_params(params.pop("inference_encoder"))
output_feedforward = FeedForward.from_params(params.pop('output_feedforward'))
output_logit = FeedForward.from_params(params.pop('output_logit'))
initializer = InitializerApplicator.from_params(params.pop('initializer', []))
regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))

dropout = params.pop("dropout", 0)

params.assert_empty(cls.__name__)
return cls(vocab=vocab,
text_field_embedder=text_field_embedder,
encoder=encoder,
similarity_function=similarity_function,
projection_feedforward=projection_feedforward,
inference_encoder=inference_encoder,
output_feedforward=output_feedforward,
output_logit=output_logit,
dropout=dropout,
initializer=initializer,
regularizer=regularizer)
1 change: 1 addition & 0 deletions allennlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from allennlp.modules.token_embedders import TokenEmbedder, Embedding
from allennlp.modules.matrix_attention import MatrixAttention
from allennlp.modules.attention import Attention
from allennlp.modules.input_variational_dropout import InputVariationalDropout
34 changes: 34 additions & 0 deletions allennlp/modules/input_variational_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

class InputVariationalDropout(torch.nn.Dropout):
"""
Apply the dropout technique in Gal and Ghahramani, "Dropout as a Bayesian Approximation:
Representing Model Uncertainty in Deep Learning" (https://arxiv.org/abs/1506.02142) to a
3D tensor.

This module accepts a 3D tensor of shape ``(batch_size, num_timesteps, embedding_dim)``
and samples a single dropout mask of shape ``(batch_size, embedding_dim)`` and applies
it to every time step.
"""
def forward(self, input_tensor):
# pylint: disable=arguments-differ
"""
Apply dropout to input tensor.

Parameters
----------
input_tensor: ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)``

Returns
-------
output: ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` with dropout applied.
"""
input_tensor.shape[-1])
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
if self.inplace:
input_tensor *= dropout_mask.unsqueeze(1)
return None
else:
return dropout_mask.unsqueeze(1) * input_tensor
17 changes: 16 additions & 1 deletion allennlp/nn/initiali C5EA zers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ def block_orthogonal(tensor: torch.Tensor,
data[block_slice] = torch.nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain)


def zero(tensor: torch.Tensor) -> None:
return tensor.data.zero_()

def lstm_hidden_bias(tensor: torch.Tensor) -> None:
"""
Initialize the biases of the forget gate to 1, and all other gates to 0,
following Jozefowicz et al., An Empirical Exploration of Recurrent Network Architectures
"""
# gates are (b_hi|b_hf|b_hg|b_ho) of shape (4*hidden_size)
tensor.data.zero_()
hidden_size = tensor.shape[0] // 4
tensor.data[hidden_size:(2 * hidden_size)] = 1.0

def _initializer_wrapper(init_function: Callable[..., None]) -> Type[Initializer]:
class Init(Initializer):
_initializer_wrapper = True
Expand Down Expand Up @@ -176,7 +189,9 @@ def from_params(cls, params: Params):
"sparse": _initializer_wrapper(torch.nn.init.sparse_),
"eye": _initializer_wrapper(torch.nn.init.eye_),
"block_orthogonal": _initializer_wrapper(block_orthogonal),
"uniform_unit_scaling": _initializer_wrapper(uniform_unit_scaling)
"uniform_unit_scaling": _initializer_wrapper(uniform_unit_scaling),
"zero": _initializer_wrapper(zero),
"lstm_hidden_bias": _initializer_wrapper(lstm_hidden_bias),
}


Expand Down
Loading
0