Skip to content

Commit

Permalink
[SPIRV] Implement log10 for logical SPIR-V (#66921)
Browse files Browse the repository at this point in the history
There is no log10 instruction in the GLSL Extended Instruction Set so to
implement the HLSL log10 intrinsic when targeting Vulkan this change
adds the logic to derive the result using the following formula:
```
log10(x) = log2(x) * (1 / log2(10))
         = log2(x) * 0.30103
```
  • Loading branch information
sudonatalie authored Oct 6, 2023
1 parent 469b9cb commit 0a2aaab
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 14 deletions.
25 changes: 14 additions & 11 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,27 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
auto &MF = MIRBuilder.getMF();
const Type *LLVMFPTy;
if (SpvType) {
LLVMFPTy = getTypeForSPIRVType(SpvType);
assert(LLVMFPTy->isFloatingPointTy());
} else {
LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
auto &Ctx = MF.getFunction().getContext();
if (!SpvType) {
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
}
// Find a constant in DT or build a new one.
const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
MIRBuilder.buildFConstant(Res, *ConstFP);

MachineInstrBuilder MIB;
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
}

return Res;
}

Expand Down
56 changes: 55 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, const ExtInstList &ExtInsts) const;

bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Register buildI32Constant(uint32_t Val, MachineInstr &I,
const SPIRVType *ResType = nullptr) const;

Expand Down Expand Up @@ -362,7 +365,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FLOG2:
return selectExtInst(ResVReg, ResType, I, CL::log2, GL::Log2);
case TargetOpcode::G_FLOG10:
return selectExtInst(ResVReg, ResType, I, CL::log10);
return selectLog10(ResVReg, ResType, I);

case TargetOpcode::G_FABS:
return selectExtInst(ResVReg, ResType, I, CL::fabs, GL::FAbs);
Expand Down Expand Up @@ -1562,6 +1565,57 @@ bool SPIRVInstructionSelector::selectGlobalValue(
return Reg.isValid();
}

bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
return selectExtInst(ResVReg, ResType, I, CL::log10);
}

// There is no log10 instruction in the GLSL Extended Instruction set, so it
// is implemented as:
// log10(x) = log2(x) * (1 / log2(10))
// = log2(x) * 0.30103

MachineIRBuilder MIRBuilder(I);
MachineBasicBlock &BB = *I.getParent();

// Build log2(x).
Register VarReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(VarReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Log2)
.add(I.getOperand(1))
.constrainAllUses(TII, TRI, RBI);

// Build 0.30103.
assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
ResType->getOpcode() == SPIRV::OpTypeFloat);
// TODO: Add matrix implementation once supported by the HLSL frontend.
const SPIRVType *SpirvScalarType =
ResType->getOpcode() == SPIRV::OpTypeVector
? GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg())
: ResType;
Register ScaleReg =
GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType);

// Multiply log2(x) by 0.30103 to get log10(x) result.
auto Opcode = ResType->getOpcode() == SPIRV::OpTypeVector
? SPIRV::OpVectorTimesScalar
: SPIRV::OpFMulS;
Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(VarReg)
.addUse(ScaleReg)
.constrainAllUses(TII, TRI, RBI);

return Result;
}

namespace llvm {
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});

// TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
// tighten these requirements. Many of these math functions are only legal on
// specific bitwidths, so they are not selectable for
// allFloatScalarsAndVectors.
getActionDefinitionsBuilder({G_FPOW,
G_FEXP,
G_FEXP2,
G_FLOG,
G_FLOG2,
G_FLOG10,
G_FABS,
G_FMINNUM,
G_FMAXNUM,
Expand All @@ -259,8 +264,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
allFloatScalarsAndVectors, allIntScalarsAndVectors);

if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);

getActionDefinitionsBuilder(
{G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
.legalForCartesianProduct(allIntScalarsAndVectors,
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s

; CHECK: %[[#extinst:]] = OpExtInstImport "GLSL.std.450"

; CHECK: %[[#float:]] = OpTypeFloat 32
; CHECK: %[[#v4float:]] = OpTypeVector %[[#float]] 4
; CHECK: %[[#float_0_30103001:]] = OpConstant %[[#float]] 0.30103000998497009
; CHECK: %[[#_ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
; CHECK: %[[#_ptr_Function_float:]] = OpTypePointer Function %[[#float]]

define void @main() {
entry:
; CHECK: %[[#f:]] = OpVariable %[[#_ptr_Function_float]] Function
; CHECK: %[[#logf:]] = OpVariable %[[#_ptr_Function_float]] Function
; CHECK: %[[#f4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
; CHECK: %[[#logf4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
%f = alloca float, align 4
%logf = alloca float, align 4
%f4 = alloca <4 x float>, align 16
%logf4 = alloca <4 x float>, align 16

; CHECK: %[[#load:]] = OpLoad %[[#float]] %[[#f]] Aligned 4
; CHECK: %[[#log2:]] = OpExtInst %[[#float]] %[[#extinst]] Log2 %[[#load]]
; CHECK: %[[#res:]] = OpFMul %[[#float]] %[[#log2]] %[[#float_0_30103001]]
; CHECK: OpStore %[[#logf]] %[[#res]] Aligned 4
%0 = load float, ptr %f, align 4
%elt.log10 = call float @llvm.log10.f32(float %0)
store float %elt.log10, ptr %logf, align 4

; CHECK: %[[#load:]] = OpLoad %[[#v4float]] %[[#f4]] Aligned 16
; CHECK: %[[#log2:]] = OpExtInst %[[#v4float]] %[[#extinst]] Log2 %[[#load]]
; CHECK: %[[#res:]] = OpVectorTimesScalar %[[#v4float]] %[[#log2]] %[[#float_0_30103001]]
; CHECK: OpStore %[[#logf4]] %[[#res]] Aligned 16
%1 = load <4 x float>, ptr %f4, align 16
%elt.log101 = call <4 x float> @llvm.log10.v4f32(<4 x float> %1)
store <4 x float> %elt.log101, ptr %logf4, align 16

ret void
}

declare float @llvm.log10.f32(float)
declare <4 x float> @llvm.log10.v4f32(<4 x float>)

0 comments on commit 0a2aaab

Please sign in to comment.