Description
We can delay the construction of explicit gradient graphs (i.e. use of Op.grad
and the like) by employing implicit gradient Op
s that are later replaced with explicit sub-graph (e.g. similar to how OpFromGraph
s can be "in-lined").
The approach would look as follows:
from functools import wraps
import aesara
import aesara.tensor as at
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.compile.mode import optdb
from aesara.graph.rewriting.basic import in2out, node_rewriter
class Gradient(Op):
__props__ = ("grad_options",)
def __init__(self, **grad_options):
self.grad_options = tuple(grad_options.items())
def make_node(self, cost, *wrt):
# Only the output types are needed, but, since there's some caching
# here and, if we also assume that most gradients are _eventually_
# expanded as-is, this seems somewhat less wasteful.
grad_res = aesara.grad(cost, wrt, **dict(self.grad_options))
if not isinstance(grad_res, (tuple, list)):
grads = (grad_res,)
else:
grads = grad_res
inputs = (cost,) + wrt
outputs = [g.clone() for g in grads]
return Apply(self, inputs, outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("This shouldn't ever be called")
@wraps(aesara.grad)
def grad(cost, wrt, **kwargs):
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
GradOp = Gradient(**kwargs)
return GradOp(cost, *wrt)
@node_rewriter([Gradient])
def expand_gradients(fgraph, node):
op = node.op
cost, *wrt = node.inputs
grad_res = aesara.grad(cost, wrt, **dict(op.grad_options))
if not isinstance(grad_res, (tuple, list)):
grads = (grad_res,)
else:
grads = grad_res
return grads
optdb.register(
"expand_gradients",
in2out(expand_gradients),
"fast_compile",
"fast_run",
position=-0.01,
)
x = at.vector("x")
x_sum = x.sum()
x_grad = grad(x_sum, x)
aesara.dprint(x_grad)
# Gradient{grad_options=()} [id A]
# |Sum{acc_dtype=float64} [id B]
# | |x [id C]
# |x [id C]
with aesara.config.change_flags(on_opt_error="raise"):
x_grad_fn = aesara.function([x], x_grad)
aesara.dprint(x_grad_fn)
# Alloc [id A] 1
# |TensorConstant{(1,) of 1.0} [id B]
# |Shape_i{0} [id C] 0
# |x [id D]
This also has the effect of enabling rewrites on gradient expressions and of providing more shape information to our gradient implementations.
For instance, this could be used to remove shape inference responsibilities and requirements from some Op.make_node
and Op.grad
implementations (e.g. Elemwise.L_op
) by allowing access to ShapeFeature
s and other compile/rewrite-time only information. Simply put, this is probably the easiest—and even best—way to guarantee that symbolic gradient implementations will always have the most Type
and shape information available, and all without wastefully cloning shape graphs and re-performing rewrites (e.g. like constant folding) on them.
This approach was proposed in Theano/Theano#4452 and might also help with #682. As mentioned in the latter, we need to think carefully about when we make implicit gradient Op
s explicit (i.e. "expand" them). Depending on exactly which rewrites are applied and when, the resulting gradient graphs could be quite different and have distinct and possibly unexpected numerical properties.
To keep things simple, we can expand implicit Op
s right after the first pass of basic canonicalizations so that shape inference/ShapeFeature
is useful and other rewrites (e.g. specializations) won't get in the way. If this approach helps with #682, then great, but, if not, I don't know if we should get into the details of further delayed or staged expansions just yet. Regardless, we'll have the machinery available to do that whenever we want.
N.B. Performing expansions in this way still changes our gradient results so that they're dependent on our canonicalizations. In some ways, this relationship sounds good, since it seems to imply that the set of graphs we would be dealing with from then on would be more "regular". Over time, we could converge on a more concentrated and effective set of stabilizing rewrites for the exact kinds of gradients that our implementations and canonicalizations tend to produce, because we would have to deal less with the particulars of "random" user-formulated graphs.
Metadata
Metadata
Assignees
Type
Projects
Status