diff --git a/python/tvm/meta_schedule/testing/tir_tensor_intrin.py b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py index 902183a215e2..04722e5668d7 100644 --- a/python/tvm/meta_schedule/testing/tir_tensor_intrin.py +++ b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py @@ -95,94 +95,16 @@ def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: ) -# @T.prim_func -# def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: -# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a") -# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b") -# C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator") - -# with T.block("root"): -# for i, j, k in T.grid(16, 16, 16): -# with T.block("update"): -# vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) -# C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast( -# B[vkk, vjj], "float32" -# ) - - -# @T.prim_func -# def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: -# A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") -# B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") -# C = T.match_buffer( -# c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" -# ) - -# with T.block("root"): -# T.reads( -# [ -# C[0:16, 0:16], -# A[0:16, 0:16], -# B[0:16, 0:16], -# ] -# ) -# T.writes(C[0:16, 0:16]) -# T.evaluate( -# T.tvm_mma_sync( -# C.data, -# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), -# A.data, -# A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), -# B.data, -# B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16), -# C.data, -# C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), -# dtype="handle", -# ) -# ) - @T.prim_func -def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, - scope="wmma.matrix_a") - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, - scope="wmma.matrix_b") - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, - scope="wmma.accumulator") - - with T.block("root"): - T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16]) - T.writes(C[0 : 16, 0 : 16]) - for i, j, k in T.grid(16, 16, 16): - with T.block(""): - vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vkk, vjj], 'float32') - - -@T.prim_func -def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: +def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") - B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, - scope="wmma.matrix_b") - C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, - scope="wmma.accumulator") with T.block("root"): - T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16]) + T.reads(A[0 : 16, 0 : 16]) T.writes(C[0 : 16, 0 : 16]) - T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), - A.data, A.elem_offset // 256, - B.data, B.elem_offset // 256, - C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle')) - - -@T.prim_func -def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") - - with T.block("root"): for i, j in T.grid(16, 16): with T.block("load"): vii, vjj = T.axis.remap("SS", [i, j]) @@ -193,34 +115,27 @@ def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: def wmma_load_a_impl(a: T.handle, c: T.handle) -> None: s1 = T.var("int32") s0 = T.var("int32") - A = T.match_buffer( - a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0] - ) + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - T.evaluate( - T.tvm_load_matrix_sync( - C.data, - 16, - 16, - 16, - C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), - A.access_ptr("r"), - s1, - "row_major", - dtype="handle", - ) - ) + T.reads(A[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) @T.prim_func def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") - C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_b") + with T.block("root"): + T.reads(A[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): with T.block("load"): vii, vjj = T.axis.remap("SS", [i, j]) @@ -231,34 +146,60 @@ def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: def wmma_load_b_impl(a: T.handle, c: T.handle) -> None: s1 = T.var("int32") s0 = T.var("int32") - A = T.match_buffer( - a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0] - ) + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - T.evaluate( - T.tvm_load_matrix_sync( - C.data, - 16, - 16, - 16, - C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), - A.access_ptr("r"), - s1, - "row_major", - dtype="handle", - ) - ) + T.reads(A[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "col_major", + dtype="handle")) + + +@T.prim_func +def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, + scope="wmma.accumulator") + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + for i, j, k in T.grid(16, 16, 16): + with T.block(""): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], 'float32') * T.cast(B[vjj, vkk], 'float32') + + +@T.prim_func +def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, + scope="wmma.accumulator") + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0: 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.data, A.elem_offset // 256, + B.data, B.elem_offset // 256, + C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), dtype='handle')) @T.prim_func def wmma_fill_desc(c: T.handle) -> None: - C = T.match_buffer( - c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + T.reads() + T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): with T.block("init"): vii, vjj = T.axis.remap("SS", [i, j]) @@ -267,32 +208,20 @@ def wmma_fill_desc(c: T.handle) -> None: @T.prim_func def wmma_fill_impl(c: T.handle) -> None: - C = T.match_buffer( - c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") with T.block("root"): - T.reads([]) - T.writes(C[0:16, 0:16]) - T.evaluate( - T.tvm_fill_fragment( - C.data, - 16, - 16, - 16, - C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), - T.float32(0), - dtype="handle", - ) - ) + T.reads() + T.writes(C[0 : 16, 0 : 16]) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle")) @T.prim_func def wmma_store_desc(a: T.handle, c: T.handle) -> None: - A = T.match_buffer( - a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") with T.block("root"): + T.reads(A[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): with T.block("store"): vii, vjj = T.axis.remap("SS", [i, j]) @@ -303,28 +232,14 @@ def wmma_store_desc(a: T.handle, c: T.handle) -> None: def wmma_store_impl(a: T.handle, c: T.handle) -> None: s1 = T.var("int32") s0 = T.var("int32") - A = T.match_buffer( - a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" - ) - C = T.match_buffer( - c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0] - ) + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]) with T.block("root"): - T.reads(A[0:16, 0:16]) - T.writes(C[0:16, 0:16]) - T.evaluate( - T.tvm_store_matrix_sync( - A.data, - 16, - 16, - 16, - A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), - C.access_ptr("w"), - s1, - "row_major", - dtype="handle", - ) - ) + T.reads(A[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate(T.tvm_store_matrix_sync( + A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major", + dtype="handle")) # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 2d944d112834..123a652cd250 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1547,16 +1547,20 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { << " may not be an iterator"; return GetRef(op); } - IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { return GetRef(op); } - PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); - if (!remainder.defined()) { - return GetRef(op); + ICHECK(preprocessed->args.size() <= 1); + if (preprocessed->args.empty()) { + return IterSumExpr({}, floordiv(preprocessed->base, b)); + } else { + PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); + } + return remainder; } - return remainder; } PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { @@ -1636,12 +1640,16 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (!preprocessed.defined()) { return GetRef(op); } - - PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); - if (!remainder.defined()) { - return GetRef(op); + ICHECK(preprocessed->args.size() <= 1); + if (preprocessed->args.empty()) { + return IterSumExpr({}, floormod(preprocessed->base, b)); + } else { + PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); + if (!remainder.defined()) { + return GetRef(op); + } + return remainder; } - return remainder; } /*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 86548c84df1c..16f60498f6d4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -51,6 +51,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.predicate_opt", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -208,6 +209,7 @@ Array CreatePassList(bool disable_loop_partition) { bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); + bool predicate_opt = pass_ctx->GetConfig("tir.predicate_opt", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -304,7 +306,7 @@ Array CreatePassList(bool disable_loop_partition) { } pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir)); - pass_list.push_back(tir::transform::OptimizePredicatedLoad(true)); + pass_list.push_back(tir::transform::OptimizePredicatedLoad(predicate_opt)); return pass_list; } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cdc17be2969b..f66b725acd27 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -170,6 +170,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); }); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/transforms/optimize_predicated_load.cc b/src/tir/transforms/optimize_predicated_load.cc index 51f607ea044a..03486ed0afaf 100644 --- a/src/tir/transforms/optimize_predicated_load.cc +++ b/src/tir/transforms/optimize_predicated_load.cc @@ -180,18 +180,19 @@ class LetVarBindingCanonicalizer : public ExprMutator { : var_range_(var_range) {} bool Canonicalize(const Var& top_var, const PrimExpr& binding) { - top_let_var_[binding] = top_var; - PrimExpr res = this->VisitExpr(binding); + PrimExpr res = this->VisitExpr(Substitute(binding, replace_map)); if (fail) return false; const SumFormNode* ret = res.as(); - ICHECK(ret != nullptr); + if (ret == nullptr) return false; ICHECK_EQ(ret->vars.size(), 1); if (!is_one(ret->scales[0]) || !is_zero(ret->base)) { let_var_buffer_map[top_var] = decl_buffer({int32(1)}, DataType::Int(32), top_var->name_hint, "local"); BuildAttachMap(top_var, ret->vars[0], Attach::AttachType::kAddition, ret->scales[0], ret->base); + } else { + replace_map[top_var] = ret->vars[0]; } return true; } @@ -199,6 +200,7 @@ class LetVarBindingCanonicalizer : public ExprMutator { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> inv_attach_map; std::unordered_map let_var_buffer_map; std::unordered_map attach_map; + std::unordered_map replace_map; private: void BuildAttachMap(const Var& cur_var, const Var& dependent_var, Attach::AttachType type, @@ -250,6 +252,7 @@ class LetVarBindingCanonicalizer : public ExprMutator { if (a != nullptr && b != nullptr) { ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); if (b->vars.size() == 0) { + if (a->vars.size() == 0) return SumForm({}, {}, int32(0)); // define let var for a Var inner; if (is_one(a->scales[0]) && is_zero(a->base)) { @@ -268,12 +271,9 @@ class LetVarBindingCanonicalizer : public ExprMutator { SearchExisitingAttach(inner, Attach::AttachType::kFloordiv, b->base); Optional var_mod = SearchExisitingAttach(inner, Attach::AttachType::kFloormod, b->base); - // introduce new intermediate vars if doesn't exsit now + // introduce new intermediate vars if doesn't exist now if (!var_div.defined()) { - auto it = top_let_var_.find(GetRef(op)); - var_div = it == top_let_var_.end() - ? inner.copy_with_suffix("_div_" + std::to_string(b->base->value)) - : it->second; + var_div = inner.copy_with_suffix("_div_" + std::to_string(b->base->value)); let_var_buffer_map[var_div.value()] = decl_buffer({int32(1)}, DataType::Int(32), var_div.value()->name_hint, "local"); BuildAttachMap(var_div.value(), inner, Attach::AttachType::kFloordiv, b->base, int32(0)); @@ -299,6 +299,7 @@ class LetVarBindingCanonicalizer : public ExprMutator { if (a != nullptr && b != nullptr) { ICHECK(a->vars.size() <= 1 && b->vars.size() <= 1); if (b->vars.size() == 0) { + if (a->vars.size() == 0) return SumForm({}, {}, int32(0)); // define let var for a Var inner; if (is_one(a->scales[0]) && is_zero(a->base)) { @@ -315,12 +316,9 @@ class LetVarBindingCanonicalizer : public ExprMutator { // first search for existing vars Optional var_mod = SearchExisitingAttach(inner, Attach::AttachType::kFloormod, b->base); - // introduce new intermediate var if doesn't exsits now + // introduce new intermediate var if doesn't exist now if (!var_mod.defined()) { - auto it = top_let_var_.find(GetRef(op)); - var_mod = it == top_let_var_.end() - ? inner.copy_with_suffix("_mod_" + std::to_string(b->base->value)) - : it->second; + var_mod = inner.copy_with_suffix("_mod_" + std::to_string(b->base->value)); let_var_buffer_map[var_mod.value()] = decl_buffer({int32(1)}, DataType::Int(32), var_mod.value()->name_hint, "local"); BuildAttachMap(var_mod.value(), inner, Attach::AttachType::kFloormod, b->base, int32(0)); @@ -382,7 +380,6 @@ class LetVarBindingCanonicalizer : public ExprMutator { } bool fail{false}; - std::unordered_map top_let_var_; std::unordered_map* var_range_; }; @@ -710,8 +707,10 @@ class PredicatePrecompute : public StmtMutator { LoadAddressLinearizer linearizer(&var_range_); if (!MatchLetVars(&canonicalizer)) return GetRef(store); local_predicate_map_.clear(); + // Replace the buffer store + BufferStore replaced_store = Downcast(Substitute(GetRef(store), canonicalizer.replace_map)); // Check the pattern of load address and predicate - const CallNode* call = store->value.as(); + const CallNode* call = replaced_store->value.as(); if (call != nullptr) { const OpNode* op = call->op.as(); if (op != nullptr && op->name == "tir.if_then_else") { @@ -853,9 +852,9 @@ class PredicatePrecompute : public StmtMutator { new_lhs = BufferLoad(load->buffer, {BufferLoad(buffer, {int32(0)})}); } return BufferStore( - store->buffer, - if_then_else(cast(DataType::Bool(1), new_predicate), new_lhs, rhs, store->span), - store->indices); + replaced_store->buffer, + if_then_else(cast(DataType::Bool(1), new_predicate), new_lhs, rhs, replaced_store->span), + replaced_store->indices); } } }