From 9d55bfd652f464c36cdc19b4b56200cae95ced5b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 22 Sep 2022 22:28:48 -0400 Subject: [PATCH 1/2] pretty printing for DataLoader --- src/eachobs.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/eachobs.jl b/src/eachobs.jl index 7f49306..1314b96 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -255,3 +255,29 @@ end e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) _dataloader_foldl1(rf, val, e, ObsView(e.data)) end + +function Base.show(io::IO, e::DataLoader) + print(io, "DataLoader(") + Base.showarg(io, e.data, false) + e.buffer == false || print(io, ", buffer=", e.buffer) + e.parallel == false || print(io, ", parallel=", e.parallel) + e.shuffle == false || print(io, ", shuffle=", e.shuffle) + e.batchsize == 1 || print(io, ", batchsize=", e.batchsize) + e.partial == true || print(io, ", partial=", e.partial) + e.collate == Val(nothing) || print(io, ", collate=", e.collate) + e.rng == Random.GLOBAL_RNG || print(io, ", rng=", e.rng) + print(io, ")") +end + +function Base.show(io::IO, m::MIME"text/plain", e::DataLoader) + print(io, length(e), "-element ") + show(io, e) + # print(io, " for ", numobs(e.data), " observations,") + println(io, "\n starting with:") + print(io, " ", _summary(first(e))) +end + +_summary(x) = summary(x) +_summary(xs::Tuple) = "tuple(" * join([_summary(x) for x in xs], ", ") * ")" +_summary(xs::NamedTuple) = "(; " * join(["$k = "*_summary(x) for (k,x) in zip(keys(xs),xs)], ", ") * ")" + From 9043acc8f8f1006aa5ce86ff05d9397ca139a711 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 23 Sep 2022 09:08:17 -0400 Subject: [PATCH 2/2] tidy, tests --- src/eachobs.jl | 30 +++++++++++++++++++++--------- test/dataloader.jl | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/eachobs.jl b/src/eachobs.jl index 1314b96..99c96c4 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -256,7 +256,8 @@ end _dataloader_foldl1(rf, val, e, ObsView(e.data)) end -function Base.show(io::IO, e::DataLoader) +# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) +function Base.showarg(io::IO, e::DataLoader, toplevel) print(io, "DataLoader(") Base.showarg(io, e.data, false) e.buffer == false || print(io, ", buffer=", e.buffer) @@ -269,15 +270,26 @@ function Base.show(io::IO, e::DataLoader) print(io, ")") end +Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) + function Base.show(io::IO, m::MIME"text/plain", e::DataLoader) - print(io, length(e), "-element ") - show(io, e) - # print(io, " for ", numobs(e.data), " observations,") - println(io, "\n starting with:") - print(io, " ", _summary(first(e))) + if Base.haslength(e) + print(io, length(e), "-element ") + else + print(io, "Unknown-length ") + end + Base.showarg(io, e, false) + print(io, "\n with first element:") + print(io, "\n ", _expanded_summary(first(e))) end -_summary(x) = summary(x) -_summary(xs::Tuple) = "tuple(" * join([_summary(x) for x in xs], ", ") * ")" -_summary(xs::NamedTuple) = "(; " * join(["$k = "*_summary(x) for (k,x) in zip(keys(xs),xs)], ", ") * ")" +_expanded_summary(x) = summary(x) +function _expanded_summary(xs::Tuple) + parts = [_expanded_summary(x) for x in xs] + "(" * join(parts, ", ") * ",)" +end +function _expanded_summary(xs::NamedTuple) + parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] + "(; " * join(parts, ", ") * ")" +end diff --git a/test/dataloader.jl b/test/dataloader.jl index 45b2a2c..dc569d7 100644 --- a/test/dataloader.jl +++ b/test/dataloader.jl @@ -214,4 +214,25 @@ dloader = DataLoader(1:1000; batchsize = 2, shuffle = true) @test copy(Map(x -> x[1]), Vector{Int}, dloader) != collect(1:2:1000) end + + @testset "printing" begin + X2 = reshape(Float32[1:10;], (2, 5)) + Y2 = [1:5;] + + d = DataLoader((X2, Y2), batchsize=3) + + @test contains(repr(d), "DataLoader(::Tuple{Matrix") + @test contains(repr(d), "batchsize=3") + + @test contains(repr(MIME"text/plain"(), d), "2-element DataLoader") + @test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector") + + d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false) + + @test contains(repr(d2), "DataLoader(::NamedTuple") + @test contains(repr(d2), "partial=false") + + @test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::NamedTuple") + @test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector") + end end