-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
onecold does not work on CuMatrix #864
Comments
It does work if you replace |
Since onecold only accepts numbers, Can I try fixing this? |
I don't see how this works for strings: onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) What does onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data) |
1448: Arbitrary dimension one-hot arrays r=DhairyaLGandhi a=darsnack This supersedes #1447. It should address the same issues: - fix #1445, #1229 - probably fix also #864, #556, #189 This PR introduces a new one-hot N-dimensional array type, `OneHotArray`. Like #1447, this approach avoids the pointer allocations associated with `OneHotMatrix` being an array of `OneHotVector`s. It also lifts the "height" into the type parameter to avoid unnecessary allocation. Unlike #1447, this approach does not introduce a new primitive type. Instead, a "one-hot vector" is represented with a single subtype of `Integer` that is configurable by the user. By default, the exposed API will use `UInt32`. Fundamentally, the primitive type is necessary because wrapping a `UInt32` as a `OneHotVector` will suffer memory penalties when you create an `Array{<:OneHotVector}`. But if we begin by designing for N-dimensions, then `OneHotVector` is just the specialized 1D case (similar to how `Vector{T} = Array{T, 1}`). ## Performance I compared against the same tests mentioned in #1447. Please suggest more if you want to. 1. #189 ```jl #master julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 5.095 μs (13 allocations: 50.86 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 24.948 μs (86 allocations: 3.11 KiB) #1447 julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 5.312 μs (3 allocations: 50.36 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 8.466 μs (61 allocations: 1.69 KiB) # this PR julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 4.708 μs (3 allocations: 50.56 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 8.576 μs (63 allocations: 1.73 KiB) ``` 2. #556 ```jl #master julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 365.712 μs (1131 allocations: 38.16 KiB) julia> @Btime Flux.onecold($(gpu(valY))); ┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)` └ @ GPUArrays ~/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43 1.330 s (781248 allocations: 31.59 MiB) #1447 julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 524.767 μs (8 allocations: 4.00 KiB) julia> @Btime Flux.onecold($(gpu(valY))); 27.563 μs (169 allocations: 5.56 KiB) # this PR julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 493.017 μs (8 allocations: 4.53 KiB) julia> @Btime Flux.onecold($(gpu(valY))); 26.702 μs (171 allocations: 5.61 KiB) ``` ## Summary This should basically be #1447 but simpler to maintain w/ fewer changes. Tests are passing, though I think we should add more tests for one-hot data (currently our test set seems pretty sparse). Performance matches #1447 where I have tested, but please suggest more performance tests. In theory, any performance difference between #1447 and this PR should be recoverable. ### PR Checklist - [ ] Tests are added - [ ] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] Final review from @DhairyaLGandhi (for API changes). cc @CarloLucibello @chengchingwen Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
The text was updated successfully, but these errors were encountered: