-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file codegen_arm.cc | ||
* \brief X86-64 specific code generator | ||
*/ | ||
#ifdef TVM_LLVM_VERSION | ||
#include "codegen_cpu.h" | ||
|
||
namespace tvm { | ||
namespace codegen { | ||
|
||
class CodeGenX86_64 final : public CodeGenCPU { | ||
public: | ||
llvm::Value* VisitExpr_(const Cast* op) override; | ||
}; | ||
|
||
llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { | ||
// LLVM does not automatically generate the correct instruction sequences for | ||
// half -> float conversion (using AVX2/AVX512 variants of vcvtph2ps). | ||
const auto from = op->value.type(); | ||
const auto to = op->type; | ||
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { | ||
CHECK_EQ(from.lanes(), to.lanes()); | ||
CHECK_NOTNULL(target_machine_); | ||
|
||
const auto has_f16c = | ||
target_machine_->getTargetFeatureString().find("f16c") != llvm::StringRef::npos; | ||
const auto has_avx512f = | ||
target_machine_->getTargetFeatureString().find("avx512f") != llvm::StringRef::npos; | ||
|
||
// TODO(tulloch): implement version generic over lanes. | ||
if (from.lanes() == 8 && (has_f16c || has_avx512f)) { | ||
Array<Expr> vcvt_args; | ||
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_vcvtph2ps_256; | ||
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id)); | ||
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0)); | ||
vcvt_args.push_back( | ||
ir::Call::make(Int(16, 8), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic)); | ||
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic)); | ||
} | ||
|
||
// TODO(tulloch): implement version generic over lanes. | ||
if (from.lanes() == 16 && has_avx512f) { | ||
Array<Expr> vcvt_args; | ||
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512; | ||
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id)); | ||
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0)); | ||
vcvt_args.push_back( | ||
ir::Call::make(Int(16, 16), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic)); | ||
vcvt_args.push_back(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), 16)); | ||
vcvt_args.push_back(ir::IntImm::make(Int(16), -1)); | ||
vcvt_args.push_back(ir::IntImm::make(Int(32), 4)); | ||
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic)); | ||
} | ||
} | ||
|
||
return CodeGenCPU::VisitExpr_(op); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") | ||
.set_body([](const TVMArgs& targs, TVMRetValue* rv) { | ||
CodeGenLLVM* cg = new CodeGenX86_64(); | ||
*rv = static_cast<void*>(cg); | ||
}); | ||
|
||
} // namespace codegen | ||
} // namespace tvm | ||
#endif // TVM_LLVM_VERSION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import tvm | ||
import re | ||
import os | ||
import ctypes | ||
|
||
def test_fp16_to_fp32_with_f16c(): | ||
target = 'llvm -mcpu=core-avx2 -mattr=+f16c' | ||
elements = 64 | ||
n = tvm.convert(elements) | ||
A = tvm.placeholder((n, 8), dtype="float16", name='A') | ||
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') | ||
s = tvm.create_schedule(B.op) | ||
s[B].vectorize(s[B].op.axis[1]) | ||
f = tvm.build(s, [A, B], target) | ||
|
||
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly | ||
ll = f.get_source('ll') | ||
assembly = f.get_source('asm').splitlines() | ||
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)] | ||
assert (len(matches) > 1) | ||
|
||
def test_fp16_to_fp32_with_avx512(): | ||
target = 'llvm -mcpu=skylake-avx512 -mattr=+avx512f,+f16c' | ||
elements = 64 | ||
n = tvm.convert(elements) | ||
A = tvm.placeholder((n, 16), dtype="float16", name='A') | ||
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') | ||
s = tvm.create_schedule(B.op) | ||
s[B].vectorize(s[B].op.axis[1]) | ||
f = tvm.build(s, [A, B], target) | ||
|
||
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly | ||
ll = f.get_source('ll') | ||
assembly = f.get_source('asm').splitlines() | ||
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)] | ||
assert (len(matches) > 1) | ||
|
||
def test_fp16_to_fp32_without_f16c(): | ||
target = 'llvm' | ||
elements = 64 | ||
n = tvm.convert(elements) | ||
A = tvm.placeholder((n, 8), dtype="float16", name='A') | ||
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') | ||
s = tvm.create_schedule(B.op) | ||
s[B].vectorize(s[B].op.axis[1]) | ||
f = tvm.build(s, [A, B], target) | ||
|
||
# Verify we see the correct number of vpaddl and vcnt instructions in the assembly | ||
ll = f.get_source('ll') | ||
assembly = f.get_source('asm').splitlines() | ||
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)] | ||
assert (len(matches) == 0) | ||
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)] | ||
assert (len(matches) == 0) | ||
|
||
if __name__ == "__main__": | ||
test_fp16_to_fp32_with_f16c() | ||
test_fp16_to_fp32_without_f16c() | ||
test_fp16_to_fp32_with_avx512() |