8000 Use implicit gradient `Op`s · Issue #1275 · aesara-devs/aesara · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Use implicit gradient Ops #1275
Open
Open
@brandonwillard

Description

@brandonwillard

We can delay the construction of explicit gradient graphs (i.e. use of Op.grad and the like) by employing implicit gradient Ops that are later replaced with explicit sub-graph (e.g. similar to how OpFromGraphs can be "in-lined").

The approach would look as follows:

from functools import wraps

import aesara
import aesara.tensor as at

from aesara.graph.basic import Apply
from aesara.graph.op import Op

from aesara.compile.mode import optdb
from aesara.graph.rewriting.basic import in2out, node_rewriter



class Gradient(Op):
    __props__ = ("grad_options",)

    def __init__(self, **grad_options):
        self.grad_options = tuple(grad_options.items())

    def make_node(self, cost, *wrt):

        # Only the output types are needed, but, since there's some caching
        # here and, if we also assume that most gradients are _eventually_
        # expanded as-is, this seems somewhat less wasteful.
        grad_res = aesara.grad(cost, wrt, **dict(self.grad_options))

        if not isinstance(grad_res, (tuple, list)):
            grads = (grad_res,)
        else:
            grads = grad_res

        inputs = (cost,) + wrt
        outputs = [g.clone() for g in grads]

        return Apply(self, inputs, outputs)

    def perform(self, *args, **kwargs):
        raise NotImplementedError("This shouldn't ever be called")


@wraps(aesara.grad)
def grad(cost, wrt, **kwargs):

    if not isinstance(wrt, (list, tuple)):
        wrt = [wrt]

    GradOp = Gradient(**kwargs)
    return GradOp(cost, *wrt)


@node_rewriter([Gradient])
def expand_gradients(fgraph, node):
    op = node.op

    cost, *wrt = node.inputs
    grad_res = aesara.grad(cost, wrt, **dict(op.grad_options))

    if not isinstance(grad_res, (tuple, list)):
        grads = (grad_res,)
    else:
        grads = grad_res

    return grads


optdb.register(
    "expand_gradients",
    in2out(expand_gradients),
    "fast_compile",
    "fast_run",
    position=-0.01,
)


x = at.vector("x")
x_sum = x.sum()

x_grad = grad(x_sum, x)

aesara.dprint(x_grad)
# Gradient{grad_options=()} [id A]
#  |Sum{acc_dtype=float64} [id B]
#  | |x [id C]
#  |x [id C]


with aesara.config.change_flags(on_opt_error="raise"):
    x_grad_fn = aesara.function([x], x_grad)

aesara.dprint(x_grad_fn)
# Alloc [id A] 1
#  |TensorConstant{(1,) of 1.0} [id B]
#  |Shape_i{0} [id C] 0
#    |x [id D]

This also has the effect of enabling rewrites on gradient expressions and of providing more shape information to our gradient implementations.

For instance, this could be used to remove shape inference responsibilities and requirements from some Op.make_node and Op.grad implementations (e.g. Elemwise.L_op) by allowing access to ShapeFeatures and other compile/rewrite-time only information. Simply put, this is probably the easiest—and even best—way to guarantee that symbolic gradient implementations will always have the most Type and shape information available, and all without wastefully cloning shape graphs and re-performing rewrites (e.g. like constant folding) on them.

This approach was proposed in Theano/Theano#4452 and might also help with #682. As mentioned in the latter, we need to think carefully about when we make implicit gradient Ops explicit (i.e. "expand" them). Depending on exactly which rewrites are applied and when, the resulting gradient graphs could be quite different and have distinct and possibly unexpected numerical properties.

To keep things simple, we can expand implicit Ops right after the first pass of basic canonicalizations so that shape inference/ShapeFeature is useful and other rewrites (e.g. specializations) won't get in the way. If this approach helps with #682, then great, but, if not, I don't know if we should get into the details of further delayed or staged expansions just yet. Regardless, we'll have the machinery available to do that whenever we want.

N.B. Performing expansions in this way still changes our gradient results so that they're dependent on our canonicalizations. In some ways, this relationship sounds good, since it seems to imply that the set of graphs we would be dealing with from then on would be more "regular". Over time, we could converge on a more concentrated and effective set of stabilizing rewrites for the exact kinds of gradients that our implementations and canonicalizations tend to produce, because we would have to deal less with the particulars of "random" user-formulated graphs.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Graph

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0