8000 Introduce scopes during tracing by lantiga · Pull Request #3016 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Introduce scopes during tracing #3016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 4, 2017
Merged

Introduce scopes during tracing #3016

merged 9 commits into from
Dec 4, 2017

Conversation

lantiga
Copy link
Contributor
@lantiga lantiga commented Oct 6, 2017

This PR introduces scopes (or namespaces) in order to group operations in the tracing IR, e.g.

x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

def doit(x, y):
    tracing_state = torch.jit.get_tracing_state(x,y)
    if tracing_state:
        tracing_state.push_scope('Foo')
    z = Variable(torch.Tensor([0.7]), requires_grad=True)
    out = torch.sigmoid(torch.tanh(x * (y + z)))
    if tracing_state:
        tracing_state.pop_scope()
    return out

traced, _ = torch.jit.trace(doit, (x, y))
g = torch._C._jit_get_graph(traced)
print(g)

outputs

graph(%1 : Float(1)
      %2 : Float(1)) {
  %3 : Float(1) = Constant[value=<Tensor>](), uses = [%4.i1], scope: Foo;
  %5 : Float(1) = ^Add(False)(%2, %3), uses = [[%6.i1]], scope: Foo;
  %7 : Float(1) = ^Mul()(%1, %5), uses = [[%8.i0]], scope: Foo;
  %9 : Float(1) = ^Tanh()(%7), uses = [[%10.i0]], scope: Foo;
  %11 : Float(1) = ^Sigmoid()(%9), uses = [[%0.i0]], scope: Foo;
  return (%11);
}

Scopes work like a stack: they can be pushed and popped manually (as in the example above). Modules automatically push a scope during __call__ and pops it before returning. The scope is named as className$id, where id is the value of the Python id function:

class Net(nn.Module):

        def __init__(self):
            super(Net, self).__init__()
            self.layer1 = nn.Sequential(nn.Linear(2,2), nn.ReLU())

        def forward(self, x):
            return self.layer1(x)

    net = Net()

    t = Variable(torch.ones(2), requires_grad=True)

    traced, _ = torch.jit.trace(net, (t, ))
    g = torch._C._jit_get_graph(traced)
    print(g)

outputs

graph(%1 : Float(2)
      %2 : Float(2, 2)
      %3 : Float(2)) {
  %5 : Float(2!, 2!) = ^Transpose(0, 1)(%2), uses = [[%10.i2]], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %7 : Float(1, 2), %8 : Handle = ^Unsqueeze(0)(%1), uses = [[%10.i1], []], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %9 : Float(1, 2) = Constant[value=<Tensor>](), uses = [%10.i0], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %11 : Float(1, 2), %12 : Handle = ^Addmm(0, 1, True)(%9, %7, %5), uses = [[%13.i0], []], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %14 : Float(2), %15 : Handle = ^Squeeze(0, True)(%11), uses = [[%16.i0], []], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %17 : Float(2) = ^Add(True)(%14, %3), uses = [[%18.i0]], scope: Net$4569918736.Sequential$4569919312.Linear$4569919120;
  %19 : Float(2), %20 : Handle = ^Threshold(0, 0, False)(%17), uses = [[%0.i0], []], scope: Net$4569918736.Sequential$4569919312.ReLU$4569919184;
  return (%19);
}

Tests fail at the moment because the expected output of traces differs, as I added a scope description. I'm not sure it belongs there, at the moment it's handy for debugging purposes. If we decide to keep them this way I'll update the expected output.

Under the hood, scope names are implemented using interned strings.

/cc @ezyang @fmassa

@@ -276,6 +293,8 @@ def __call__(self, *input, **kwargs):
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
if tracing_state:
torch._tracing_state.pop_scope()

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -20,6 +20,20 @@ def _addindent(s_, numSpaces):
return s


def _first_var(input, kwargs):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if (scope_.empty()) {
return scope_name;
}
scope_name = std::accumulate(

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor
ezyang commented Oct 7, 2017

This looks pretty good to me!

8000

@lantiga
Copy link
Contributor Author
lantiga commented Oct 7, 2017

Thanks for the review @ezyang

tracing_state = torch.jit.get_tracing_state(_first_var(input, kwargs))
try:
if tracing_state:
tracing_state.push_scope('%s$%d' % (self.__class__.__name__, id(self)))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
tracing_state = torch.jit.get_tracing_state(_first_var(input, kwargs))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fmassa
Copy link
Member
fmassa commented Oct 7, 2017

I think this looks good, thanks @lantiga! I have some small comments that are mostly details, let me know what you think.

8000

@lantiga
Copy link
Contributor Author
lantiga commented Oct 7, 2017

I've just pushed a new commit:

  • args to __call__ in module are flattened (instead of _first_var)
  • I introduced a torch.jit.scope context manager

So, what was

x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

def doit(x, y):
    tracing_state = torch.jit.get_tracing_state(x,y)
    if tracing_state:
        tracing_state.push_scope('Foo')
    z = Variable(torch.Tensor([0.7]), requires_grad=True)
    out = torch.sigmoid(torch.tanh(x * (y + z)))
    if tracing_state:
        tracing_state.pop_scope()
    return out

traced, _ = torch.jit.trace(doit, (x, y))
g = torch._C._jit_get_graph(traced)
print(g)

becomes

x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

def doit(x, y):
    with torch.jit.scope('Foo', (x, y)):
        z = Variable(torch.Tensor([0.7]), requires_grad=True)
        out = torch.sigmoid(torch.tanh(x * (y + z)))
    return out

traced, _ = torch.jit.trace(doit, (x, y))
g = torch._C._jit_get_graph(traced)
print(g)

tracing_state = torch._C.get_tracing_state(vars)
if tracing_state:
tracing_state.push_scope(scope_name)
yield

This comment was marked as off-topic.

@@ -20,6 +20,20 @@ def _addindent(s_, numSpaces):
return s


def _flatten(*args):

This comment was marked as off-topic.

This comment was marked as off-topic.

@lantiga
Copy link
Contributor Author
lantiga commented Oct 7, 2017

Thanks @fmassa BTW!

@lantiga
Copy link
Contributor Author
lantiga commented Oct 7, 2017

Here's the new output (each module appears as className$name):

graph(%1 : Float(2)
      %2 : Float(2, 2)
      %3 : Float(2)) {
  %5 : Float(2!, 2!) = ^Transpose(0, 1)(%2), uses = [[%10.i2]], scope: Net$inner.Sequential$layer1.Linear$0;
  %7 : Float(1, 2), %8 : Handle = ^Unsqueeze(0)(%1), uses = [[%10.i1], []], scope: Net$inner.Sequential$layer1.Linear$0;
  %9 : Float(1, 2) = Constant[value= 0.5175 -0.6471 [ CPUFloatTensor{1,2} ]](), uses = [%10.i0], scope: Net$inner.Sequential$layer1.Linear$0;
  %11 : Float(1, 2), %12 : Handle = ^Addmm(0, 1, True)(%9, %7, %5), uses = [[%13.i0], []], scope: Net$inner.Sequential$layer1.Linear$0;
  %14 : Float(2), %15 : Handle = ^Squeeze(0, True)(%11), uses = [[%16.i0], []], scope: Net$inner.Sequential$layer1.Linear$0;
  %17 : Float(2) = ^Add(True)(%14, %3), uses = [[%18.i0]], scope: Net$inner.Sequential$layer1.Linear$0;
  %19 : Float(2), %20 : Handle = ^Threshold(0, 0, False)(%17), uses = [[%0.i0], []], scope: Net$inner.Sequential$layer1.ReLU$1;
  return (%19);
}

@@ -56,6 +70,7 @@ def __init__(self):
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
self._name = ''

This comment was marked as off-topic.

This comment was marked as off-topic.

8000

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

Graph* pushScope(const std::string& scope) {
if (scope.empty()) {

This comment was marked as off-topic.

This comment was marked as off-topic.

modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
value._name = name
if value is not None:
value._name = name

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor
@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, good to finally get scopes :) I have a few requests regarding implementation details

if (scope_.empty()) {
return scope_name;
}
scope_name = std::accumulate(

This comment was marked as off-topic.

wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
with torch.jit.scope('%s$%s' % (self.__class__.__name__, self._name),

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if '_name' not in self.__dict__:
self._name = ''
for name, module in self.named_children():
module._name = name

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -245,7 +245,8 @@ std::ostream& printNode(std::ostream & out, Node * n, std::vector<Node*> * group
} else {
emitUses(out,n);
}
out << "];\n";
out << "], ";
out << "scope: " << n->scopeName() << ";\n";

This comment was marked as off-topic.

This comment was marked as off-topic.

@lantiga
Copy link
Contributor Author
lantiga commented Oct 20, 2017

Hey @apaszke @ezyang , I had a very packed couple of weeks. I resumed this - I've got the trie, will make the last changes and update the PR. Thanks for your patience!

@lantiga
Copy link
Contributor Author
lantiga commented Oct 20, 2017

Note after last update: scopes are not preserved after a ONNX pass. I'm looking into it.

@lantiga
Copy link
Contributor Author
lantiga commented Oct 20, 2017

Fixed scopes for the ONNX pass.

@apaszke
Copy link
Contributor
apaszke commented Oct 21, 2017

@lantiga we'll have to wait with merging this PR anyway. Autograd PR heavily disrupted the JIT work, but it was very high priority, so we were ok with this. Right now, we're focused on unbreaking ONNX, but parts of the JIT infra will be different now - e.g. we won't have this initial onnxification pass, only the proper export pass (most functions are in C++ and we already auto-generate code to make them JIT nodes with proper names). Sorry for having you wait 😕

@lantiga
Copy link
Contributor Author
lantiga commented Oct 21, 2017

No worries, I had to get this out of my system. When you’re ready I’ll be happy to fix any breakages that the various merges might have caused. Thanks

@lantiga
Copy link
Contributor Author
lantiga commented Nov 5, 2017

Since the C++ refactor has landed, this PR has been rebased on top of current master and updated accordingly.

A few tests fail because current expects do not include scope information. Scopes are now added automatically by Module if tracing is active, so they then get printed.
We can either update the expects or make printing include scopes through a flag.

/cc @ezyang @apaszke

// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor
ezyang commented Nov 7, 2017

A few tests fail because current expects do not include scope information. Scopes are now added automatically by Module if tracing is active, so they then get printed. We can either update the expects or make printing include scopes through a flag.

The scopes we add by default are fairly stable, right? I'd be happy with either solution, but certainly --accept ing the new output seems easiest.

@ezyang
Copy link
Contributor
ezyang commented Nov 7, 2017

I think this is good enough to be useful, so we should merge it before it bitrots more, but since @zdevito is owner of the JIT/tracer, he makes the final call.

Having done one more review pass on it, there are two high level things that pop out to me:

  • This PR adds another "observation point" for seeing tracing, the global torch.jit._tracing. I can see why this is expedient, and I think it is being used in a sound way, but it means that backwards tracing with scopes will not work out of the box. Since the primary use-case for this was ONNX (forwards only), I think it makes sense not to try to also nail the backwards case today; furthermore, I suspect @zdevito may be touching the tracer a lot in the near future, so the mechanics of how to do this correctly may change.

  • Scopes were attached as state to the graph. This is consistent with the way that we've handled other similar concerns (e.g., stages), but it is also morally wrong: logically, an IR doesn't have a "scope" in any sense; it's a property of the tracer. But I don't think it would be awful to fix this later, if we ever need it fixed, so it seems OK for now as well.

(And yes, I know I helped suggest some of these things; I apologize for not having 100% design clarity before hand ^^)

@lantiga
C95D Copy link
Contributor Author
lantiga commented Nov 25, 2017

Thank you @apaszke! No worries, take your time

@lantiga lantiga mentioned this pull request Nov 27, 2017
Copy link
Contributor
@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good, except for the Scope ownership scheme (this cycle I mentioned in the comment). Once this is fixed, it should be good to go.

private:
Scope* parent_;
Symbol name_;
std::unordered_set<std::unique_ptr<Scope> > children_;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

children_.insert(std::unique_ptr<Scope>(newScope));
return newScope;
}
Scope* pop() {

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -59,6 +61,9 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
std::shared_ptr<Graph> graph;
bool active;

std::unique_ptr<Scope> scope_root;

This comment was marked as off-topic.

426B

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -319,10 +319,42 @@ def register_forward_hook(self, hook):
self._forward_hooks[handle.id] = hook
return handle

def tracing_name(self, tracing_state):

This comment was marked as off-topic.

This comment was marked as off-topic.

return None
module = tracing_state._traced_module_stack[-1]
for name, child in module.named_children():
if child == self:

This comment was marked as off-topic.

This comment was marked as off-topic.

if torch.jit._tracing:
result = self.slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)

This comment was marked as off-topic.

return name
return None

def slow_forward(self, *input, **kwargs):

This comment was marked as off-topic.

This comment was marked as off-topic.

@lantiga
Copy link
Contributor Author
lantiga commented Nov 29, 2017

@apaszke Fixes are in, Scopes are back into Graph.

Copy link
Contributor
@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but there's this one problem that can lead to segfaults. Should be ready once this is fixed.

Scope* push(Symbol name) {
Scope* newScope = new Scope(this, name);
children_.insert(std::unique_ptr<Scope>(newScope));
return newScope;

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -567,6 +646,15 @@ friend struct Value;
Graph()
: next_unique_(0)
, new_node_stage_(0)
, scope_root_(std::make_shared<Scope>())
, current_scope_(scope_root_.get())
, output_(initOutput(create(kReturn, 0))), input_(create(kParam, 0)) {}

This comment was marked as off-topic.

This comment was marked as off-topic.

@lantiga
Copy link
Contributor Author
lantiga commented Nov 30, 2017

@apaszke Done. I'm using vector now (see discussion inline). I only quote this line:

Deduping based on names only makes sense if we'll ever compare the scope_ pointers stored in Node to check if two nodes have the same scope. I'm going without, let me know about this.

as I'd like to make sure with you about deduping.

Copy link
Contributor
@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but it conflicts with something now 😕 I also noticed that it could leak memory in case of an error so it would be nice to fix that as well

@@ -93,7 +93,7 @@ struct Scope {
}
Scope* push(Symbol name) {
Scope* newScope = new Scope(this, name);
children_.insert(std::unique_ptr<Scope>(newScope));
children_.push_back(std::unique_ptr<Scope>(newScope));

This comment was marked as off-topic.

@lantiga
Copy link
Contributor Author
lantiga commented Dec 3, 2017

@apaszke fixed!

@apaszke apaszke merged commit 4eb8e12 into pytorch:master Dec 4, 2017
@apaszke
Copy link
Contributor
apaszke commented Dec 4, 2017

Thanks Luca!

@hhsecond
Copy link
hhsecond commented Dec 15, 2017

@lantiga @apaszke I am writing a visualizer for PyTorch using the tracer and scope in tracer is extremely useful for grouping the nodes while rendering the graph. I understand that each node has a pointer to its scope but I don't think we have access to that from python. Wouldn't it be nice to get the scope name in python probably as an attribute of node objects like node.scope() or something?

@lantiga
Copy link
Contributor Author
lantiga commented Dec 15, 2017

@hhsecond adding the scope name from python is not a problem. You will get a string Foo/Bar or Foo/Bar[baz], where Foo and Bar are scope names. The second form will be generated automatically by modules, one scope name per module: in this case baz is the name of the variable containing the module Bar in the parent module Foo

class Foo(nn.Module):
    def __init__(self):
        [...]
        self.baz = Bar()

Feel free to open an issue for exposing scopeName to python and mention me.

@hhsecond
Copy link

@lantiga Perfect! Thanks a ton

@lanpa
Copy link
Collaborator
lanpa commented Dec 18, 2017

@lantiga Will scope information be saved in onnx too? Thanks

@lantiga
Copy link
Contributor Author
lantiga commented Dec 18, 2017

I don't know about what the specific plans are /cc @ezyang and @bddppq

lantiga added a commit to lantiga/pytorch that referenced this pull request Feb 8, 2018
soumith pushed a commit that referenced this pull request Feb 9, 2018
* Introduce scopes during tracing (#3016)

* Fix segfault during ONNX export

* Further fix to tracing scope (#4558)

* Set missing temporary scope in callPySymbolicMethod

* Use expected traces in all scope tests

* Fix tracking of tracing scopes during ONNX pass (#4524)

* Fix tracking of tracing scopes during ONNX pass

* Use ResourceGuard to manage setting a temporary current scope in Graph

* Add tests for ONNX pass scopes

* Remove unused num_classes argument

* Expose node scopeName to python (#4200)

* Inherit JIT scopes when cloning only when it's correct

It's correct only when the new graph owns the same scope tree
as the original one. We can end up with dangling pointers otherwise.

* Fixes after cherry-picking, still one test to go

* Fix for last failing test after scope cherry-pick

* Fix linting issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0