Skip to content

Commit

Permalink
threaded utils; generalize split_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jondeuce committed Apr 13, 2024
1 parent 56ee5b1 commit 46e9da8
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 38 deletions.
7 changes: 4 additions & 3 deletions api/DECAESCLI/cli_builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function execute(cmd::Cmd)
err = Pipe()
process = run(pipeline(ignorestatus(cmd); stdout = err, stderr = err))
close(err.in)
return (stderr = String(read(err)), exitcode = process.exitcode)
return (; stderr = String(read(err)), exitcode = process.exitcode)
end

function install()
Expand Down Expand Up @@ -50,7 +50,7 @@ function install()
end

function cli_script()

# Create a batch (Windows) or bash (Unix) script for running DECAES CLI
cmds = String[]

# Julia executable path
Expand All @@ -68,8 +68,8 @@ function cli_script()
push!(cmds, "-- \"\${BASH_SOURCE[0]}\" \"\$@\"")
end

# Return batch script on Windows or bash script on Linux
if Sys.iswindows()
# Windows batch script
"""
@echo off
setlocal
Expand All @@ -80,6 +80,7 @@ function cli_script()
endlocal
"""
else
# Unix shell script
"""
#!/usr/bin/env bash
#=
Expand Down
5 changes: 3 additions & 2 deletions src/DECAES.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Statistics: Statistics, mean, std
using ArgParse: ArgParse, ArgParseSettings, add_arg_group!, add_arg_table!, parse_args
using Dierckx: Dierckx
using DocStringExtensions: DocStringExtensions, @doc, FIELDS, SIGNATURES, TYPEDFIELDS, TYPEDSIGNATURES
using ForwardDiff: ForwardDiff, DiffResults
using ForwardDiff: ForwardDiff, DiffResults, Dual
using LoggingExtras: LoggingExtras, FileLogger, LevelOverrideLogger, TeeLogger, TransformerLogger
using MAT: MAT
using MuladdMacro: MuladdMacro, @muladd
Expand All @@ -28,6 +28,7 @@ using Parameters: Parameters, @with_kw, @with_kw_noshow
using PrecompileTools: PrecompileTools, @compile_workload, @setup_workload
using ProgressMeter: ProgressMeter, Progress, BarGlyphs
# using Roots: Roots
# using SLEEFPirates: SLEEFPirates
# using SIMD: SIMD, FloatingTypes, Vec, shufflevector
using Scratch: Scratch, @get_scratch!, get_scratch!
using SpecialFunctions: SpecialFunctions, erfc, erfinv
Expand Down Expand Up @@ -64,7 +65,7 @@ export main
main(["--help"])
main(["--version"])
mock_load_image()
for Reg in ["lcurve", "gcv", "chi2", "mdp"]
for Reg in ["none", "lcurve", "gcv", "chi2", "mdp"]
NumVoxels = max(4, Threads.nthreads()) * default_blocksize()
mock_T2_pipeline(; MatrixSize = (NumVoxels, 1, 1), Reg)
end
Expand Down
6 changes: 3 additions & 3 deletions src/NNLS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Charles L. Lawson and Richard J. Hanson at Jet Propulsion Laboratory
Revised FEB 1995 to accompany reprinting of the book by SIAM.
"""
@inline function construct_householder!(x::AbstractVector{T}) where {T}
if length(x) == 0
if length(x) <= 1
return zero(T)
end

Expand Down Expand Up @@ -249,7 +249,7 @@ function apply_householder!(
tau::T,
) where {T}
m = length(u)
if m == 0
if m <= 1
return nothing
end

Expand Down Expand Up @@ -282,7 +282,7 @@ function apply_householder_dual!(
j1::Int,
m1::Int,
) where {T}
if j1 > m1
if j1 >= m1
return nothing
end

Expand Down
26 changes: 13 additions & 13 deletions src/T2mapSEcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ function T2Maps(opts::T2mapOptions{T}) where {T}
convert(Array{T}, copy(thread_buffer.flip_angle_work.decay_basis)),

# Default output maps
gdn = fill(T(NaN), opts.MatrixSize...),
ggm = fill(T(NaN), opts.MatrixSize...),
gva = fill(T(NaN), opts.MatrixSize...),
fnr = fill(T(NaN), opts.MatrixSize...),
snr = fill(T(NaN), opts.MatrixSize...),
alpha = fill(T(NaN), opts.MatrixSize...),
gdn = tfill(T(NaN), opts.MatrixSize...),
ggm = tfill(T(NaN), opts.MatrixSize...),
gva = tfill(T(NaN), opts.MatrixSize...),
fnr = tfill(T(NaN), opts.MatrixSize...),
snr = tfill(T(NaN), opts.MatrixSize...),
alpha = tfill(T(NaN), opts.MatrixSize...),
is_alpha_provided = Ref(false),

# Optional output maps
resnorm = !opts.SaveResidualNorm ? nothing : fill(T(NaN), opts.MatrixSize...),
decaycurve = !opts.SaveDecayCurve ? nothing : fill(T(NaN), opts.MatrixSize..., opts.nTE),
mu = !opts.SaveRegParam ? nothing : fill(T(NaN), opts.MatrixSize...),
chi2factor = !opts.SaveRegParam ? nothing : fill(T(NaN), opts.MatrixSize...),
resnorm = !opts.SaveResidualNorm ? nothing : tfill(T(NaN), opts.MatrixSize...),
decaycurve = !opts.SaveDecayCurve ? nothing : tfill(T(NaN), opts.MatrixSize..., opts.nTE),
mu = !opts.SaveRegParam ? nothing : tfill(T(NaN), opts.MatrixSize...),
chi2factor = !opts.SaveRegParam ? nothing : tfill(T(NaN), opts.MatrixSize...),
decaybasis = !opts.SaveNNLSBasis ? nothing :
opts.SetFlipAngle === nothing ?
fill(T(NaN), opts.MatrixSize..., opts.nTE, opts.nT2) : # unique decay basis set for each voxel
tfill(T(NaN), opts.MatrixSize..., opts.nTE, opts.nT2) : # unique decay basis set for each voxel
convert(Array{T}, copy(thread_buffer.decay_basis)), # single decay basis set used for all voxels
)
end
Expand All @@ -67,7 +67,7 @@ end

function T2Distributions(opts::T2mapOptions{T}) where {T}
return T2Distributions(;
distributions = fill(T(NaN), opts.MatrixSize..., opts.nT2),
distributions = tfill(T(NaN), opts.MatrixSize..., opts.nT2),
)
end

Expand Down Expand Up @@ -180,7 +180,7 @@ function T2mapSEcorr!(
return convert(Dict{String, Any}, maps), convert(Array{T, 4}, dist)
end
ntasks = opts.Threaded ? Threads.nthreads() : 1
indices_blocks = split_indices(length(indices), default_blocksize())
indices_blocks = split_indices(; length = length(indices), minchunksize = default_blocksize())

# Run analysis in parallel
with_singlethreaded_blas() do
Expand Down
11 changes: 5 additions & 6 deletions src/T2partSEcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ function T2partSEcorr(T2distributions::Array{T, 4}, opts::T2partOptions{T}) wher
# Run T2-Part analysis
ntasks = opts.Threaded ? Threads.nthreads() : 1
indices = CartesianIndices(opts.MatrixSize)
blocksize = ceil(Int, length(indices) / ntasks)
indices_blocks = split_indices(length(indices), blocksize)
indices_blocks = split_indices(; length = length(indices), minchunksize = default_blocksize(), maxpartitions = ntasks)

with_singlethreaded_blas() do
workerpool(with_thread_buffer, indices_blocks; ntasks, verbose = !opts.Silent) do inds, thread_buffer
Expand All @@ -83,10 +82,10 @@ Base.convert(::Type{Dict{String, Any}}, maps::T2Parts) = Dict{String, Any}(Any[s

function T2Parts(opts::T2partOptions{T}) where {T}
return T2Parts(;
sfr = fill(T(NaN), opts.MatrixSize...),
sgm = fill(T(NaN), opts.MatrixSize...),
mfr = fill(T(NaN), opts.MatrixSize...),
mgm = fill(T(NaN), opts.MatrixSize...),
sfr = tfill(T(NaN), opts.MatrixSize...),
sgm = tfill(T(NaN), opts.MatrixSize...),
mfr = tfill(T(NaN), opts.MatrixSize...),
mgm = tfill(T(NaN), opts.MatrixSize...),
)
end

Expand Down
48 changes: 43 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ function tforeach(work!, allocate, x::AbstractArray; blocksize::Int = default_bl
nt = Threads.nthreads()
len = length(x)
if nt > 1 && len > blocksize
@sync for p in split_indices(len, blocksize)
@sync for p in split_indices(; length = len, minchunksize = blocksize)
Threads.@spawn allocate() do resource
@simd for i in p
work!(x[i], resource)
Expand All @@ -360,6 +360,26 @@ function tforeach(work!, allocate, x::AbstractArray; blocksize::Int = default_bl
end
tforeach(f, x::AbstractArray; kwargs...) = tforeach((x, r) -> f(x), g -> g(nothing), x; kwargs...)

function tmap!(f, y::AbstractArray, x::AbstractArray; blocksize::Int = 1024^2)
if Threads.nthreads() > 1 && length(x) >= blocksize
Threads.@threads for i in eachindex(x, y)
@inbounds xi = x[i]
@inbounds y[i] = f(xi)
end
else
@inbounds @simd for i in eachindex(x, y)
xi = x[i]
y[i] = f(xi)
end
end
return x
end
tmap!(f, x::AbstractArray; kwargs...) = tmap!(f, x, x; kwargs...)

tfill!(v, x) = tmap!(Returns(v), x)
tfill(v, sz::NTuple{N, Int}) where {N} = tfill!(v, Array{typeof(v), N}(undef, sz))
tfill(v, sz::Int...) = tfill(v, sz)

default_blocksize() = 64

# Worker pool for allocating thread-local resources. This is a more robust alternative to
Expand Down Expand Up @@ -421,11 +441,29 @@ function workerpool(work!, allocate, inputs, args...; kwargs...)
return workerpool(work!, allocate, ch, args...; ninputs = length(inputs), kwargs...)
end

function split_indices(len::Int, basesize::Int)
len′ = Int64(len) # Avoid overflow on 32-bit machines
np = max(1, div(len′, basesize))
return [Int(1 + ((i - 1) * len′) ÷ np):Int((i * len′) ÷ np) for i in 1:np]
function split_indices(; length::Int, minchunksize::Union{Int, Nothing} = nothing, maxpartitions::Union{Int, Nothing} = nothing)
@assert minchunksize === nothing || minchunksize::Int > 0 "Basesize must be a positive integer, got minchunksize = $minchunksize"
@assert maxpartitions === nothing || maxpartitions::Int > 0 "Maximum partitions must be a positive integer, got maxpartitions = $maxpartitions"
return split_indices(length, something(minchunksize, 1)::Int, something(maxpartitions, length)::Int)
end

function split_indices(len::Int, minchunksize::Int, maxpartitions::Int)
# Partition the range `1:len` into at most `maxpartitions` ranges with balanced lengths which are all at least `min(len, minchunksize)`
chunks = len >= minchunksize * maxpartitions ? maxpartitions : len >= minchunksize ? len ÷ minchunksize : 1
return split_indices(len, chunks)
end

function split_indices(len::Int, chunks::Int)
# Partition the range `1:len` into `chunks` ranges with balanced lengths
return MappedArray{UnitRange{Int64}}(SubRange(len, chunks), 1:chunks)
end

# Produce the `i`th subrange resulting from dividing `1:len` into `chunks` partitions. Assumes 1 <= chunks <= len
struct SubRange
len::Int64 # avoid overflow on 32-bit machines
chunks::Int64
end
(r::SubRange)(i::Int) = Int(1 + ((i - 1) * r.len) ÷ r.chunks):Int((i * r.len) ÷ r.chunks)

####
#### Logging
Expand Down
14 changes: 8 additions & 6 deletions test/interactive/compare/compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ versions = [
settings_files = [
"settings.txt",
]
cli_extra_args = `--Reg lcurve --SaveResidualNorm --SaveRegParam --quiet`
cli_extra_args = `--Reg lcurve --SaveResidualNorm --SaveRegParam --quiet`

for (julia_version, decaes_versions) in versions
for decaes_version in decaes_versions, settings in settings_files
@info "----------------"
@info "Running DECAES with Julia $(julia_version) and DECAES $(decaes_version)"
@info "----------------"

project_dir = joinpath(@__DIR__, "envs", "julia-v" * julia_version, decaes_version)
output_dir = joinpath(@__DIR__, "output", splitext(settings)[1], "julia-v" * julia_version * "_" * decaes_version)
@time run(`julia +$(julia_version) --startup-file=no --project=$(project_dir) --threads=$(Threads.nthreads()) -e "using DECAES; main()" @$(joinpath(@__DIR__, "settings", settings)) --output $(output_dir) $(cli_extra_args)`)
cmd = `julia +$(julia_version) --startup-file=no --project=$(project_dir) --threads=auto -e "using DECAES; main()" -- @$(joinpath(@__DIR__, "settings", settings)) --output $(output_dir) $(cli_extra_args)`

@info "----------------"
@info "Running DECAES with Julia $(julia_version) and DECAES $(decaes_version)" cmd
@info "----------------"
println()
@time run(cmd)
println()
end
end
Expand Down
37 changes: 37 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,40 @@ end
@test @allocated(svdvals!(work, A)) == 0
end
end

@testset "split_indices" begin
function test_valid_partition(p, len)
@test first(first(p)) == 1
@test last(last(p)) == len
@test all(i -> last(p[i-1]) + 1 == first(p[i]), 2:length(p))
end

for len in 1:20, minchunksize in 1:20
p = DECAES.split_indices(; length = len, minchunksize)
test_valid_partition(p, len)

if len <= minchunksize
@test length(p) == 1
@test length(only(p)) == len
else
@test 1 <= length(p) <= len
@test length(p) == len ÷ minchunksize
@test all(c -> length(c) >= minchunksize, p)
end
end

for len in 1:10, minchunksize in 1:10, maxpartitions in 1:10
_basesize = min(len, max(minchunksize, len ÷ maxpartitions))
p = DECAES.split_indices(; length = len, minchunksize, maxpartitions)
test_valid_partition(p, len)

@test length(p) >= 1
@test all(c -> length(c) >= _basesize, p)
if len >= minchunksize * maxpartitions
@test length(p) == maxpartitions
else
@test 1 <= length(p) <= min(len, maxpartitions)
@test length(p) == len ÷ _basesize
end
end
end

0 comments on commit 46e9da8

Please sign in to comment.