8000 [WIP] Various bugs in passes by mbrookhart · Pull Request #6906 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Various bugs in passes #6906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ class ConstIntBoundAnalyzer::Impl
*/
static Entry MakeBound(int64_t min_value, int64_t max_value) {
Entry e;
e.min_value = min_value;
e.max_value = max_value;
e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value;
e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value;
return e;
}
/*!
Expand Down
10 changes: 7 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (updates == nullptr) {
return false;
}
ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
ICHECK(indices->dtype.is_int()) << "indices of scatter must be tensor of integer";
const auto param = attrs.as<ScatterAttrs>();
ICHECK(param != nullptr);
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
Expand Down Expand Up @@ -1076,7 +1076,7 @@ Examples::
.set_support_level(3)
.add_type_rel("Take", TakeRel)
.set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kOpaque);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope we could come up with a way to avoid this. This would hurt performance on hummingbird workload and possibly other models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, maybe we can use annotation.stop_fusion when we encounter take + dynamic, solving this problem at the frontend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thing I have a typo in CI.

I'm not sure I see a clean way to do this in the frontends, it demands we already have infer_type run to check for dynamic inputs.

Maybe we write a pass that selectively stops fusion on certain ops under certain conditions?

Copy link
Member
@masahi masahi Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes a new pass that inserts stop_fusion sounds good. I can work on this. We can make take opaque for now until I have that new pass ready (or take compute being fixed).

Anyway I think graph runtime shouldn't be affected by the issue of take + dynamic, so take should be injective eventually.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbrookhart Can you add TODO(masahi) there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi want to request changes so this doesn't merge while we chat about it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure done. But I think it is good to go after adding a comment on why we make this change temporary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the other option is to edit the relay take function that creates the op. We could remove the normalization that causes this problem from the take kernel in topi, and do in it relay with select/shape_of, but that might end up causing some performance degradation, it's hard to predict.


// Init ops
TVM_REGISTER_NODE_TYPE(InitOpAttrs);
Expand Down Expand Up @@ -2322,7 +2322,11 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<IndexExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
if (input->shape[i]->IsInstance<tvm::IntImmNode>()) {
out_shape.push_back(input->shape[i]);
} else {
out_shape.push_back(tvm::tir::Var("dim"));
}
}
Array<PrimExpr> begin_expr;
Array<PrimExpr> strides_expr;
Expand Down
13 changes: 12 additions & 1 deletion src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,18 @@ class ForwardPrep : private ExprVisitor {
}
}
// Visitor pattern override.
void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; }
void VisitExpr_(const LetNode* op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kevinthesun This might be of interest to you

ExprVisitor::VisitExpr_(op);
// do pass through condition
// by assigning NullValue<Message>
// it means fuse signal cannot pass
// through into these subexpressions.
auto flazy = [this, op]() {
this->Update(op->value, NullValue<Message>());
this->Update(op->body, NullValue<Message>());
};
flist_.push_back(flazy);
}

void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_pass_fold_scale_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,44 @@ def check(shape, channels, blocking, in_scale):
check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)


def test_fold_fwd_let_fail():
"""testcase where we canont fold"""

def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias]
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
x_var = relay.Var("x_var")
y1 = relay.nn.conv2d(
x_var,
conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
kernel_layout="HWIO",
padding=(1, 1),
)
z = relay.add(y1, x)
let = relay.Let(x_var, x, z)
return relay.Function(args, let)

def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert tvm.ir.structural_equal(y1, y1_folded)

check((2, 11, 10, 4), 4)


def test_fold_fwd_negative_scale():
"""Testcase of folding negative scale"""

Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm import te
from tvm import relay
Expand Down Expand Up @@ -623,6 +625,8 @@ def expected(n, max_fused_ops):
assert tvm.ir.structural_equal(zz, after)


'''
TODO(mbrookhart): Disabling this test because fusion on take doesn't work in the input is dynamic. Fix take compute before re-enabling
def test_fuse_take():
"""Test fusion case involving concat and take"""

Expand Down Expand Up @@ -654,6 +658,7 @@ def expected():
relay.build(m, "llvm")
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)
'''


def test_fuse_gather_nd():
Expand Down Expand Up @@ -759,6 +764,31 @@ def create_diamond_func(inp):
assert tvm.ir.structural_equal(fused, expected)


def test_fuse_dynamic_squeeze_slice_take():
input_data = [
np.random.random([1, 2, 4]).astype("float32"),
np.array([0]).astype("int64"),
]

x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32")
take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64")

squeeze = relay.op.squeeze(x, axis=[0])
strided_slice = relay.op.strided_slice(
squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1]
)
take = relay.op.take(strided_slice, take_val, axis=0)

mod = tvm.IRModule.from_expr(take)
ex = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm")

result = ex.evaluate()(*input_data)

np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0)

assert np.allclose(result.asnumpy(), np_result)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ def test_add_sub_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 1

## constants with negative or positive max(int64) occassionally show up
## in models, this is to ensure we can handle those cases
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.NEG_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(bd.POS_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF


def test_mul_bound():
analyzer = tvm.arith.Analyzer()
Expand Down
0