8000 [CMSIS-NN] Convert CMSIS-NN to use Target Hooks by Mousius · Pull Request #9397 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[CMSIS-NN] Convert CMSIS-NN to use Target Hooks #9397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import tvm

from tvm.driver import tvmc
from tvm import relay
from tvm import transform
from tvm._ffi import registry
Expand Down Expand Up @@ -206,6 +207,7 @@ def parse_target(target):
a key-value for all options passed via CLI; 'raw',
containing the plain string for this codegen
"""
codegen_names = tvmc.composite_target.get_codegen_names()
codegens = []

tvm_target_kinds = tvm.target.Target.list_kinds()
Expand All @@ -232,7 +234,7 @@ def parse_target(target):
for codegen_def in split_codegens:
# the first is expected to be the name
name = codegen_def[0]
is_tvm_target = name in tvm_target_kinds
is_tvm_target = name in tvm_target_kinds and name not in codegen_names
raw_target = " ".join(codegen_def)
all_opts = codegen_def[1:] if len(codegen_def) > 1 else []
opts = {}
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/driver/tvmc/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
This file contains functions for processing target inputs for the TVMC CLI
"""

from tvm.driver import tvmc
from tvm.target import Target

# We can't tell the type inside an Array but all current options are strings so
Expand All @@ -27,6 +28,11 @@
INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"}


def _valid_target_kinds():
codegen_names = tvmc.composite_target.get_codegen_names()
return filter(lambda target: target not in codegen_names, Target.list_kinds())


def _generate_target_kind_args(parser, kind):
target_group = parser.add_argument_group(f"target {kind.name}")
for target_option, target_type in kind.options.items():
Expand All @@ -45,8 +51,7 @@ def generate_target_args(parser):
help="compilation target as plain string, inline JSON or path to a JSON file",
required=True,
)
target_kinds = Target.list_kinds()
for target_kind in target_kinds:
for target_kind in _valid_target_kinds():
target = Target(target_kind)
_generate_target_kind_args(parser, target.kind)

Expand All @@ -55,7 +60,7 @@ def _reconstruct_target_kind_args(args, kind):
kind_options = {}
for target_option, target_type in kind.options.items():
if target_type in INTERNAL_TO_NATIVE_TYPE:
var_name = f"target_{kind.name}_{target_option.replace('-', '_')}"
var_name = f"target_{kin 8000 d.name.replace('-', '_')}_{target_option.replace('-', '_')}"
option_value = getattr(args, var_name)
if option_value is not None:
kind_options[target_option] = getattr(args, var_name)
Expand All @@ -64,9 +69,8 @@ def _reconstruct_target_kind_args(args, kind):

def reconstruct_target_args(args):
"""Reconstructs the target options from the arguments"""
target_kinds = Target.list_kinds()
reconstructed = {}
for target_kind in target_kinds:
for target_kind in _valid_target_kinds():
target = Target(target_kind)
kind_options = _reconstruct_target_kind_args(args, target.kind)
if kind_options:
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Arm(R) CMSIS-NN supported operators for Cortex-M."""
import tvm.ir
from tvm.target import Target
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

Expand All @@ -25,7 +26,7 @@


def enabled():
return bool(tvm.get_global_func("relay.ext.cmsisnn", True))
return "cmsis-nn" in Target.list_kinds()


def partition_for_cmsisnn(mod, params=None, **opts):
Expand All @@ -51,7 +52,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
[
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cmsisnn"),
transform.AnnotateTarget("cmsis-nn"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
]
Expand All @@ -60,9 +61,9 @@ def partition_for_cmsisnn(mod, params=None, **opts):
return seq(mod)


@register_pattern_table("cmsisnn")
@register_pattern_table("cmsis-nn")
def pattern_table():
"""Get the cmsisnn compiler pattern table."""
"""Get the CMSIS-NN compiler pattern table."""

def softmax_pattern():
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
Expand Down Expand Up @@ -104,14 +105,14 @@ def check_quantized_binary_op(extract):
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
(
"cmsisnn.quantized_mul",
"cmsis-nn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsisnn.quantized_add",
"cmsis-nn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
Expand Down
5 changes: 4 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
// The host Target contains these parameters at the moment rather than
// the specific Target
// TODO(Mousius) - Move these to the Executor object rather than Target
if (target->GetHost().value()->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
Expand Down
127 changes: 73 additions & 54 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/builtin.h>
Expand All @@ -33,29 +34,46 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

class RelayToTIRVisitor : public MixedModeVisitor {
class RelayToTIRVisitor : public MixedModeMutator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be renamed RelayToTIRVisitor --> RelayToTIRMutator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move this to a follow up 😸

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't mean it as a blocking comment. A follow up is fine, should've marked as nit.

public:
explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}
explicit RelayToTIRVisitor(IRModule ir_module, Target target)
: ir_module_(ir_module), target_(target) {}

tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }
IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);

ir_module_->Update(main_global_var, mutated_main);

return ir_module_;
}

private:
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
tvm::Array<PrimExpr> call_extern_args) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
dict_attrs.Set(tvm::attr::kTarget, target_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

void EmitSoftMax(const Expr& expr) {
void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -102,10 +120,10 @@ class RelayToTIRVisitor : public MixedModeVisitor {
out_var,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void EmitMul(const Expr& expr) {
void EmitMul(const GlobalVar& global_var, const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = GetScalarFromConstant<float>(mul_call->args[2]);
Expand Down Expand Up @@ -145,10 +163,10 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void EmitAdd(const Expr& expr) {
void EmitAdd(const GlobalVar& global_var, const Expr& expr) {
auto* add_call = expr.as<CallNode>();

const float input_0_scale = GetScalarFromConstant<float>(add_call->args[2]);
Expand Down Expand Up @@ -212,58 +230,59 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
CreatePrimFuncForExtern(global_var, func_signature, args);
}

void VisitExpr_(const CallNode* call) final {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return;
}

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
if (comp_name == "cmsisnn.quantized_softmax") {
EmitSoftMax(func->body);
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return post;
}
if (comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
if (comp_name == "cmsisnn.quantized_add") {
EmitAdd(func->body);

auto codegen_name = func->GetAttr<String>(attr::kCompiler);
if (codegen_name.defined() && codegen_name == "cmsis-nn") {
const CallNode* inner_call = func->body.as<CallNode>();
const FunctionNode* composite_func = inner_call->op.as<FunctionNode>();
auto comp_name = composite_func->GetAttr<String>(attr::kComposite);
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);

GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = composite_func->checked_type();

if (comp_name == "cmsis-nn.quantized_softmax") {
EmitSoftMax(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.quantized_mul") {
EmitMul(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.quantized_add") {
EmitAdd(new_global_var, composite_func->body);
}

Array<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}

return Call(new_global_var, args, call->attrs, call->type_args, call->span);
}
}
}

public:
String func_name_;
tir::PrimFunc primfunc_;
};

IRModule GenerateTIR(IRModule mod) {
String func_name;
Function func;

// Obtain external Relay Function that needs to be translated into TIR
ICHECK(mod->functions.size() == 1) << "Supports modules with single external Relay function.";
for (auto kv : mod->functions) {
func = Downcast<Function>(kv.second);
func_name = func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
return post;
}

// Prepare PrimFunc from Relay Function
auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);

// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}
private:
IRModule ir_module_;
Target target_;
};

transform::Pass RelayToTIR() {
tvm::transform::Pass RelayToTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return GenerateTIR(m); };
[=](IRModule ir_module, transform::PassContext pass_context) {
auto relay_to_tir = RelayToTIRVisitor(ir_module, Target("cmsis-nn"));
return relay_to_tir.Mutate();
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand All @@ -16,34 +17,22 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relay/transform.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>

namespace tvm {

namespace relay {
namespace contrib {
namespace cmsisnn {

transform::Pass RelayToTIR();

runtime::Module CompileCMSISNN(const ObjectRef& ref) {
IRModule relay_mod;
Function relay_func = Downcast<Function>(ref);
auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
GlobalVar var = GlobalVar(func_name.value());
relay_mod->Add(var, relay_func);
relay_mod = transform::InferType()(relay_mod);

Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()};
transform::Sequential seq(pass_seqs);
IRModule tir_mod = seq(relay_mod);

const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate");
return (*pf)(tir_mod);
}
tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN);
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);

} // namespace cmsisnn
} // namespace contrib
Expand Down
Loading
0