Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: truncate Float16 vector ops also #46130

Merged
merged 1 commit into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ configure-y: | $(BUILDDIRMAKE)
configure:
ifeq ("$(origin O)", "command line")
@if [ "$$(ls '$(BUILDROOT)' 2> /dev/null)" ]; then \
echo 'WARNING: configure called on non-empty directory $(BUILDROOT)'; \
printf $(WARNCOLOR)'WARNING: configure called on non-empty directory'$(ENDCOLOR)' %s\n' '$(BUILDROOT)'; \
read -p "Proceed [y/n]? " answer; \
else \
answer=y;\
fi; \
[ $$answer = 'y' ] && $(MAKE) configure-$$answer
[ "y$$answer" = yy ] && $(MAKE) configure-$$answer
else
$(error "cannot rerun configure from within a build directory")
endif
Expand Down
120 changes: 63 additions & 57 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,21 @@ namespace {
static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -64,6 +73,7 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
// TODO: Do calls to llvm.fma.f16 may need to go to f64 to be correct?
continue;
}

Expand All @@ -75,72 +85,68 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
if (Op->getType()->getScalarType()->isHalfTy()) {
++TotalExt;
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
}
Operands[i] = (Op);
Operands[i] = Op;
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
24 changes: 21 additions & 3 deletions test/llvmpasses/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,29 @@ include $(JULIAHOME)/Make.inc

check: .

TESTS = $(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.jl $(SRCDIR)/*.ll))
TESTS_ll := $(filter-out update-%,$(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.ll)))
TESTS_jl := $(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.jl))
TESTS := $(TESTS_ll) $(TESTS_jl)

. $(TESTS):
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(build_depsbindir)/lit/lit.py -v $(addprefix $(SRCDIR)/,$@)
$(build_depsbindir)/lit/lit.py -v "$(addprefix $(SRCDIR)/,$@)"

.PHONY: $(TESTS) check all .
$(addprefix update-,$(TESTS_ll)):
@echo 'NOTE: This requires a llvm source files locally, such as via `make -C deps USE_BINARYBUILDER_LLVM=0 DEPS_GIT=llvm checkout-llvm`'
@read -p "$$(printf $(WARNCOLOR)'This will directly modify %s, are you sure you want to proceed? '$(ENDCOLOR) '$@')" REPLY && [ yy = "y$$REPLY" ]
sed -e 's/%shlibext/.$(SHLIB_EXT)/g' < "$(@:update-%=$(SRCDIR)/%)" > "$@"
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(JULIAHOME)/deps/srccache/llvm/llvm/utils/update_test_checks.py "$@" \
--preserve-names
mv "$@" "$(@:update-%=$(SRCDIR)/%)"

update-help:
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(JULIAHOME)/deps/srccache/llvm/llvm/utils/update_test_checks.py \
--help

.PHONY: $(TESTS) $(addprefix update-,$(TESTS_ll)) check all .
64 changes: 64 additions & 0 deletions test/llvmpasses/float16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s

define half @demotehalf_test(half %a, half %b) {
; CHECK-LABEL: @demotehalf_test(
; CHECK-NEXT: top:
; CHECK-NEXT: %0 = fpext half %a to float
; CHECK-NEXT: %1 = fpext half %b to float
; CHECK-NEXT: %2 = fadd float %0, %1
; CHECK-NEXT: %3 = fptrunc float %2 to half
; CHECK-NEXT: %4 = fpext half %3 to float
; CHECK-NEXT: %5 = fpext half %b to float
; CHECK-NEXT: %6 = fadd float %4, %5
; CHECK-NEXT: %7 = fptrunc float %6 to half
; CHECK-NEXT: %8 = fpext half %7 to float
; CHECK-NEXT: %9 = fpext half %b to float
; CHECK-NEXT: %10 = fadd float %8, %9
; CHECK-NEXT: %11 = fptrunc float %10 to half
; CHECK-NEXT: %12 = fpext half %11 to float
; CHECK-NEXT: %13 = fpext half %b to float
; CHECK-NEXT: %14 = fmul float %12, %13
; CHECK-NEXT: %15 = fptrunc float %14 to half
; CHECK-NEXT: %16 = fpext half %15 to float
; CHECK-NEXT: %17 = fpext half %b to float
; CHECK-NEXT: %18 = fdiv float %16, %17
; CHECK-NEXT: %19 = fptrunc float %18 to half
; CHECK-NEXT: %20 = insertelement <2 x half> undef, half %a, i32 0
; CHECK-NEXT: %21 = insertelement <2 x half> %20, half %b, i32 1
; CHECK-NEXT: %22 = insertelement <2 x half> undef, half %b, i32 0
; CHECK-NEXT: %23 = insertelement <2 x half> %22, half %b, i32 1
; CHECK-NEXT: %24 = fpext <2 x half> %21 to <2 x float>
; CHECK-NEXT: %25 = fpext <2 x half> %23 to <2 x float>
; CHECK-NEXT: %26 = fadd <2 x float> %24, %25
; CHECK-NEXT: %27 = fptrunc <2 x float> %26 to <2 x half>
; CHECK-NEXT: %28 = extractelement <2 x half> %27, i32 0
; CHECK-NEXT: %29 = extractelement <2 x half> %27, i32 1
; CHECK-NEXT: %30 = fpext half %28 to float
; CHECK-NEXT: %31 = fpext half %29 to float
; CHECK-NEXT: %32 = fadd float %30, %31
; CHECK-NEXT: %33 = fptrunc float %32 to half
; CHECK-NEXT: %34 = fpext half %33 to float
; CHECK-NEXT: %35 = fpext half %19 to float
; CHECK-NEXT: %36 = fadd float %34, %35
; CHECK-NEXT: %37 = fptrunc float %36 to half
; CHECK-NEXT: ret half %37
;
top:
%0 = fadd half %a, %b
%1 = fadd half %0, %b
%2 = fadd half %1, %b
%3 = fmul half %2, %b
%4 = fdiv half %3, %b
%5 = insertelement <2 x half> undef, half %a, i32 0
%6 = insertelement <2 x half> %5, half %b, i32 1
%7 = insertelement <2 x half> undef, half %b, i32 0
%8 = insertelement <2 x half> %7, half %b, i32 1
%9 = fadd <2 x half> %6, %8
%10 = extractelement <2 x half> %9, i32 0
%11 = extractelement <2 x half> %9, i32 1
%12 = fadd half %10, %11
%13 = fadd half %12, %4
ret half %13
}