Skip to content

Commit

Permalink
Merge #1221
Browse files Browse the repository at this point in the history
1221: DataLoader with NamedTuple r=CarloLucibello a=cossio

Just a couple of small changes, so that `DataLoader` can be created with a `NamedTuple` of tensors instead of `Tuple`. This way the tensors can be referred to by name. For example

```
train_loader = DataLoader((images = Xtrain, labels = Ytrain), batchsize=16)
batch = first(train_loader)
y = model(batch.images)
logitcrossentropy(y, batch.labels)
```

If we only use tuples, then in datasets with multiple tensors one has to be careful about the order in which the tensors are fed into the `DataLoader` constructor and be consistent with this elsewhere. With `NamedTuples` one just have to be consistent about the names used, which I think is a minor improvement.

CC @CarloLucibello 

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable

I don't think this qualifies as an API change. It's just a minor feature addition. So final review probably not required.

- [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes).


Co-authored-by: cossio <j.cossio.diaz@gmail.com>
Co-authored-by: cossio <cossio@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 16, 2020
2 parents 254e4a7 + 9078f85 commit 19b45b4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# v0.11
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].

# v0.10.5
Expand Down
24 changes: 14 additions & 10 deletions src/data/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ end
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
Expand Down Expand Up @@ -57,6 +57,13 @@ Usage example:
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
# can use NamedTuple to name tensors
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
for datum in train_loader
@assert size(datum.images) == (10, 2)
@assert size(datum.labels) == (2,)
end
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
Expand Down Expand Up @@ -88,19 +95,16 @@ end

_nobs(data::AbstractArray) = size(data)[end]

function _nobs(data::Tuple)
function _nobs(data::Union{Tuple, NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
if !all(x -> _nobs(x) == n, Base.tail(data))
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end

function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end

_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)

Base.eltype(d::DataLoader{D}) where D = D
Base.eltype(::DataLoader{D}) where D = D
20 changes: 20 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Y = [1:5;]

d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
Expand All @@ -11,20 +12,23 @@
@test batches[3] == X[:,5:5]

d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]

d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)

d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
Expand All @@ -38,6 +42,22 @@
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]

# test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == batches[1].x == X[:,1:2]
@test batches[1][2] == batches[1].y == Y[1:2]
@test batches[2][1] == batches[2].x == X[:,3:4]
@test batches[2][2] == batches[2].y == Y[3:4]
@test batches[3][1] == batches[3].x == X[:,5:5]
@test batches[3][2] == batches[3].y == Y[5:5]

# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)
Expand Down

0 comments on commit 19b45b4

Please sign in to comment.