From b76de496b797be634a165ab87fbf9f4a63e2f02e Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:05:38 -0300 Subject: [PATCH] Initial BFloat16 support --- lib/mps/MPS.jl | 4 ++-- src/Metal.jl | 1 + src/compiler/compilation.jl | 3 ++- src/device/intrinsics/simd.jl | 3 ++- test/device/intrinsics.jl | 16 ++++++++++------ 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index 09f7754b7..b2874e26d 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -16,9 +16,9 @@ using ObjectiveC, .Foundation import GPUArrays -using BFloat16s +using BFloat16s: BFloat16 -const MtlFloat = Union{Float32, Float16} +const MtlFloat = Union{Float32, Float16, BFloat16} is_supported(dev::MTLDevice) = ccall(:MPSSupportsMTLDevice, Bool, (id{MTLDevice},), dev) diff --git a/src/Metal.jl b/src/Metal.jl index b63a69910..1a51456c2 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -13,6 +13,7 @@ using ExprTools: splitdef, combinedef using Artifacts using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS import KernelAbstractions +using BFloat16s include("version.jl") diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl index 7c4284ca8..c8fede871 100644 --- a/src/compiler/compilation.jl +++ b/src/compiler/compilation.jl @@ -18,7 +18,8 @@ function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob), # pointer type information for typed intrinsics # (this is consumed by the LLVM IR downgrader) for (jltyp, llvmtyp) in (Int32 => :i32, Int64 => :i64, - Float16 => :f16, Float32 => :f32), + Float16 => :f16, Float32 => :f32, + BFloat16 => :bf16), (as, asname) in (AS.Device => "global", AS.ThreadGroup => "local") # map of intrinsics to pointer operand indices and eltypes diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index 79250e330..e8815797d 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64}) return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1)) end -for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32")) +for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18")) for as in (AS.Device, AS.ThreadGroup) @eval begin @device_function simdgroup_load( @@ -88,6 +88,7 @@ Returns `a * b + c`. simd_shuffle_map = ((Float32, "f32"), (Float16, "f16"), + (BFloat16, "bf16"), (Int32, "s.i32"), (UInt32, "u.i32"), (Int16, "s.i16"), diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index c100b2ddd..a849bfda8 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -1,4 +1,5 @@ using SpecialFunctions +using BFloat16s using Metal: metal_support @testset "arguments" begin @@ -308,8 +309,9 @@ end ############################################################################################ @testset "simd intrinsics" begin - -@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] +types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] +metal_support() >= v"3.1" && push!(types, BFloat16) +@testset "shuffle($typ)" for typ in types function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T idx = thread_position_in_grid_1d() idx_in_simd = thread_index_in_simdgroup() @@ -344,7 +346,9 @@ end end @testset "matrix functions" begin - @testset "load_store($typ)" for typ in [Float16, Float32] + simdgroup_types = [Float16, Float32] + metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16) + @testset "load_store($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, origin_a=(1, 1), origin_b=(1, 1)) where {T} sg_a = simdgroup_load(a, origin_a) @@ -367,7 +371,7 @@ end end end - @testset "load_store_tg($typ)" for typ in [Float16, Float32] + @testset "load_store_tg($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T} pos = thread_position_in_threadgroup_2d() @@ -391,7 +395,7 @@ end @test Array(a) == Array(b) end - @testset "mul($typ)" for typ in [Float16, Float32] + @testset "mul($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T} sg_a = simdgroup_load(a) sg_b = simdgroup_load(b) @@ -407,7 +411,7 @@ end @test Array(a) * Array(b) ≈ Array(c) end - @testset "mad($typ)" for typ in [Float16, Float32] + @testset "mad($typ)" for typ in simdgroup_types function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}, d::MtlDeviceArray{T}) where {T} sg_a = simdgroup_load(a)