diff --git a/src/onehot.jl b/src/onehot.jl index 307611cc66..c82dce237a 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -78,6 +78,8 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) +onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data) + function argmax(xs...) Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) return onecold(xs...)