8000 add diff command by epwalsh · Pull Request #5109 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

add diff command #5109

Merged
merged 23 commits into from
May 7, 2021
Merged

add diff command #5109

merged 23 commits into from
May 7, 2021

Conversation

epwalsh
Copy link
Member
@epwalsh epwalsh commented Apr 9, 2021

Adds a diff command for comparing two arbitrary model checkpoints. This is analogous to git diff, except that instead of comparing lines between two files, we are comparing (key, tensor) pairs from the checkpoints' state dictionaries.

Example

For example, run this to compare the pre-trained RoBERTa large weights to a version that has been fine-tuned on the SQuAD 2.0 task:

allennlp diff \
      hf://_/roberta-large/pytorch_model.bin \
      https://storage.googleapis.com/allennlp-public-models/transformer-qa-2020-10-03.tar.gz \
      --strip-prefix-1 'roberta.' \
      --strip-prefix-2 '_text_field_embedder.token_embedder_tokens.transformer_model.'

image

The △ = x.xxx means a tensor was modified by this amount (this is 2-norm distance between the corresponding tensor from the first checkpoint and the one from the second).

Technical details

The diff function uses a modified version of the default git diff algorithm: the Myers algorithm. The only real modification between our algorithm and the original Myers algorithm is that in addition to the "keep", "insert", and "remove" operations, we also consider a "modify" operation. This operation corresponds to keeping a parameter with the same key and shape, but modifying its weights.

Road map

  • Automatically handle other input path or non-path types, like AllenNLP archives, HF model names, etc.
  • Add an integration test.

Outstanding questions

  • Is this useful?
  • What can we change or add to the output to make this more helpful?
    For example, maybe we add an option to only show modified parameters and sort by the modification distance.
  • What kinds of visualizations can we create on top of this?

@epwalsh
Copy link
Member Author
epwalsh commented Apr 9, 2021

Pinging @AkshitaB, @ArjunSubramonian, @jacob-morrison, and @dirkgr for feedback.

Copy link
Contributor
@ArjunSubramonian ArjunSubramonian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments! Overall, looks great :D

return subparser


def load_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function not present anywhere else in the AllenNLP library? It might be better to import it than copy it here, to help with maintenance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. It should definitely be elsewhere.

distance: float

def display(self):
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't think that the L2 distance is too insightful, especially when all the tensors are so high-dimensional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I can't think of a better single metric to succinctly describe how different two tensors are. If we have such a metric, I don't personally think that sorting by that metric is helpful. I would much rather prefer to see the differing parameters in the order in which they are used during forward propagation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a heatmap of the element-wise differences of 2D parameters could be very useful. I don't know how to extend this well to more than 2 dimensions though, perhaps aggregate over the channel dimension?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we flatten the tensors, then bin the weights into a manageable number of bins, then do something with those bins, like show a bar chart of the distance between bins of two tensors?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start.

Copy link
Member Author
@epwalsh epwalsh Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking over the example I gave, there is certainly a strong correlation with the number of elements in a modified parameter and the corresponding Euclidean distance, suggesting this is not the best metric to use.

That said, I'm pretty sure any other L*-based metric (even L∞) would suffer from the same correlation unless we add an additional size-based normalization term.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way, the exact metric to use should be a configurable option IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just normalizing by $\sqrt{n}$ is good enough. So then we are really doing

$\sqrt{ \frac{1}{n} \sum_{i=0}^n (x_i - y_i)^2 }$

Or the square root of the mean squared "error". Is this meaningful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are y'all running plugins that render LaTex properly in GitHub?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Human renderer plugin.

Copy link
Member
@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is quite a good feature. I wish we could measure how much people use it.

subparser.add_argument(
"checkpoint1",
type=str,
help="""The URL or path to the first PyTorch checkpoint file (e.g. '.pt' or '.bin').""",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if these could also point to model archives.


out[new_key] = state[key]

if strip_prefix_used is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't do !strip_prefix_used anymore? Did I miss a memo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference here is that strip_prefix_used could be None.

distance: float

def display(self):
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start.

Comment on lines +209 to +210
Uses a modified version of the Myers diff algorithm to compute a representation
of the diff between two model state dictionaries.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that much about Myers, but isn't that only necessary if the order matters? Does the order of entries in the state_dict matter? I thought that's just alphabetical?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order is meaningful. It is the order in which the corresponding modules were registered. Generally this is the order of data flow.

b_tensor = state_dict_b[step.key]
with torch.no_grad():
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()
if dist != 0.0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we worry about loss of precision here? I.e. maybe we want to check that it's within some small threshold of 0, not exactly 0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The threshold could be a configurable parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point

@epwalsh epwalsh marked this pull request as ready for review April 14, 2021 22:33
@epwalsh
Copy link
Member Author
epwalsh commented Apr 14, 2021

@dirkgr @ArjunSubramonian, this is ready for another review.

Now handles different input types: checkpoint files, archives, huggingface model IDs.

Copy link
Member
@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I'm just wondering if we can not have checkpoint types at all once we have native HF hub integration.

Comment on lines 238 to 243
elif checkpoint_type == "huggingface":
from transformers.file_utils import (
hf_bucket_url,
WEIGHTS_NAME,
cached_path as hf_cached_path,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this go away when #5052 is merged? If so, I'd rather wait. #5052 looks like it's almost ready for merging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, if that is so, the whole notion of a checkpoint type can go away, right?

b_tensor = state_dict_b[step.key]
with torch.no_grad():
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()
if dist != 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon.

-c, shape = (3,)
+b, shape = (4,)
+d, shape = (3,)
!e, shape = (3,), difference = 0.5774
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a small comment that explains how you computed 0.5774?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@epwalsh
Copy link
Member Author
epwalsh commented Apr 15, 2021

Yea @dirkgr, I'll hold off until the hub PR is merged

a_tensor = state_dict_a[step.key]
b_tensor = state_dict_b[step.key]
with torch.no_grad():
dist = (scale * torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, is there a reason you opted for sqrt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, two reasons:

  1. It gives you another interpretation of the formula, which is L2 norm, normalized by the sqrt of the number of elements in the tensor, and
  2. It seems to put the numbers on a scale that is easier to look at

type=str,
help="""a prefix to remove from all of the 2nd checkpoint's keys.""",
)
subparser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't completely understand the purpose of the scale arg, could you explain its necessity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only show a limited number of decimal places (4) in the output, so if you care about differences on the order of 1e-5 or smaller, you may want to increase the --scale parameter by an order of magnitude or more. Does that make sense? Should I include that in the help string?

@epwalsh
Copy link
Member Author
epwalsh commented Apr 21, 2021

I had to make a small change/fix to how we deal with hf:// URLs to make it work with files like roberta-large/pytorch_model.bin that are not associated with a user or organization.

cached_path("hf://roberta-large/pytorch_model.bin") fails with a 404. Instead this proposes we add an extra "/" or "_/" after "hf://" to avoid ambiguity, so you'd do cached_path("hf://_/roberta-large/pytorch_model.bin"). Note this is only necessary when there is no user/org associated with the identifier. I'm not crazy about this solution, so open to better alternatives.

@LysandreJik, @dirkgr

@dirkgr
Copy link
Member
dirkgr commented Apr 21, 2021

Is there any way we can make hf://roberta-large/pytorch_model.bin work? Why is that difficult?

@epwalsh
Copy link
Member Author
epwalsh commented Apr 21, 2021

Is there any way we can make hf://roberta-large/pytorch_model.bin work? Why is that difficult?

Strictly speaking, hf://roberta-large/pytorch_model.bin is ambiguous. It could mean:

  1. the file pytorch_model.bin in the roberta-large repository, which is the intent of course, but it could also mean
  2. the repository pytorch_model.bin (repos are allowed to contain periods) owned by the user/organization called roberta-large.

@dirkgr
Copy link
Member
dirkgr commented Apr 21, 2021

We could try 2, and if it fails, fall back to 1.

@epwalsh
Copy link
Member Author
epwalsh commented Apr 21, 2021

We could try 2, and if it fails, fall back to 1.

I think it would be better to try 1 first. If we try 2 first, this creates a vulnerability:

Say you have a production system that is downloading hf://roberta-large/pytorch_model.bin and getting the result from (1), as expected, because (2) doesn't exist. But then some joker creates a fake user/org account called "roberta-large" with a repo called "pytorch_model.bin". Then your system starts failing because it's downloading this junk repo (2), instead of (1).

@dirkgr
Copy link
Member
dirkgr commented Apr 21, 2021

Excellent point!

@LysandreJik
Copy link
Contributor

That's a great find indeed, thank you for spotting it and offering a fix! 1 then 2 sounds good to me, is there any way I can help?

@epwalsh
Copy link
Member Author
epwalsh commented Apr 21, 2021

I moved the fix to this PR: #5141

@epwalsh epwalsh merged commit 7473737 into main May 7, 2021
@epwalsh epwalsh deleted the model-diff branch May 7, 2021 23:07
dirkgr added a commit that referenced this pull request May 10, 2021
* add diff command

* fix docs

* no silly geese

* update CHANGELOG

* move 'load_state_dict' to nn.util

* normalize by size

* handle different checkpoint types

* add integration tests

* add 'scale' and 'threshold' params

* HuggingFace Hub support

* support '_/' as well, add test

* revert some changes

* fix

* Update CHANGELOG.md

* Update codecov.yml

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0