-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still a couple of questions around indexing that aren't clear to me.
# Shape (batch_size, num_head_tags, sequence_length, sequence_length) | ||
# This energy tensor expresses the following relation: | ||
# energy[i,j] = "Score that j is the head of i". In this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something's backwards in either this comment or the logic above. Because if I substitute j
for ROOT
and i
for some_word
, I get "Score that ROOT
is the head of some_word
", which you set to very negative above, with energy[:, 0, :] = -1e8
. Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's the wrong way around! You replace j with each of the words at index j, so you get: "Score that word[j]
is the head of ROOT
". Therefore, if I want ROOT
to never be a child of a word, I should zero out the first row.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just saying there's a mismatch between your comment and your code. Oh, wait, no, I was confusing myself because of the batch dimension. When I wrote my comment, I was thinking it was normalized_arc_logits[i, j, :]
, so j
was ROOT
, but it's actually normalized_arc_logits[:, i, :]
, so i
is ROOT
. Ok, all good.
@@ -25,7 +25,7 @@ def test_uses_named_inputs(self): | |||
assert head_tags is not None | |||
assert isinstance(head_tags, list) | |||
assert all(isinstance(x, int) for x in head_tags) | |||
|
|||
print(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove print.
# This is the correct MST, but not desirable for dependency parsing. | ||
assert heads.tolist()[0] == [-1, 0, 0] | ||
|
||
energy[:, :, 0, :] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this example really doing what you think it's doing?
>>> t = torch.Tensor([[0, 1, 1], [10, 0, 1], [10, 1, 0]]).view(1, 1, 3, 3)
>>> t
tensor([[[[ 0., 1., 1.],
[ 10., 0., 1.],
[ 10., 1., 0.]]]])
>>> t[:, :, 0, :] = 0
>>> t
tensor([[[[ 0., 0., 0.],
[ 10., 0., 1.],
[ 10., 1., 0.]]]])
This is zeroing out the top row, not the first column. Is that what you expected? And in general, I'd recommend using different numbers for every non-zero value, so there are no ties and it's more obvious what the MST should be.
* initial fix * correct approach * fix and test * fix predictor test * fix pylint * use unique edge weights
No description provided.