8000 Add necessary implicit embedding extension for transfer-modules api and vocab extension by HarshTrivedi · Pull Request #2431 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add necessary implicit embedding extension for transfer-modules api and vocab extension #2431

Merged
merged 25 commits into from
Feb 12, 2019

Conversation

HarshTrivedi
Copy link
Contributor

Please don't review this yet; I still need to change some code and add tests. This is a follow-up on #2374 and #2387, but I will probably get back after #2395 gets finalized.

@HarshTrivedi
Copy link
Contributor Author

Okay, this is up for review. @joelgrus, @matt-gardner

Two main things changed here:

  1. If vocab extension is ON during training, the embedding extension should also happen implicitly. Or else, it will raise error. This is same as implicitly doing embedding extension when fine-tuning with vocab extension is ON.
  2. If one is using transfer-modules api and has transferred embeding/text_field_embedder (eg. from old-archive) with Vocab extension ON, loading new-archive currently doesn't work. This is because the new-archive statedict would have extended embedding but from_params would yet load the embedding from old-archive before copying the state dict from new-archive model. So it's necessary to take precautionary embedding extension before copying state-dict.

Other change: we need to make sure embedding-extension is no-op unless we are sure (eg. it's incorrect to default to "tokens" namespace). Incorrect implicit assumption with above changes can make some tests fails.

@HarshTrivedi HarshTrivedi changed the title [WIP] Fix loading of model which was vocab + embedding extended. Add necessary implicit embedding extension for transfer-modules api and vocab extension Feb 7, 2019
Copy link
Contributor
@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of minor questions; otherwise looks very good, thanks for the PR!

logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")
# It's not safe to default to 'tokens' when we aren't sure that 'tokens'
# need to be extended. (Without this, several tests fail.)
logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Extension will be no-op'.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what circumstances will this actually emit a warning? Will almost everyone that loads or trains a model in practice see this warning? If so, it should be at the info level (or even debug).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning won't be seen for any models trained after #2374 because _vocab_namespace is stored. For previously trained models, it would almost always be seen because extend_vocab call is implicit now. Will change it to info level.

return

extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
if extended_num_embeddings <= self.num_embeddings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be ==, not <=?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if vocab and embedding are already in sync it should be ==. I had kept <= as a precaution against incorrect vocab namespace. But on second thought, if user passed an incorrect vocab namespace, it's better to raise error than a silent no-op. Will change it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth explicitly raising error for < case? It's only possible with user explicitly passed an incorrect vocab namespace that's smaller than embedding itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I have separated == and < cases, making no-op in first and raising configuration error in second.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matt-gardner Correcting this revealed a subtle issue: defaulting to tokens and token_characters namepsace can be problematic, when num_embeddings was used instead of vocab namespace to decide embedding size. Fixed this in last commit.

This is up for another look now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. oops, didn't realize you already reviewed again!

Copy link
Contributor
@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor wording tweaks, and this is good to merge. Thanks for the PR, for this and all of the related functionality! I think it's turned out quite nicely.

logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")
# It's not safe to default to 'tokens' when we aren't sure that 'tokens'
# need to be extended. (Without this, several tests fail.)
logging.info("No vocab_namespace provided to Embedder.extend_vocab. Extension will be no-op'.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more obvious, I'd recommend a message like "Loading a model trained before embedding extension was implemented; pass an explicit vocab namespace if you want to extend the vocabulary."

vocab_namespace = "tokens"
logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")
# It's not safe to default to 'tokens' when we aren't sure that 'tokens'
# need to be extended. (Without this, several tests fail.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to reference failing tests in comments in the code - just give the justification (that it's not safe to default to "tokens").

@HarshTrivedi
Copy link
Contributor Author

@matt-gardner, Thanks for the review! This should be good to merge now.

@matt-gardner matt-gardner merged commit 39413f2 into allenai:master Feb 12, 2019
@HarshTrivedi HarshTrivedi deleted the fix-embedding-extension-load branch February 12, 2019 03:36
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0