Skip to content

Commit

Permalink
Initial BFloat16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Nov 6, 2024
1 parent 3613366 commit b76de49
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
4 changes: 2 additions & 2 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ using ObjectiveC, .Foundation

import GPUArrays

using BFloat16s
using BFloat16s: BFloat16

const MtlFloat = Union{Float32, Float16}
const MtlFloat = Union{Float32, Float16, BFloat16}

is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev)

Expand Down
1 change: 1 addition & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using ExprTools: splitdef, combinedef
using Artifacts
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
import KernelAbstractions
using BFloat16s

include("version.jl")

Expand Down
3 changes: 2 additions & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
# pointer type information for typed intrinsics
# (this is consumed by the LLVM IR downgrader)
for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64,
Float16 => :f16, Float32 => :f32),
Float16 => :f16, Float32 => :f32,
BFloat16 => :bf16),
(as, asname) in (AS.Device => "global", AS.ThreadGroup => "local")

# map of intrinsics to pointer operand indices and eltypes
Expand Down
3 changes: 2 additions & 1 deletion src/device/intrinsics/simd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
end

for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"))
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18"))
for as in (AS.Device, AS.ThreadGroup)
@eval begin
@device_function simdgroup_load(
Expand Down Expand Up @@ -88,6 +88,7 @@ Returns `a * b + c`.

simd_shuffle_map = ((Float32, "f32"),
(Float16, "f16"),
(BFloat16, "bf16"),
(Int32, "s.i32"),
(UInt32, "u.i32"),
(Int16, "s.i16"),
Expand Down
16 changes: 10 additions & 6 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SpecialFunctions
using BFloat16s
using Metal: metal_support

@testset "arguments" begin
Expand Down Expand Up @@ -308,8 +309,9 @@ end
############################################################################################

@testset "simd intrinsics" begin

@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8]
metal_support() >= v"3.1" && push!(types, BFloat16)
@testset "shuffle($typ)" for typ in types
function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T
idx = thread_position_in_grid_1d()
idx_in_simd = thread_index_in_simdgroup()
Expand Down Expand Up @@ -344,7 +346,9 @@ end
end

@testset "matrix functions" begin
@testset "load_store($typ)" for typ in [Float16, Float32]
simdgroup_types = [Float16, Float32]
metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16)
@testset "load_store($typ)" for typ in simdgroup_types
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
origin_a=(1, 1), origin_b=(1, 1)) where {T}
sg_a = simdgroup_load(a, origin_a)
Expand All @@ -367,7 +371,7 @@ end
end
end

@testset "load_store_tg($typ)" for typ in [Float16, Float32]
@testset "load_store_tg($typ)" for typ in simdgroup_types
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T}
pos = thread_position_in_threadgroup_2d()

Expand All @@ -391,7 +395,7 @@ end
@test Array(a) == Array(b)
end

@testset "mul($typ)" for typ in [Float16, Float32]
@testset "mul($typ)" for typ in simdgroup_types
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T}
sg_a = simdgroup_load(a)
sg_b = simdgroup_load(b)
Expand All @@ -407,7 +411,7 @@ end
@test Array(a) * Array(b) Array(c)
end

@testset "mad($typ)" for typ in [Float16, Float32]
@testset "mad($typ)" for typ in simdgroup_types
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T},
d::MtlDeviceArray{T}) where {T}
sg_a = simdgroup_load(a)
Expand Down

0 comments on commit b76de49

Please sign in to comment.