8000 Resize T5 Vocab by dirkgr · Pull Request #5497 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Resize T5 Vocab #5497

Merged
merged 7 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
8000 Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added a way to resize the vocabulary in the T5 module

### Fixed

- Fixed the docstring information for the `FBetaMultiLabelMeasure` metric.
Expand Down
80 changes: 75 additions & 5 deletions allennlp/modules/transformer/t5.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
""" # noqa: E401

import logging
from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING, NamedTuple

from typing import Optional, Tuple, List, Union, Dict, TYPE_CHECKING, NamedTuple, Callable

import torch
from torch import nn
Expand Down Expand Up @@ -428,6 +427,36 @@ def get_head_mask(head_mask: Optional[torch.BoolTensor], num_hidden_layers: int)
head_mask = [None] * num_hidden_layers
return head_mask

def resize_token_embeddings(
self, new_size: int, *, init_fn: Callable = torch.nn.init.normal_
) -> None:
old_size, embedding_dim = tuple(self.token_embeddings.weight.shape)
if old_size == new_size:
return
if old_size > new_size:
logger.warning(
"Shrinking vocabulary from size %d to size %d. This is probably not what you want?",
old_size,
new_size,
)

result = torch.nn.Embedding(
new_size,
embedding_dim,
self.token_embeddings.padding_idx,
self.token_embeddings.max_norm,
self.token_embeddings.norm_type,
self.token_embeddings.scale_grad_by_freq,
self.token_embeddings.sparse,
device=self.token_embeddings.weight.device,
dtype=self.token_embeddings.weight.dtype,
)
copy_size = min(old_size, new_size)
result.weight.data[:copy_size, ...] = self.token_embeddings.weight.data[:copy_size, ...]
if new_size > old_size:
init_fn(result.weight.data[copy_size:, ...])
self.token_embeddings = result

def forward(
self,
input_ids: Optional[torch.IntTensor] = None,
Expand Down Expand Up @@ -759,8 +788,8 @@ class T5(TransformerModule, Registrable):
def __init__(
self,
token_embeddings: Optional[nn.Embedding] = None,
encoder: Lazy[T5EncoderStack] = Lazy(T5EncoderStack),
decoder: Lazy[T5DecoderStack] = Lazy(T5DecoderStack),
encoder: Lazy[T5EncoderStack] = Lazy(T5EncoderStack.basic_encoder),
decoder: Lazy[T5DecoderStack] = Lazy(T5DecoderStack.basic_decoder),
decoder_start_token_id: int = 0,
pad_token_id: int = 0, # These are both 0 in t5-(small|base|large). Go figure.
eos_token_id: int = 1,
Expand Down Expand Up @@ -806,6 +835,47 @@ def __init__(

self.beam_search = beam_search.construct(end_index=self.eos_token_id)

def resize_token_embeddings(
self, new_size: int, *, init_fn: Callable = torch.nn.init.normal_
) -> None:
"""
Resizes the token embeddings in the model.

This takes care of the token embeddings for the encoder, the decoder, and the LM head.

new_size : `int`
The new size of the token embeddings
init_fn : `Callable`
The function to use to initialize new embeddings. This function will be called with a
single argument, the tensor to initialize, and it is expected to initialize the tensor
in place. Many of the functions from `torch.nn.init` fit.
"""
self.encoder.resize_token_embeddings(new_size, init_fn=init_fn)
# If encoder and decoder share embeddings, this is a no-op the second time.
self.decoder.resize_token_embeddings(new_size, init_fn=init_fn)

10000 # resize lm head
old_size = self.lm_head.out_features
if old_size == new_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing: Maybe we can do this check first thing? It'll avoid the 2 empty calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps not? We want to resize the embedding dim and the output dim to the same new_size. Will the old embedding_dim and the old lm_head.out_features be different?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of defensive programming. It should never happen, but if the three places where the size of the token embedding matters get out of sync, this call will sync them all back up.

return
new_lm_head = torch.nn.Linear(
self.lm_head.in_features,
new_size,
self.lm_head.bias,
self.lm_head.weight.device,
self.lm_head.weight.dtype,
)
copy_size = min(old_size, new_size)
new_lm_head.weight.data[:copy_size, ...] = self.lm_head.weight.data[:copy_size, ...]
if self.lm_head.bias and new_lm_head.bias:
new_lm_head.bias.data[:copy_size, ...] = self.lm_head.bias[:copy_size, ...]
if new_size > old_size:
init_fn(new_lm_head.weight.data[copy_size:, ...])
if new_lm_head.bias:
init_fn(new_lm_head.bias[copy_size:, ...])

self.lm_head = new_lm_head

def _post_load_state_dict(
self, missing_keys: List[str], unexpected_keys: List[str]
) -> Tuple[List[str], List[str]]:
Expand Down Expand Up @@ -954,7 +1024,7 @@ def forward(
logits = self._get_lm_logits(decoder_outputs.last_hidden_state) # type: ignore[union-attr]

# Shape: (1,)
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.to(torch.long).view(-1))
elif self.training:
raise ValueError("'labels' required during training")

Expand Down
14 changes: 14 additions & 0 deletions tests/modules/transformer/t5_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import torch
from transformers.models import t5 as hf_t5

from allennlp.modules.transformer.t5 import T5
Expand Down Expand Up @@ -135,3 +136,16 @@ def _test_distributed_load_state_dict(global_rank, world_size, gpu_id):
@requires_multi_gpu
def test_distributed_load_state_dict():
run_distributed_test([0, 1], func=_test_distributed_load_state_dict)


@pytest.mark.parametrize("tie_word_embeddings", [True, False])
def test_t5_resize_token_embeddings(model: T5, tie_word_embeddings: bool):
module = T5(tie_word_embeddings=tie_word_embeddings)

labels = torch.IntTensor([[1, 2, 3]])
module(torch.IntTensor([[129, 130, 131]]), labels=labels)
module.resize_token_embeddings(128)
with pytest.raises(IndexError):
module(torch.IntTensor([[129, 130, 131]]), labels=labels)
module.resize_token_embeddings(1024)
module(torch.IntTensor([[129, 130, 131]]), labels=labels)
0