-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Conversation
Pinging @AkshitaB, @ArjunSubramonian, @jacob-morrison, and @dirkgr for feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments! Overall, looks great :D
allennlp/commands/diff.py
Outdated
return subparser | ||
|
||
|
||
def load_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function not present anywhere else in the AllenNLP library? It might be better to import it than copy it here, to help with maintenance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. It should definitely be elsewhere.
allennlp/commands/diff.py
Outdated
distance: float | ||
|
||
def display(self): | ||
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally don't think that the L2 distance is too insightful, especially when all the tensors are so high-dimensional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That being said, I can't think of a better single metric to succinctly describe how different two tensors are. If we have such a metric, I don't personally think that sorting by that metric is helpful. I would much rather prefer to see the differing parameters in the order in which they are used during forward propagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a heatmap of the element-wise differences of 2D parameters could be very useful. I don't know how to extend this well to more than 2 dimensions though, perhaps aggregate over the channel dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we flatten the tensors, then bin the weights into a manageable number of bins, then do something with those bins, like show a bar chart of the distance between bins of two tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking over the example I gave, there is certainly a strong correlation with the number of elements in a modified parameter and the corresponding Euclidean distance, suggesting this is not the best metric to use.
That said, I'm pretty sure any other L*-based metric (even L∞) would suffer from the same correlation unless we add an additional size-based normalization term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any way, the exact metric to use should be a configurable option IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just normalizing by
Or the square root of the mean squared "error". Is this meaningful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are y'all running plugins that render LaTex properly in GitHub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Human renderer plugin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is quite a good feature. I wish we could measure how much people use it.
allennlp/commands/diff.py
Outdated
subparser.add_argument( | ||
"checkpoint1", | ||
type=str, | ||
help="""The URL or path to the first PyTorch checkpoint file (e.g. '.pt' or '.bin').""", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if these could also point to model archives.
allennlp/commands/diff.py
Outdated
|
||
out[new_key] = state[key] | ||
|
||
if strip_prefix_used is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't do !strip_prefix_used
anymore? Did I miss a memo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference here is that strip_prefix_used
could be None
.
allennlp/commands/diff.py
Outdated
distance: float | ||
|
||
def display(self): | ||
termcolor.cprint(f"!{self.key}, shape = {self.shape}, △ = {self.distance:.4f}", "yellow") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't think that L2 distance is going to show up meaningfully when you want to validate, for example, that your gradual unfreezing schedule works? Maybe we can do better than L2, but I think it's a great start.
Uses a modified version of the Myers diff algorithm to compute a representation | ||
of the diff between two model state dictionaries. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know that much about Myers, but isn't that only necessary if the order matters? Does the order of entries in the state_dict
matter? I thought that's just alphabetical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order is meaningful. It is the order in which the corresponding modules were registered. Generally this is the order of data flow.
allennlp/commands/diff.py
Outdated
b_tensor = state_dict_b[step.key] | ||
with torch.no_grad(): | ||
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt() | ||
if dist != 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we worry about loss of precision here? I.e. maybe we want to check that it's within some small threshold of 0, not exactly 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The threshold could be a configurable parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great point
@dirkgr @ArjunSubramonian, this is ready for another review. Now handles different input types: checkpoint files, archives, huggingface model IDs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I'm just wondering if we can not have checkpoint types at all once we have native HF hub integration.
allennlp/commands/diff.py
Outdated
elif checkpoint_type == "huggingface": | ||
from transformers.file_utils import ( | ||
hf_bucket_url, | ||
WEIGHTS_NAME, | ||
cached_path as hf_cached_path, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, if that is so, the whole notion of a checkpoint type can go away, right?
allennlp/commands/diff.py
Outdated
b_tensor = state_dict_b[step.key] | ||
with torch.no_grad(): | ||
dist = torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt() | ||
if dist != 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great point that I just thought of as well! I think we could maybe ensure better precision if we actually did (a_tensor != b_tensor).any(). But, I also like the idea of a configurable parameter. This way, if people train two models with a small difference in implementation, they can also use this tool to identify how similar the weights are with different thresholds epsilon.
tests/commands/diff_test.py
Outdated
-c, shape = (3,) | ||
+b, shape = (4,) | ||
+d, shape = (3,) | ||
!e, shape = (3,), difference = 0.5774 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a small comment that explains how you computed 0.5774?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Yea @dirkgr, I'll hold off until the hub PR is merged |
a_tensor = state_dict_a[step.key] | ||
b_tensor = state_dict_b[step.key] | ||
with torch.no_grad(): | ||
dist = (scale * torch.nn.functional.mse_loss(a_tensor, b_tensor).sqrt()).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, is there a reason you opted for sqrt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, two reasons:
- It gives you another interpretation of the formula, which is L2 norm, normalized by the sqrt of the number of elements in the tensor, and
- It seems to put the numbers on a scale that is easier to look at
type=str, | ||
help="""a prefix to remove from all of the 2nd checkpoint's keys.""", | ||
) | ||
subparser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't completely understand the purpose of the scale arg, could you explain its necessity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only show a limited number of decimal places (4) in the output, so if you care about differences on the order of 1e-5 or smaller, you may want to increase the --scale
parameter by an order of magnitude or more. Does that make sense? Should I include that in the help string?
I had to make a small change/fix to how we deal with
|
Is there any way we can make |
Strictly speaking,
|
We could try 2, and if it fails, fall back to 1. |
I think it would be better to try 1 first. If we try 2 first, this creates a vulnerability: Say you have a production system that is downloading |
Excellent point! |
That's a great find indeed, thank you for spotting it and offering a fix! 1 then 2 sounds good to me, is there any way I can help? |
I moved the fix to this PR: #5141 |
* add diff command * fix docs * no silly geese * update CHANGELOG * move 'load_state_dict' to nn.util * normalize by size * handle different checkpoint types * add integration tests * add 'scale' and 'threshold' params * HuggingFace Hub support * support '_/' as well, add test * revert some changes * fix * Update CHANGELOG.md * Update codecov.yml Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
Adds a
diff
command for comparing two arbitrary model checkpoints. This is analogous togit diff
, except that instead of comparing lines between two files, we are comparing(key, tensor)
pairs from the checkpoints' state dictionaries.Example
For example, run this to compare the pre-trained RoBERTa large weights to a version that has been fine-tuned on the SQuAD 2.0 task:
The
△ = x.xxx
means a tensor was modified by this amount (this is 2-norm distance between the corresponding tensor from the first checkpoint and the one from the second).Technical details
The
diff
function uses a modified version of the defaultgit diff
algorithm: the Myers algorithm. The only real modification between our algorithm and the original Myers algorithm is that in addition to the "keep", "insert", and "remove" operations, we also consider a "modify" operation. This operation corresponds to keeping a parameter with the same key and shape, but modifying its weights.Road map
Outstanding questions
For example, maybe we add an option to only show modified parameters and sort by the modification distance.