Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support WrapperGPUArray nd indexing by fusing vectorized fallback. #512

Merged
merged 3 commits into from
Jan 17, 2024

Conversation

N5N3
Copy link
Contributor

@N5N3 N5N3 commented Jan 10, 2024

A quick trial to extend JuliaLang/julia#52626.
It seems bad to overloading internal functions. But this looks like the simplest way.
I took a quick code search on juliahub, and the result looks quite clean. (I wish there's no ambiguity risk.)

@maleadt Since this PR fuses the current fallback. IIUC, your concern on various lower-level routines should be resolved?

@N5N3 N5N3 changed the title Add WrapperGPUArray nd indexing by fusing vectorized fallback. Support WrapperGPUArray nd indexing by fusing vectorized fallback. Jan 10, 2024
@N5N3
Copy link
Contributor Author

N5N3 commented Jan 11, 2024

Looks like oneAPI.jl and Metal.jl don't support Int128 (used in slow ReshapedArray's indexing transformation)?
I have no mac so I tried oneAPI.jl via wsl.
This is ok:

julia> a = oneArray(rand(Int, 1, 1)) .|> identity
1×1 oneArray{Int64, 2, oneAPI.oneL0.DeviceBuffer}:
 2601239328758681190

while this is bad

julia> a = oneArray(rand(Int128, 1, 1)) .|> identity
InvalidBitWidth: Invalid bit width in input: 128

As for the test failure, I tried to avoid the Int128 usage by ReshapedArray(src, (length(src)), ()) but it throws the same error.
@maleadt is it OK to mark related test skipped for now?

Comment on lines 151 to 152
## Vectorized index overloading for `WrappedGPUArray`
# We overloading `getindex` by dispatch the copy part to our implement.
Copy link
Member

@maleadt maleadt Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment why these lower-level overloads are required?

@maleadt
Copy link
Member

maleadt commented Jan 16, 2024

@maleadt is it OK to mark related test skipped for now?

What changed causing this Int128 code path to be hit now? Unless the test explicitly creates an Int128[], I don't think it's OK to disable those tests because they now trigger use of Int128; that seems like a regression?

@N5N3
Copy link
Contributor Author

N5N3 commented Jan 16, 2024

What changed causing this Int128 code path to be hit now?

It comes from Base.ReshapedArray.
It use Base.MultiplicativeInverses.SignedMultiplicativeInverse to accelerate the index transformation if its parent is not IndexLinear.
And Base use widened bit integer for internal divrem

julia> a = Base.MultiplicativeInverses.SignedMultiplicativeInverse(1)
Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}(1, 1, 0, 0x00)

julia> @code_typed divrem(1, a)
CodeInfo(
1%1  = Base.getfield(b, :multiplier)::Int64%2  = Core.sext_int(Core.Int128, a)::Int128    # here %3  = Base.sext_int(Int128, %1)::Int128          # here %4  = Base.mul_int(%2, %3)::Int128                    # here %5  = Base.lshr_int(%4, 0x0000000000000040)::Int128       # here %6  = Base.trunc_int(Int64, %5)::Int64%7  = Base.getfield(b, :addmul)::Int8%8  = Base.sext_int(Int64, %7)::Int64%9  = Base.mul_int(a, %8)::Int64%10 = Base.add_int(%6, %9)::Int64%11 = Base.getfield(b, :divisor)::Int64%12 = Base.flipsign_int(%11, %11)::Int64%13 = (%12 === 1)::Bool%14 = Base.getfield(b, :divisor)::Int64%15 = Base.mul_int(a, %14)::Int64%16 = Base.slt_int(%10, 0)::Bool%17 = Base.getfield(b, :shift)::UInt8%18 = Base.ashr_int(%10, %17)::Int64%19 = Core.zext_int(Core.Int64, %16)::Int64%20 = Core.and_int(%19, 1)::Int64%21 = Base.add_int(%20, %18)::Int64%22 = Core.ifelse(%13, %15, %21)::Int64%23 = Base.getfield(b, :divisor)::Int64%24 = Base.mul_int(%22, %23)::Int64%25 = Base.sub_int(a, %24)::Int64%26 = Core.tuple(%22, %25)::Tuple{Int64, Int64}
└──       return %26
) => Tuple{Int64, Int64}

I guess this means Base.ReshapedArray is "broken" for Meta.jl/oneApi.jl for a long time.

@maleadt
Copy link
Member

maleadt commented Jan 16, 2024

Oh, it's only triggering in the newly added test. Yeah, it's probably fine to mark those as broken right now. We don't have an interface for that though, maybe something like broken=(string(AT) in ["MtlArray", "oneArray"]) is fine for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants