8000 Do target-specific lowering of lerp by abadams · Pull Request #6432 · halide/Halide · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Do target-specific lowering of lerp #6432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@ void CodeGen_C::visit(const Call *op) {
}
} else if (op->is_intrinsic(Call::lerp)) {
internal_assert(op->args.size() == 3);
Expr e = lower_lerp(op->args[0], op->args[1], op->args[2]);
Expr e = lower_lerp(op->args[0], op->args[1], op->args[2], target);
rhs << print_expr(e);
} else if (op->is_intrinsic(Call::absd)) {
internal_assert(op->args.size() == 2);
Expand Down
3 changes: 2 additions & 1 deletion src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2697,7 +2697,8 @@ void CodeGen_LLVM::visit(const Call *op) {
Type wt = upgrade_type_for_arithmetic(op->args[2].type());
Expr e = lower_lerp(cast(t, op->args[0]),
cast(t, op->args[1]),
cast(wt, op->args[2]));
cast(wt, op->args[2]),
target);
e = cast(op->type, e);
codegen(e);
} else if (op->is_intrinsic(Call::popcount)) {
Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ class OptimizePatterns : public IRMutator {
// We need to lower lerps now to optimize the arithmetic
// that they generate.
internal_assert(op->args.size() == 3);
return mutate(lower_lerp(op->args[0], op->args[1], op->args[2]));
return mutate(lower_lerp(op->args[0], op->args[1], op->args[2], target));
} else if ((op->is_intrinsic(Call::div_round_to_zero) ||
op->is_intrinsic(Call::mod_round_to_zero)) &&
!op->type.is_float() && op->type.is_vector()) {
Expand Down
29 changes: 21 additions & 8 deletions src/Lerp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#include "IROperator.h"
#include "Lerp.h"
#include "Simplify.h"
#include "Target.h"

namespace Halide {
namespace Internal {

Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight) {
Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target) {

Expr result;

Expand Down Expand Up @@ -134,13 +135,25 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight) {
case 8:
case 16:
case 32: {
Expr shift = Cast::make(UInt(2 * bits), bits);
Expr prod_sum = widening_mul(zero_val, inverse_typed_weight) + widening_mul(one_val, typed_weight);
// Computes x / (2 ** N - 1) as (x / 2 ** N + x) / 2 ** N.
// TODO: on x86 it's actually one instruction cheaper to do the division directly.
Expr divided = rounding_shift_right(rounding_shift_right(prod_sum, shift) + prod_sum, shift);

result = Cast::make(UInt(bits, computation_type.lanes()), divided);
Expr prod_sum = (widening_mul(zero_val, inverse_typed_weight) +
widening_mul(one_val, typed_weight));
// Now we need to do a rounding divide and narrow. For
// 8-bit, this rounding divide looks like (x + 127) /
// 255. On most platforms it's we can compute this as
// ((x + 128) / 256 + x + 128) / 256. Note that
// overflow is impossible here because the most our
// prod_sum can be is 255^2.
if (target.arch == Target::X86) {
// On x86 we have no rounding shifts but we do
// have a multiply-keep-high-half. So it's
// actually one instruction cheaper to do the
// division directly.
Expr divisor = cast(UInt(bits), -1);
result = (prod_sum + divisor / 2) / divisor;
} else {
result = rounding_shift_right(rounding_shift_right(prod_sum, bits) + prod_sum, bits);
}
result = Cast::make(UInt(bits, computation_type.lanes()), result);
break;
}
case 64:
Expand Down
5 changes: 4 additions & 1 deletion src/Lerp.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
#include "Expr.h"

namespace Halide {

struct Target;

namespace Internal {

/** Build Halide IR that computes a lerp. Use by codegen targets that
* don't have a native lerp. */
Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight);
Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target);

} // namespace Internal
} // namespace Halide
Expand Down
0