8000 [LANG/PASS] Support Vectorize by tqchen · Pull Request #37 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[LANG/PASS] Support Vectorize #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class IRMutator {
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ Stmt StorageFlatten(Stmt stmt,
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
57 changes: 56 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class StageNode;
class ScheduleNode;
// Node container for IterVarRelation
class IterVarRelationNode;
// Attribute of itervar.
class IterVarAttrNode;

/*! \brief the attachment type */
enum AttachType : int {
Expand All @@ -27,6 +29,12 @@ enum AttachType : int {
kScope = 3
};

/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
};

/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
public:
Expand Down Expand Up @@ -123,12 +131,23 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
/*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
*/
inline bool is_scheduled() const;

// declare container type
using ContainerType = StageNode;
};
Expand Down Expand Up @@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef {
inline const IterVarRelationNode* operator->() const;
};

/*!
* \brief Additional scheduable attributes about IterVar.
*/
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarAttrNode* operator->() const;
};

// defintion of node containers
/*!
* \brief represents the schedule of the tensor
Expand Down Expand Up @@ -223,6 +257,8 @@ class StageNode : public Node {
Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */
AttachType attach_type{kNone};
/*! \brief The attach point of this schedule. */
Expand All @@ -236,6 +272,7 @@ class StageNode : public Node {
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
Expand Down Expand Up @@ -268,6 +305,20 @@ class ScheduleNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
};

/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
public:
/*! \brief The iteration type. */
IterVarType iter_type;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
}

static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode);
};

/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
};
Expand Down Expand Up @@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
}

inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_SCHEDULE_H_
1 change: 1 addition & 0 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,23 @@ def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner

def vectorize(self, var):
"""Vectorize the iteration.

Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)

def unroll(self, var):
"""Unroll the iteration.

Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
12 changes: 12 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.unroll(args[1]);
});

TVM_REGISTER_API(_StageVectorize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.vectorize(args[1]);
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
Expand Down
18 changes: 18 additions & 0 deletions src/arithmetic/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <tvm/ir.h>
#include <pass/Interval.h>
#include <limits>

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
}
}

// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}

#define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \
Expand Down
Loading
0