Skip to content

Commit

Permalink
Fix zero extend arithmetic operation for all three backends
Browse files Browse the repository at this point in the history
  • Loading branch information
gigiblender committed Sep 22, 2024
1 parent ed4b6ab commit ea7b602
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,23 @@ public static JavaKind convertJavaKind(JavaType type) {
};
}


public static JavaKind javaKindFromBitSize(int bitSize, boolean isFloat) {
if (isFloat) {
return switch (bitSize) {
case 32 -> JavaKind.Float;
case 64 -> JavaKind.Double;
default -> throw new IllegalArgumentException("Unsupported floating point bit size: " + bitSize);
};
} else {
return switch (bitSize) {
case 8 -> JavaKind.Byte;
case 16 -> JavaKind.Short;
case 32 -> JavaKind.Int;
case 64 -> JavaKind.Long;
default -> throw new IllegalArgumentException("Unsupported integer bit size: " + bitSize);
};
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ public OCLLIRKindTool(OCLTargetDescription target) {
this.target = target;
}


public LIRKind getUnsignedIntegerKind(int numBits) {
if (numBits <= 8) {
return LIRKind.value(OCLKind.UCHAR);
} else if (numBits <= 16) {
return LIRKind.value(OCLKind.USHORT);
} else if (numBits <= 32) {
return LIRKind.value(OCLKind.UINT);
} else if (numBits <= 64) {
return LIRKind.value(OCLKind.ULONG);
} else {
throw shouldNotReachHere();
}
}

@Override
public LIRKind getIntegerKind(int numBits) {
if (numBits <= 8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
public enum OCLVariablePrefix {
// @formatter:off
ULONG("ulong", "ul_"),
UINT("uint", "ui_"),
USHORT("ushort", "us_"),
UCHAR("uchar", "uc_"),
INT("int", "i_"),
LONG("long", "l_"),
BOOL("bool", "b_"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLUnaryIntrinsic.RSQRT;

import jdk.vm.ci.meta.AllocatableValue;
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.PlatformKind;
import jdk.vm.ci.meta.PrimitiveConstant;
import jdk.vm.ci.meta.Value;
import jdk.vm.ci.meta.ValueKind;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.calc.FloatConvert;
import org.graalvm.compiler.core.common.memory.MemoryExtendKind;
Expand All @@ -35,12 +41,7 @@
import org.graalvm.compiler.lir.Variable;
import org.graalvm.compiler.lir.gen.ArithmeticLIRGenerator;

import jdk.vm.ci.meta.AllocatableValue;
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.PlatformKind;
import jdk.vm.ci.meta.PrimitiveConstant;
import jdk.vm.ci.meta.Value;
import jdk.vm.ci.meta.ValueKind;
import uk.ac.manchester.tornado.drivers.common.code.CodeUtil;
import uk.ac.manchester.tornado.drivers.common.logging.Logger;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLArchitecture.OCLMemoryBase;
import uk.ac.manchester.tornado.drivers.opencl.graal.OCLLIRKindTool;
Expand Down Expand Up @@ -284,16 +285,16 @@ public Value emitZeroExtend(Value value, int fromBits, int toBits) {
OCLKind kind = (OCLKind) value.getPlatformKind();
LIRKind toKind;
if (kind.isInteger()) {
toKind = kindTool.getIntegerKind(toBits);
toKind = kindTool.getUnsignedIntegerKind(toBits);
} else if (kind.isFloating()) {
toKind = kindTool.getFloatingKind(toBits);
} else {
throw shouldNotReachHere();
}

Variable result = getGen().newVariable(toKind);

getGen().emitMove(result, value);
// Apply a bitwise mask in order to avoid sign extension and instead zero extend the value.
ConstantValue mask = new ConstantValue(toKind, JavaConstant.forIntegerKind(CodeUtil.javaKindFromBitSize(toBits, kind.isFloating()), (1L << fromBits) - 1));
Variable result = emitBinaryAssign(OCLBinaryOp.BITWISE_AND, toKind, value, mask);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ public PTXLIRKindTool(PTXTargetDescription target) {
this.target = target;
}

public LIRKind getUnsignedIntegerKind(int numBits) {
if (numBits <= 8) {
return LIRKind.value(PTXKind.U8);
} else if (numBits <= 16) {
return LIRKind.value(PTXKind.U16);
} else if (numBits <= 32) {
return LIRKind.value(PTXKind.U32);
} else if (numBits <= 64) {
return LIRKind.value(PTXKind.U64);
} else {
throw shouldNotReachHere();
}
}

@Override
public LIRKind getIntegerKind(int bits) {
if (bits <= 8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
import static uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler.PTXUnaryIntrinsic.RSQRT;

import jdk.vm.ci.meta.JavaConstant;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.calc.FloatConvert;
import org.graalvm.compiler.core.common.memory.MemoryExtendKind;
import org.graalvm.compiler.core.common.memory.MemoryOrderMode;
import org.graalvm.compiler.lir.ConstantValue;
import org.graalvm.compiler.lir.LIRFrameState;
import org.graalvm.compiler.lir.Variable;
import org.graalvm.compiler.lir.gen.ArithmeticLIRGenerator;

import jdk.vm.ci.meta.PlatformKind;
import jdk.vm.ci.meta.Value;
import jdk.vm.ci.meta.ValueKind;
import uk.ac.manchester.tornado.drivers.common.code.CodeUtil;
import uk.ac.manchester.tornado.drivers.common.logging.Logger;
import uk.ac.manchester.tornado.drivers.ptx.graal.PTXLIRKindTool;
import uk.ac.manchester.tornado.drivers.ptx.graal.asm.PTXAssembler;
Expand Down Expand Up @@ -198,7 +201,22 @@ public Value emitNarrow(Value inputVal, int bits) {
@Override
public Value emitSignExtend(Value inputVal, int fromBits, int toBits) {
Logger.traceBuildLIR(Logger.BACKEND.PTX, "emitSignExtend inputVal=%s fromBits=%d toBits=%d", inputVal, fromBits, toBits);
return emitZeroExtend(inputVal, fromBits, toBits);
PTXLIRKindTool kindTool = getGen().getLIRKindTool();
PTXKind kind = (PTXKind) inputVal.getPlatformKind();
LIRKind toKind;
if (kind.isInteger()) {
toKind = kindTool.getIntegerKind(toBits);
} else if (kind.isFloating()) {
toKind = kindTool.getFloatingKind(toBits);
} else {
throw shouldNotReachHere();
}

Variable result = getGen().newVariable(toKind);

getGen().emitMove(result, inputVal);
return result;

}

@Override
Expand All @@ -208,16 +226,19 @@ public Value emitZeroExtend(Value inputVal, int fromBits, int toBits) {
PTXKind kind = (PTXKind) inputVal.getPlatformKind();
LIRKind toKind;
if (kind.isInteger()) {
toKind = kindTool.getIntegerKind(toBits);
toKind = kindTool.getUnsignedIntegerKind(toBits);
} else if (kind.isFloating()) {
toKind = kindTool.getFloatingKind(toBits);
} else {
throw shouldNotReachHere();
}

Variable result = getGen().newVariable(toKind);
Variable signExtendedValue = getGen().newVariable(toKind);
getGen().emitMove(signExtendedValue, inputVal);

getGen().emitMove(result, inputVal);
// Apply a bitwise mask in order to avoid sign extension and instead zero extend the value.
ConstantValue mask = new ConstantValue(toKind, JavaConstant.forIntegerKind(CodeUtil.javaKindFromBitSize(toBits, kind.isFloating()), (1L << fromBits) - 1));
Variable result = emitBinaryAssign(PTXBinaryOp.BITWISE_AND, toKind, signExtendedValue, mask);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ public Value emitZeroExtend(Value inputVal, int fromBits, int toBits) {
Variable result = getGen().newVariable(toKind);

LIRKind lirKind = getGen().getLIRKindTool().getIntegerKind(toBits);
SPIRVUnary.SignExtend signExtend = new SPIRVUnary.SignExtend(lirKind, result, inputVal, fromBits, toBits);
getGen().append(new SPIRVLIRStmt.AssignStmt(result, signExtend));
SPIRVUnary.ZeroExtend zeroExtend = new SPIRVUnary.ZeroExtend(lirKind, result, inputVal, fromBits, toBits);
getGen().append(new SPIRVLIRStmt.AssignStmt(result, zeroExtend));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,43 @@ public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {
}
}

public static class ZeroExtend extends AbstractExtend {

private int fromBits;
private int toBits;

public ZeroExtend(LIRKind lirKind, Variable result, Value inputVal, int fromBits, int toBits) {
super(null, result, lirKind, inputVal);
this.fromBits = fromBits;
this.toBits = toBits;
}

@Override
public void emit(SPIRVCompilationResultBuilder crb, SPIRVAssembler asm) {

Logger.traceCodeGen(Logger.BACKEND.SPIRV, "emit SPIRVOpUConvert : " + fromBits + " -> " + toBits);

SPIRVKind spirvKind = (SPIRVKind) value.getPlatformKind();
SPIRVId type = asm.primitives.getTypePrimitive(spirvKind);

SPIRVId loadConvert = loadConvertIfNeeded(crb, asm, type, spirvKind);

SPIRVId toTypeId = switch (toBits) {
case 64 -> asm.primitives.getTypePrimitive(SPIRVKind.OP_TYPE_INT_64);
case 32 -> asm.primitives.getTypePrimitive(SPIRVKind.OP_TYPE_INT_32);
case 16 -> asm.primitives.getTypePrimitive(SPIRVKind.OP_TYPE_INT_16);
case 8 -> asm.primitives.getTypePrimitive(SPIRVKind.OP_TYPE_INT_8);
default -> throw new TornadoRuntimeException("to Type not supported: " + toBits);
};

SPIRVId result = obtainPhiValueIdIfNeeded(asm);
asm.currentBlockScope().add(new SPIRVOpUConvert(toTypeId, result, loadConvert));

asm.registerLIRInstructionValue(this, result);

}
}

public static class CastOperations extends UnaryConsumer {

protected CastOperations(SPIRVUnaryOp opcode, Variable result, LIRKind valueKind, Value value) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package uk.ac.manchester.tornado.unittests.numpromotion;

import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;

import java.util.Random;

import static org.junit.Assert.assertEquals;

public class TestZeroExtend {

public static void narrowByte(ByteArray a, IntArray result, int size) {
for(@Parallel int i = 0; i < size; i++) {
result.set(i, a.get(i) & 0xFF);
}
}

public static void narrowShort(ShortArray a, LongArray result, int size) {
for(@Parallel int i = 0; i < size; i++) {
result.set(i, a.get(i) & 0xFFFF);
}
}

public static void narrowInt(IntArray a, LongArray result, int size) {
for(@Parallel int i = 0; i < size; i++) {
result.set(i, a.get(i) & 0xFFFFFFFFL);
}
}

@Test
public void testByte() throws TornadoExecutionPlanException {
Random r = new Random();
int size = 1024;

ByteArray a = new ByteArray(size);
for(int i = 0; i < size/2; i++) {
a.set(i, (byte) (128 + r.nextInt(128)));
a.set(i+size/2, (byte) (r.nextInt(128)));
}

IntArray expected = new IntArray(size);
IntArray result = new IntArray(size);

expected.init(0);
result.init(0);

TaskGraph graph = new TaskGraph("s0")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, result)
.task("t0", TestZeroExtend::narrowByte, a, result, size)
.transferToHost(DataTransferMode.EVERY_EXECUTION, result);

ImmutableTaskGraph immutableTaskGraph = graph.snapshot();
try(TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
narrowByte(a, expected, size);
}

for (int i = 0; i < expected.getSize(); i++) {
assertEquals(expected.get(i), result.get(i));
}
}

@Test
public void testShort() throws TornadoExecutionPlanException {
Random r = new Random();
int size = 1024;

ShortArray a = new ShortArray(size);

for(int i = 0; i < size/2; i++) {
a.set(i, (short) (Short.MAX_VALUE + (short) r.nextInt(Short.MAX_VALUE)));
a.set(i+size/2, (short) (r.nextInt(Short.MAX_VALUE)));
}

LongArray expected = new LongArray(size);
LongArray result = new LongArray(size);
expected.init(0);
result.init(0);

TaskGraph graph = new TaskGraph("s0")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, result)
.task("t0", TestZeroExtend::narrowShort, a, result, size)
.transferToHost(DataTransferMode.EVERY_EXECUTION, result);

ImmutableTaskGraph immutableTaskGraph = graph.snapshot();
try(TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
narrowShort(a, expected, size);
}

for (int i = 0; i < expected.getSize(); i++) {
assertEquals(expected.get(i), result.get(i));
}
}

@Test
public void testInt() throws TornadoExecutionPlanException {
Random r = new Random();
int size = 1024;

IntArray a = new IntArray(size);

for(int i = 0; i < size/2; i++) {
a.set(i, Integer.MAX_VALUE + r.nextInt(Integer.MAX_VALUE));
a.set(i+size/2, r.nextInt(Integer.MAX_VALUE));
}

LongArray expected = new LongArray(size);
LongArray result = new LongArray(size);
expected.init(0);
result.init(0);

TaskGraph graph = new TaskGraph("s0")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, a, result)
.task("t0", TestZeroExtend::narrowInt, a, result, size)
.transferToHost(DataTransferMode.EVERY_EXECUTION, result);

ImmutableTaskGraph immutableTaskGraph = graph.snapshot();
try(TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
executionPlan.execute();
narrowInt(a, expected, size);
}

for (int i = 0; i < expected.getSize(); i++) {
assertEquals(expected.get(i), result.get(i));
}
}
}

0 comments on commit ea7b602

Please sign in to comment.