Skip to content

Commit

Permalink
faster path for GPU creation
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 31, 2022
1 parent 32e06c8 commit 4f55dc8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.2"
version = "0.2.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
10 changes: 10 additions & 0 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<
return OneHotArray(indices, length(labels))
end

function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
offset = 1 - first(labels)
# The bounds check with extrema synchronises, often 10x slower than rest of the function.
indices = map(data) do datum
checkbounds(labels, datum)
UInt32(datum + offset)
end
return OneHotArray(indices, length(labels))
end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))
Expand Down

0 comments on commit 4f55dc8

Please sign in to comment.