Skip to content

Commit

Permalink
Fix vcvtph2ps codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtulloch committed Mar 30, 2019
1 parent 4ac64fc commit d529109
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/codegen/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_arm.cc
* \brief X86-64 specific code generator
*/
#ifdef TVM_LLVM_VERSION
#include "codegen_cpu.h"

namespace tvm {
namespace codegen {

class CodeGenX86_64 final : public CodeGenCPU {
public:
llvm::Value* VisitExpr_(const Cast* op) override;
};

llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
// LLVM does not automatically generate the correct instruction sequences for
// half -> float conversion (using AVX2/AVX512 variants of vcvtph2ps).
const auto from = op->value.type();
const auto to = op->type;
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
CHECK_EQ(from.lanes(), to.lanes());
CHECK_NOTNULL(target_machine_);

const auto has_f16c =
target_machine_->getTargetFeatureString().find("f16c") != llvm::StringRef::npos;
const auto has_avx512f =
target_machine_->getTargetFeatureString().find("avx512f") != llvm::StringRef::npos;

// TODO(tulloch): implement version generic over lanes.
if (from.lanes() == 8 && (has_f16c || has_avx512f)) {
Array<Expr> vcvt_args;
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_vcvtph2ps_256;
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id));
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0));
vcvt_args.push_back(
ir::Call::make(Int(16, 8), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic));
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic));
}

// TODO(tulloch): implement version generic over lanes.
if (from.lanes() == 16 && has_avx512f) {
Array<Expr> vcvt_args;
::llvm::Intrinsic::ID vcvtph2ps_id = ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512;
vcvt_args.push_back(ir::UIntImm::make(UInt(32), vcvtph2ps_id));
vcvt_args.push_back(ir::UIntImm::make(UInt(32), 0));
vcvt_args.push_back(
ir::Call::make(Int(16, 16), ir::Call::reinterpret, {op->value}, ir::Call::PureIntrinsic));
vcvt_args.push_back(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), 16));
vcvt_args.push_back(ir::IntImm::make(Int(16), -1));
vcvt_args.push_back(ir::IntImm::make(Int(32), 4));
return MakeValue(ir::Call::make(to, "llvm_intrin", vcvt_args, ir::Call::PureIntrinsic));
}
}

return CodeGenCPU::VisitExpr_(op);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
CodeGenLLVM* cg = new CodeGenX86_64();
*rv = static_cast<void*>(cg);
});

} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
59 changes: 59 additions & 0 deletions tests/python/unittest/test_codegen_x86.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import tvm
import re
import os
import ctypes

def test_fp16_to_fp32_with_f16c():
target = 'llvm -mcpu=core-avx2 -mattr=+f16c'
elements = 64
n = tvm.convert(elements)
A = tvm.placeholder((n, 8), dtype="float16", name='A')
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[1])
f = tvm.build(s, [A, B], target)

# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
ll = f.get_source('ll')
assembly = f.get_source('asm').splitlines()
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)]
assert (len(matches) > 1)

def test_fp16_to_fp32_with_avx512():
target = 'llvm -mcpu=skylake-avx512 -mattr=+avx512f,+f16c'
elements = 64
n = tvm.convert(elements)
A = tvm.placeholder((n, 16), dtype="float16", name='A')
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[1])
f = tvm.build(s, [A, B], target)

# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
ll = f.get_source('ll')
assembly = f.get_source('asm').splitlines()
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)]
assert (len(matches) > 1)

def test_fp16_to_fp32_without_f16c():
target = 'llvm'
elements = 64
n = tvm.convert(elements)
A = tvm.placeholder((n, 8), dtype="float16", name='A')
B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B')
s = tvm.create_schedule(B.op)
s[B].vectorize(s[B].op.axis[1])
f = tvm.build(s, [A, B], target)

# Verify we see the correct number of vpaddl and vcnt instructions in the assembly
ll = f.get_source('ll')
assembly = f.get_source('asm').splitlines()
matches = [l for l in assembly if re.search("vcvtph2ps.*ymm", l)]
assert (len(matches) == 0)
matches = [l for l in assembly if re.search("vcvtph2ps.*zmm", l)]
assert (len(matches) == 0)

if __name__ == "__main__":
test_fp16_to_fp32_with_f16c()
test_fp16_to_fp32_without_f16c()
test_fp16_to_fp32_with_avx512()

0 comments on commit d529109

Please sign in to comment.