Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix/180] Fix offset for vector instructions that operate in private memory space #181

Merged
merged 5 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2020, APT Group, Department of Computer Science,
* Copyright (c) 2018, 2020, 2022, APT Group, Department of Computer Science,
* The University of Manchester. All rights reserved.
* Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
Expand Down Expand Up @@ -28,6 +28,7 @@
import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.unimplemented;
import static uk.ac.manchester.tornado.drivers.opencl.graal.asm.OCLAssembler.OCLUnaryIntrinsic.RSQRT;

import jdk.vm.ci.meta.JavaConstant;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.calc.FloatConvert;
import org.graalvm.compiler.lir.ConstantValue;
Expand Down Expand Up @@ -62,6 +63,7 @@
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt.VectorStoreStmt;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLUnary.MemoryAccess;
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLUnary.OCLAddressCast;
import uk.ac.manchester.tornado.drivers.opencl.graal.meta.OCLMemorySpace;
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.vector.VectorUtil;

public class OCLArithmeticTool extends ArithmeticLIRGenerator {
Expand Down Expand Up @@ -332,7 +334,7 @@ public Variable emitLoad(LIRKind lirKind, Value address, LIRFrameState state) {
if (oclKind.isVector()) {
OCLBinaryIntrinsic intrinsic = VectorUtil.resolveLoadIntrinsic(oclKind);
OCLAddressCast cast = new OCLAddressCast(base, LIRKind.value(oclKind.getElementKind()));
emitVectorLoad(result, intrinsic, new ConstantValue(LIRKind.value(OCLKind.INT), PrimitiveConstant.INT_0), cast, (MemoryAccess) address);
emitVectorLoad(result, intrinsic, getOffsetValue(oclKind, (MemoryAccess) address), cast, (MemoryAccess) address);
} else {
OCLAddressCast cast = new OCLAddressCast(base, lirKind);
emitLoad(result, cast, (MemoryAccess) address);
Expand Down Expand Up @@ -364,7 +366,7 @@ public void emitStore(ValueKind<?> lirKind, Value address, Value input, LIRFrame
if (oclKind.isVector()) {
OCLTernaryIntrinsic intrinsic = VectorUtil.resolveStoreIntrinsic(oclKind);
OCLAddressCast cast = new OCLAddressCast(memAccess.getBase(), LIRKind.value(oclKind.getElementKind()));
getGen().append(new VectorStoreStmt(intrinsic, new ConstantValue(LIRKind.value(OCLKind.INT), PrimitiveConstant.INT_0), cast, memAccess, input));
getGen().append(new VectorStoreStmt(intrinsic, getOffsetValue(oclKind, memAccess), cast, memAccess, input));
} else {

/**
Expand Down Expand Up @@ -474,6 +476,38 @@ public Value emitMathCopySign(Value magnitude, Value sign) {
return null;
}

/**
* It calculates and returns the offset for vstore/vload operations as a Value
* object.
*
* @param oclKind
* the kind for getting the size of the element type in a vector
* @param memoryAccess
* the object that holds the index of an element in a vector
* @return
*/
private Value getPrivateOffsetValue(OCLKind oclKind, MemoryAccess memoryAccess) {
Value privateOffsetValue = null;
if (memoryAccess == null) {
return null;
}
if (memoryAccess.getIndex() instanceof ConstantValue) {
ConstantValue constantValue = (ConstantValue) memoryAccess.getIndex();
int parsedIntegerIndex = Integer.parseInt(constantValue.getConstant().toValueString());
int index = parsedIntegerIndex / oclKind.getVectorLength();
privateOffsetValue = new ConstantValue(LIRKind.value(OCLKind.INT), JavaConstant.forInt(index));
}
return privateOffsetValue;
}

private Value getOffsetValue(OCLKind oclKind, MemoryAccess memoryAccess) {
if (memoryAccess.getBase().memorySpace == OCLMemorySpace.GLOBAL.getBase().memorySpace) {
return new ConstantValue(LIRKind.value(OCLKind.INT), PrimitiveConstant.INT_0);
} else {
return getPrivateOffsetValue(oclKind, memoryAccess);
}
}

public Value emitFMAInstruction(Value op1, Value op2, Value op3) {
LIRKind resultKind = LIRKind.combine(op1, op2, op3);
Variable result = getGen().newVariable(resultKind);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
* Copyright (c) 2013-2020, APT Group, Department of Computer Science,
* Copyright (c) 2013-2022, APT Group, Department of Computer Science,
* The University of Manchester. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
Expand Down Expand Up @@ -250,4 +250,8 @@ public DoubleBuffer asBuffer() {
public int size() {
return numElements;
}

public int getLength() {
return numElements;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
* Copyright (c) 2013-2020, APT Group, Department of Computer Science,
* Copyright (c) 2013-2022, APT Group, Department of Computer Science,
* The University of Manchester. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
Expand Down Expand Up @@ -245,4 +245,8 @@ public FloatBuffer asBuffer() {
public int size() {
return numElements;
}

public int getLength() {
return numElements;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* This file is part of Tornado: A heterogeneous programming framework:
* https://github.com/beehive-lab/tornadovm
*
* Copyright (c) 2013-2020, APT Group, Department of Computer Science,
* Copyright (c) 2013-2022, APT Group, Department of Computer Science,
* The University of Manchester. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
Expand Down Expand Up @@ -254,4 +254,8 @@ public IntBuffer asBuffer() {
public int size() {
return numElements;
}

public int getLength() {
return numElements;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2013-2020, APT Group, Department of Computer Science,
* Copyright (c) 2013-2022, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -34,10 +34,13 @@
import uk.ac.manchester.tornado.api.collections.types.VectorDouble2;
import uk.ac.manchester.tornado.api.collections.types.VectorDouble3;
import uk.ac.manchester.tornado.api.collections.types.VectorDouble4;
import uk.ac.manchester.tornado.api.collections.types.VectorDouble8;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestDoubles extends TornadoTestBase {

public static final double DELTA = 0.001;

private static void addDouble2(Double2 a, Double2 b, VectorDouble results) {
Double2 d2 = Double2.add(a, b);
double r = d2.getX() + d2.getY();
Expand All @@ -59,7 +62,7 @@ public void testDoubleAdd2() {
//@formatter:on

for (int i = 0; i < size; i++) {
assertEquals(8.0, output.get(i), 0.001);
assertEquals(8.0, output.get(i), DELTA);
}
}

Expand All @@ -84,7 +87,7 @@ public void testDoubleAdd3() {
//@formatter:on

for (int i = 0; i < size; i++) {
assertEquals(12.0, output.get(i), 0.001);
assertEquals(12.0, output.get(i), DELTA);
}
}

Expand All @@ -109,7 +112,7 @@ public void testDoubleAdd4() {
//@formatter:on

for (int i = 0; i < size; i++) {
assertEquals(20.0, output.get(i), 0.001);
assertEquals(20.0, output.get(i), DELTA);
}
}

Expand All @@ -134,7 +137,7 @@ public void testDoubleAdd8() {
//@formatter:on

for (int i = 0; i < size; i++) {
assertEquals(72., output.get(i), 0.001);
assertEquals(72., output.get(i), DELTA);
}
}

Expand Down Expand Up @@ -165,7 +168,7 @@ public void testDoubleAdd() {
//@formatter:on

for (int i = 0; i < size; i++) {
assertEquals(i + i, output[i], 0.001);
assertEquals(i + i, output[i], DELTA);
}
}

Expand Down Expand Up @@ -213,7 +216,7 @@ public void testDotProductDouble() {
.execute();
//@formatter:on

assertEquals(seqReduce[0], outputReduce[0], 0.001);
assertEquals(seqReduce[0], outputReduce[0], DELTA);
}

public static void addVectorDouble2(VectorDouble2 a, VectorDouble2 b, VectorDouble2 results) {
Expand Down Expand Up @@ -244,8 +247,8 @@ public void testVectorDouble2() {

for (int i = 0; i < size; i++) {
Double2 sequential = new Double2(i + (size - i), i + (size - i));
assertEquals(sequential.getX(), output.get(i).getX(), 0.001);
assertEquals(sequential.getY(), output.get(i).getY(), 0.001);
assertEquals(sequential.getX(), output.get(i).getX(), DELTA);
assertEquals(sequential.getY(), output.get(i).getY(), DELTA);
}
}

Expand Down Expand Up @@ -277,9 +280,9 @@ public void testVectorDouble3() {

for (int i = 0; i < size; i++) {
Double3 sequential = new Double3(i + (size - i), i + (size - i), i + (size - i));
assertEquals(sequential.getX(), output.get(i).getX(), 0.001);
assertEquals(sequential.getY(), output.get(i).getY(), 0.001);
assertEquals(sequential.getZ(), output.get(i).getZ(), 0.001);
assertEquals(sequential.getX(), output.get(i).getX(), DELTA);
assertEquals(sequential.getY(), output.get(i).getY(), DELTA);
assertEquals(sequential.getZ(), output.get(i).getZ(), DELTA);
}
}

Expand Down Expand Up @@ -311,10 +314,126 @@ public void testVectorDouble4() {

for (int i = 0; i < size; i++) {
Double4 sequential = new Double4(i + (size - i), i + (size - i), i + (size - i), i + (size - i));
assertEquals(sequential.getX(), output.get(i).getX(), 0.001);
assertEquals(sequential.getY(), output.get(i).getY(), 0.001);
assertEquals(sequential.getZ(), output.get(i).getZ(), 0.001);
assertEquals(sequential.getW(), output.get(i).getW(), 0.001);
assertEquals(sequential.getX(), output.get(i).getX(), DELTA);
assertEquals(sequential.getY(), output.get(i).getY(), DELTA);
assertEquals(sequential.getZ(), output.get(i).getZ(), DELTA);
assertEquals(sequential.getW(), output.get(i).getW(), DELTA);
}
}

public static void testPrivateVectorDouble2(VectorDouble2 output) {
VectorDouble2 vectorDouble2 = new VectorDouble2(output.getLength());

for (int i = 0; i < vectorDouble2.getLength(); i++) {
vectorDouble2.set(i, new Double2(i, i));
}

Double2 sum = new Double2(0, 0);

for (int i = 0; i < vectorDouble2.getLength(); i++) {
Double2 f = vectorDouble2.get(i);
sum = Double2.add(f, sum);
}

output.set(0, sum);
}

@Test
public void privateVectorDouble2() {
int size = 16;
VectorDouble2 sequentialOutput = new VectorDouble2(size);
VectorDouble2 tornadoOutput = new VectorDouble2(size);

TaskSchedule ts = new TaskSchedule("s0");
ts.task("t0", TestDoubles::testPrivateVectorDouble2, tornadoOutput);
ts.streamOut(tornadoOutput);
ts.execute();

testPrivateVectorDouble2(sequentialOutput);

for (int i = 0; i < size; i++) {
assertEquals(sequentialOutput.get(i).getX(), tornadoOutput.get(i).getX(), DELTA);
assertEquals(sequentialOutput.get(i).getY(), tornadoOutput.get(i).getY(), DELTA);
}
}

public static void testPrivateVectorDouble4(VectorDouble4 output) {
VectorDouble4 vectorDouble4 = new VectorDouble4(output.getLength());

for (int i = 0; i < vectorDouble4.getLength(); i++) {
vectorDouble4.set(i, new Double4(i, i, i, i));
}

Double4 sum = new Double4(0, 0, 0, 0);

for (int i = 0; i < vectorDouble4.getLength(); i++) {
Double4 f = vectorDouble4.get(i);
sum = Double4.add(f, sum);
}

output.set(0, sum);
}

@Test
public void privateVectorDouble4() {
int size = 16;
VectorDouble4 sequentialOutput = new VectorDouble4(size);
VectorDouble4 tornadoOutput = new VectorDouble4(size);

TaskSchedule ts = new TaskSchedule("s0");
ts.task("t0", TestDoubles::testPrivateVectorDouble4, tornadoOutput);
ts.streamOut(tornadoOutput);
ts.execute();

testPrivateVectorDouble4(sequentialOutput);

for (int i = 0; i < size; i++) {
assertEquals(sequentialOutput.get(i).getX(), tornadoOutput.get(i).getX(), DELTA);
assertEquals(sequentialOutput.get(i).getY(), tornadoOutput.get(i).getY(), DELTA);
assertEquals(sequentialOutput.get(i).getZ(), tornadoOutput.get(i).getZ(), DELTA);
assertEquals(sequentialOutput.get(i).getW(), tornadoOutput.get(i).getW(), DELTA);
}
}

public static void testPrivateVectorDouble8(VectorDouble8 output) {
VectorDouble8 vectorDouble8 = new VectorDouble8(output.getLength());

for (int i = 0; i < vectorDouble8.getLength(); i++) {
vectorDouble8.set(i, new Double8(i, i, i, i, i, i, i, i));
}

Double8 sum = new Double8(0, 0, 0, 0, 0, 0, 0, 0);

for (int i = 0; i < vectorDouble8.getLength(); i++) {
Double8 f = vectorDouble8.get(i);
sum = Double8.add(f, sum);
}

output.set(0, sum);
}

@Test
public void privateVectorDouble8() {
int size = 16;
VectorDouble8 sequentialOutput = new VectorDouble8(16);
VectorDouble8 tornadoOutput = new VectorDouble8(16);

TaskSchedule ts = new TaskSchedule("s0");
ts.task("t0", TestDoubles::testPrivateVectorDouble8, tornadoOutput);
ts.streamOut(tornadoOutput);
ts.execute();

testPrivateVectorDouble8(sequentialOutput);

for (int i = 0; i < size; i++) {
assertEquals(sequentialOutput.get(i).getS0(), tornadoOutput.get(i).getS0(), DELTA);
assertEquals(sequentialOutput.get(i).getS1(), tornadoOutput.get(i).getS1(), DELTA);
assertEquals(sequentialOutput.get(i).getS2(), tornadoOutput.get(i).getS2(), DELTA);
assertEquals(sequentialOutput.get(i).getS3(), tornadoOutput.get(i).getS3(), DELTA);
assertEquals(sequentialOutput.get(i).getS4(), tornadoOutput.get(i).getS4(), DELTA);
assertEquals(sequentialOutput.get(i).getS5(), tornadoOutput.get(i).getS5(), DELTA);
assertEquals(sequentialOutput.get(i).getS6(), tornadoOutput.get(i).getS6(), DELTA);
assertEquals(sequentialOutput.get(i).getS7(), tornadoOutput.get(i).getS7(), DELTA);
}
}

Expand Down
Loading