-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Introduce scopes during tracing #3016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/nn/modules/module.py
Outdated
@@ -276,6 +293,8 @@ def __call__(self, *input, **kwargs): | |||
wrapper = functools.partial(hook, self) | |||
functools.update_wrapper(wrapper, hook) | |||
grad_fn.register_hook(wrapper) | |||
if tracing_state: | |||
torch._tracing_state.pop_scope() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
@@ -20,6 +20,20 @@ def _addindent(s_, numSpaces): | |||
return s | |||
|
|||
|
|||
def _first_var(input, kwargs): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.h
Outdated
if (scope_.empty()) { | ||
return scope_name; | ||
} | ||
scope_name = std::accumulate( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This looks pretty good to me! |
Thanks for the review @ezyang |
torch/nn/modules/module.py
Outdated
tracing_state = torch.jit.get_tracing_state(_first_var(input, kwargs)) | ||
try: | ||
if tracing_state: | ||
tracing_state.push_scope('%s$%d' % (self.__class__.__name__, id(self))) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
wrapper = functools.partial(hook, self) | ||
functools.update_wrapper(wrapper, hook) | ||
grad_fn.register_hook(wrapper) | ||
tracing_state = torch.jit.get_tracing_state(_first_var(input, kwargs)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I think this looks good, thanks @lantiga! I have some small comments that are mostly details, let me know what you think. |
I've just pushed a new commit:
So, what was
becomes
|
torch/jit/__init__.py
Outdated
tracing_state = torch._C.get_tracing_state(vars) | ||
if tracing_state: | ||
tracing_state.push_scope(scope_name) | ||
yield |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
@@ -20,6 +20,20 @@ def _addindent(s_, numSpaces): | |||
return s | |||
|
|||
|
|||
def _flatten(*args): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Thanks @fmassa BTW! |
Here's the new output (each module appears as
|
torch/nn/modules/module.py
Outdated
@@ -56,6 +70,7 @@ def __init__(self): | |||
self._forward_pre_hooks = OrderedDict() | |||
self._modules = OrderedDict() | |||
self.training = True | |||
self._name = '' |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
8000
This comment was marked as off-topic.
8000
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.h
Outdated
} | ||
|
||
|
||
if (scope.empty()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
modules[name] = value | ||
elif modules is not None and name in modules: | ||
if value is not None: | ||
raise TypeError("cannot assign '{}' as child module '{}' " | ||
"(torch.nn.Module or None expected)" | ||
.format(torch.typename(value), name)) | ||
value._name = name | ||
if value is not None: | ||
value._name = name |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, good to finally get scopes :) I have a few requests regarding implementation details
torch/csrc/jit/ir.h
Outdated
if (scope_.empty()) { | ||
return scope_name; | ||
} | ||
scope_name = std::accumulate( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
wrapper = functools.partial(hook, self) | ||
functools.update_wrapper(wrapper, hook) | ||
grad_fn.register_hook(wrapper) | ||
with torch.jit.scope('%s$%s' % (self.__class__.__name__, self._name), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
if '_name' not in self.__dict__: | ||
self._name = '' | ||
for name, module in self.named_children(): | ||
module._name = name |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.cpp
Outdated
@@ -245,7 +245,8 @@ std::ostream& printNode(std::ostream & out, Node * n, std::vector<Node*> * group | |||
} else { | |||
emitUses(out,n); | |||
} | |||
out << "];\n"; | |||
out << "], "; | |||
out << "scope: " << n->scopeName() << ";\n"; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Note after last update: scopes are not preserved after a ONNX pass. I'm looking into it. |
Fixed scopes for the ONNX pass. |
@lantiga we'll have to wait with merging this PR anyway. Autograd PR heavily disrupted the JIT work, but it was very high priority, so we were ok with this. Right now, we're focused on unbreaking ONNX, but parts of the JIT infra will be different now - e.g. we won't have this initial onnxification pass, only the proper export pass (most functions are in C++ and we already auto-generate code to make them JIT nodes with proper names). Sorry for having you wait 😕 |
No worries, I had to get this out of my system. When you’re ready I’ll be happy to fix any breakages that the various merges might have caused. Thanks |
Since the C++ refactor has landed, this PR has been rebased on top of current master and updated accordingly. A few tests fail because current expects do not include scope information. Scopes are now added automatically by |
torch/csrc/jit/ir.h
Outdated
// to the scope that was current when the node was created. | ||
// The trie never needs to shrink, it only grows until it is disposed | ||
// of when Graph is deallocated. Hence, pointers to scopes held by nodes | ||
// will always be valid. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
The scopes we add by default are fairly stable, right? I'd be happy with either solution, but certainly |
I think this is good enough to be useful, so we should merge it before it bitrots more, but since @zdevito is owner of the JIT/tracer, he makes the final call. Having done one more review pass on it, there are two high level things that pop out to me:
(And yes, I know I helped suggest some of these things; I apologize for not having 100% design clarity before hand ^^) |
Thank you @apaszke! No worries, take your time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good, except for the Scope ownership scheme (this cycle I mentioned in the comment). Once this is fixed, it should be good to go.
torch/csrc/jit/ir.h
Outdated
private: | ||
Scope* parent_; | ||
Symbol name_; | ||
std::unordered_set<std::unique_ptr<Scope> > children_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.h
Outdated
children_.insert(std::unique_ptr<Scope>(newScope)); | ||
return newScope; | ||
} | ||
Scope* pop() { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/tracer_state.h
Outdated
@@ -59,6 +61,9 @@ struct TracingState : public std::enable_shared_from_this<TracingState> { | |||
std::shared_ptr<Graph> graph; | |||
bool active; | |||
|
|||
std::unique_ptr<Scope> scope_root; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
@@ -319,10 +319,42 @@ def register_forward_hook(self, hook): | |||
self._forward_hooks[handle.id] = hook | |||
return handle | |||
|
|||
def tracing_name(self, tracing_state): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
return None | ||
module = tracing_state._traced_module_stack[-1] | ||
for name, child in module.named_children(): | ||
if child == self: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
if torch.jit._tracing: | ||
result = self.slow_forward(*input, **kwargs) | ||
else: | ||
result = self.forward(*input, **kwargs) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
return name | ||
return None | ||
|
||
def slow_forward(self, *input, **kwargs): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@apaszke Fixes are in, Scopes are back into Graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but there's this one problem that can lead to segfaults. Should be ready once this is fixed.
torch/csrc/jit/ir.h
Outdated
Scope* push(Symbol name) { | ||
Scope* newScope = new Scope(this, name); | ||
children_.insert(std::unique_ptr<Scope>(newScope)); | ||
return newScope; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/ir.h
Outdated
@@ -567,6 +646,15 @@ friend struct Value; | |||
Graph() | |||
: next_unique_(0) | |||
, new_node_stage_(0) | |||
, scope_root_(std::make_shared<Scope>()) | |||
, current_scope_(scope_root_.get()) | |||
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@apaszke Done. I'm using
as I'd like to make sure with you about deduping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but it conflicts with something now 😕 I also noticed that it could leak memory in case of an error so it would be nice to fix that as well
torch/csrc/jit/ir.h
Outdated
@@ -93,7 +93,7 @@ struct Scope { | |||
} | |||
Scope* push(Symbol name) { | |||
Scope* newScope = new Scope(this, name); | |||
children_.insert(std::unique_ptr<Scope>(newScope)); | |||
children_.push_back(std::unique_ptr<Scope>(newScope)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@apaszke fixed! |
Thanks Luca! |
@lantiga @apaszke I am writing a visualizer for PyTorch using the tracer and scope in tracer is extremely useful for grouping the nodes while rendering the graph. I understand that each node has a pointer to its scope but I don't think we have access to that from python. Wouldn't it be nice to get the scope name in python probably as an attribute of node objects like |
@hhsecond adding the scope name from python is not a problem. You will get a string
Feel free to open an issue for exposing scopeName to python and mention me. |
@lantiga Perfect! Thanks a ton |
@lantiga Will scope information be saved in onnx too? Thanks |
* Introduce scopes during tracing (#3016) * Fix segfault during ONNX export * Further fix to tracing scope (#4558) * Set missing temporary scope in callPySymbolicMethod * Use expected traces in all scope tests * Fix tracking of tracing scopes during ONNX pass (#4524) * Fix tracking of tracing scopes during ONNX pass * Use ResourceGuard to manage setting a temporary current scope in Graph * Add tests for ONNX pass scopes * Remove unused num_classes argument * Expose node scopeName to python (#4200) * Inherit JIT scopes when cloning only when it's correct It's correct only when the new graph owns the same scope tree as the original one. We can end up with dangling pointers otherwise. * Fixes after cherry-picking, still one test to go * Fix for last failing test after scope cherry-pick * Fix linting issue
This PR introduces scopes (or namespaces) in order to group operations in the tracing IR, e.g.
outputs
Scopes work like a stack: they can be pushed and popped manually (as in the example above). Modules automatically push a scope during
__call__
and pops it before returning. The scope is named asclassName$id
, where id is the value of the Pythonid
function:outputs
Tests fail at the moment because the expected output of traces differs, as I added a
scope
description. I'm not sure it belongs there, at the moment it's handy for debugging purposes. If we decide to keep them this way I'll update the expected output.Under the hood, scope names are implemented using interned strings.
/cc @ezyang @fmassa