Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ public CudaHATKernelBuilder defines() {
.hashDefine("HAT_BIY", _ -> keyword("blockIdx").dot().threadDimId(1))
.hashDefine("HAT_BIZ", _ -> keyword("blockIdx").dot().threadDimId(2))
.hashDefine("HAT_BARRIER", _->keyword("__syncthreads").ocparen())
.includeSys("cuda_fp16.h")
.buildStructSingleMember("F16", "value", "half");
.includeSys("cuda_fp16.h", "cuda_bf16.h")
.hashDefine("BFLOAT16", _->keyword("__nv_bfloat16"))
.buildStructSingleMember("F16", "value", "half")
.buildStructSingleMember("BF16", "value", "BFLOAT16");
}

@Override
Expand Down Expand Up @@ -226,8 +228,13 @@ public CudaHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildConte

@Override
public CudaHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().halfType().cparen().obrace();
identifier("__float2half").oparen();
oparen();
ReducedFloatType reducedFloatType = hatF16ConvOp.reducedFloatType();
generateReduceFloatType(reducedFloatType);
cparen().obrace();

buildReducedFloatType(reducedFloatType);
oparen();
Value param = hatF16ConvOp.operands().getFirst();
if (param instanceof Op.Result r) {
recurse(buildContext, r.op());
Expand All @@ -238,7 +245,8 @@ public CudaHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext,

@Override
public CudaHATKernelBuilder hatF16ToFloatConvOp(ScopedCodeBuilderContext builderContext, HATF16ToFloatConvOp hatF16ToFloatConvOp) {
identifier("__half2float").oparen();
buildReducedFloatType(hatF16ToFloatConvOp.reducedFloatType());
oparen();
Value param = hatF16ToFloatConvOp.operands().getFirst();
if (param instanceof Op.Result r) {
recurse(builderContext, r.op());
Expand Down Expand Up @@ -281,13 +289,16 @@ public CudaHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builder
public CudaHATKernelBuilder hatF16BinaryOp(ScopedCodeBuilderContext buildContext, HATF16BinaryOp hatF16BinaryOp) {
Value op1 = hatF16BinaryOp.operands().get(0);
Value op2 = hatF16BinaryOp.operands().get(1);
ReducedFloatType reducedFloatType = hatF16BinaryOp.reducedFloatType();
List<Boolean> references = hatF16BinaryOp.references();
byte f32Mixed = hatF16BinaryOp.getF32();

oparen().halfType().cparen().obrace().oparen();
oparen();
generateReduceFloatType(reducedFloatType);
cparen().obrace().oparen();

if (f32Mixed == HATF16BinaryOp.LAST_OP) {
identifier("__half2float").oparen();
generateReducedFloatConversionToFloat(reducedFloatType);
}

if (op1 instanceof Op.Result r) {
Expand All @@ -303,10 +314,10 @@ public CudaHATKernelBuilder hatF16BinaryOp(ScopedCodeBuilderContext buildContext
cparen();
}

space().identifier(hatF16BinaryOp.operationType().symbol()).space();
space().identifier(hatF16BinaryOp.binaryOperationType().symbol()).space();

if (f32Mixed == HATF16BinaryOp.FIRST_OP) {
identifier("__half2float").oparen();
generateReducedFloatConversionToFloat(reducedFloatType);
}

if (op2 instanceof Op.Result r) {
Expand All @@ -326,4 +337,28 @@ public CudaHATKernelBuilder hatF16BinaryOp(ScopedCodeBuilderContext buildContext
cparen().cbrace();
return self();
}

private void buildReducedFloatType(ReducedFloatType reducedFloatType) {
switch (reducedFloatType) {
case ReducedFloatType.HalfFloat _ -> identifier("__half2float");
case ReducedFloatType.BFloat16 _ -> identifier("__nv_bfloat16");
default -> throw new IllegalStateException("Unexpected value: " + reducedFloatType);
}
}

private void generateReduceFloatType(ReducedFloatType reducedFloatType) {
switch (reducedFloatType) {
case ReducedFloatType.HalfFloat _ -> halfType();
case ReducedFloatType.BFloat16 _ -> bfloatType();
default -> throw new IllegalStateException("Unexpected value: " + reducedFloatType);
}
}

private void generateReducedFloatConversionToFloat(ReducedFloatType reducedFloatType) {
switch (reducedFloatType) {
case ReducedFloatType.HalfFloat _ -> identifier("__half2float").oparen();
case ReducedFloatType.BFloat16 _ -> identifier("__bfloat162float").oparen();
default -> throw new IllegalStateException("Unexpected value: " + reducedFloatType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import hat.dialect.ReducedFloatType;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;

import java.util.Objects;

public class OpenCLHATKernelBuilder extends C99HATKernelBuilder<OpenCLHATKernelBuilder> {

@Override
Expand Down Expand Up @@ -71,7 +74,32 @@ public OpenCLHATKernelBuilder defines() {
.hashDefine("HAT_BIY", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstOne())))
.hashDefine("HAT_BIZ", _ -> paren(_ -> identifier("get_group_id").paren(_ -> intConstTwo())))
.hashDefine("HAT_BARRIER", _ -> identifier("barrier").oparen().identifier("CLK_LOCAL_MEM_FENCE").cparen())
.buildStructSingleMember("F16", "value", "half");
.hashDefine("BFLOAT16", _ -> keyword("ushort"))
.buildStructSingleMember("F16", "value", "half")
.buildStructSingleMember("BF16", "value", "BFLOAT16")
.identifier("""
void byteCopy(void *dest, const void* src, size_t size) {
unsigned char *c = (unsigned char*)dest;
unsigned char *s = (unsigned char*)src;
for (int i = 0; i < size; i++) {
*c++ = *s++;
}
}

float bfloat162float(ushort bf16) {
uint bitsRecovered = bf16 << 16;
float r = bitsRecovered;
byteCopy(&r, &bitsRecovered, sizeof(r));
return r;
}

ushort float2bfloat16(float f) {
uint bits;
byteCopy(&bits, &f, sizeof(bits));
short bf16 = bits >> 16;
return bf16;
}
""");
}

@Override
Expand Down Expand Up @@ -191,11 +219,26 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon

@Override
public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().halfType().cparen().obrace();
ReducedFloatType reducedFloatType = hatF16ConvOp.reducedFloatType();

oparen();
if (reducedFloatType instanceof ReducedFloatType.HalfFloat) {
halfType();
} else if (reducedFloatType instanceof ReducedFloatType.BFloat16) {
bfloatType();
}

cparen().obrace();
if (reducedFloatType instanceof ReducedFloatType.BFloat16) {
builtin_float2bfloat16().oparen();
}
Value initValue = hatF16ConvOp.operands().getFirst();
if (initValue instanceof Op.Result r) {
recurse(buildContext, r.op());
}
if (reducedFloatType instanceof ReducedFloatType.BFloat16) {
cparen();
}
cbrace();
return self();
}
Expand All @@ -222,7 +265,20 @@ public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext build

@Override
public OpenCLHATKernelBuilder hatF16ToFloatConvOp(ScopedCodeBuilderContext builderContext, HATF16ToFloatConvOp hatF16ToFloatConvOp) {
oparen().floatType().cparen();

// Type conversions:
// half -> float
// bfloat16 -> float

ReducedFloatType reducedFloatType = hatF16ToFloatConvOp.reducedFloatType();

if (reducedFloatType instanceof ReducedFloatType.HalfFloat) {
// half -> float
oparen().floatType().cparen();
} else if (reducedFloatType instanceof ReducedFloatType.BFloat16) {
// bfloat16 -> float
builtin_bfloat162float().oparen();
}
Value value = hatF16ToFloatConvOp.operands().getFirst();
if (value instanceof Op.Result r) {
recurse(builderContext, r.op());
Expand All @@ -232,6 +288,9 @@ public OpenCLHATKernelBuilder hatF16ToFloatConvOp(ScopedCodeBuilderContext build
} else if (!hatF16ToFloatConvOp.wasFloat()) {
dot().identifier("value");
}
if (reducedFloatType instanceof ReducedFloatType.BFloat16) {
cparen();
}
return self();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@
import hat.annotations.Kernel;
import hat.annotations.Preformatted;
import hat.annotations.TypeDef;
import hat.buffer.F16;
import hat.buffer.KernelBufferContext;
import hat.buffer.*;
import hat.codebuilders.C99HATKernelBuilder;
import hat.buffer.ArgArray;
import hat.buffer.Buffer;
import hat.buffer.BufferTracker;
import hat.callgraph.KernelCallGraph;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.device.DeviceSchema;
Expand Down Expand Up @@ -283,6 +279,7 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern

// Add HAT reserved types
typedefs.add(F16.class.getName());
typedefs.add(BF16.class.getName());

for (TypeElement typeElement : localIFaceList) {
try {
Expand Down
124 changes: 124 additions & 0 deletions hat/core/src/main/java/hat/buffer/BF16.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package hat.buffer;

public interface BF16 {

char value();
void value(char value);

static BF16 of(float value) {
return new BF16() {
@Override
public char value() {
int bits = Float.floatToRawIntBits(value);
bits >>= 16;
return (char) bits;
}

@Override
public void value(char value) {
}
};
}

static BF16 of(char value) {
return new BF16() {
@Override
public char value() {
return value;
}

@Override
public void value(char value) {
}
};
}

static BF16 float2bfloat16(float value) {
return of(value);
}

static float bfloat162float(BF16 value) {
return Float.intBitsToFloat(value.value() << 16);
}

static BF16 add(BF16 ha, BF16 hb) {
return BF16.of(bfloat162float(ha) + bfloat162float(hb));
}

static BF16 add(float f32, BF16 hb) {
return BF16.of(f32 + bfloat162float(hb));
}

static BF16 sub(BF16 ha, BF16 hb) {
return BF16.of(bfloat162float(ha) - bfloat162float(hb));
}

static BF16 sub(float f32, BF16 hb) {
return BF16.of(f32 - bfloat162float(hb));
}

static BF16 sub(BF16 hb, float f32) {
return BF16.of(bfloat162float(hb) - f32);
}

static BF16 mul(BF16 ha, BF16 hb) {
return BF16.of(bfloat162float(ha) * bfloat162float(hb));
}

static BF16 mul(float f32, BF16 hb) {
return BF16.of(f32 * bfloat162float(hb));
}

static BF16 div(BF16 ha, BF16 hb) {
return BF16.of(bfloat162float(ha) / bfloat162float(hb));
}

static BF16 div(float f32, BF16 hb) {
return BF16.of(f32 / bfloat162float(hb));
}

static BF16 add(BF16 hb, float f32) {
return BF16.of(bfloat162float(hb) / f32);
}

default BF16 add(BF16 ha) {
return BF16.add(this, ha);
}

default BF16 sub(BF16 ha) {
return BF16.sub(this, ha);
}

default BF16 mul(BF16 ha) {
return BF16.mul(this, ha);
}

default BF16 div(BF16 ha) {
return BF16.div(this, ha);
}

}
Loading