diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 0d9f0f8dd..e517b7d9a 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,6 @@ style="blue" format_markdown = true +# The below should actually be part of Blue according to +# https://github.com/JuliaDiff/BlueStyle?tab=readme-ov-file#method-definitions +# but JuliaFormatter v2.10 doesn't enforce it. +always_use_return = true diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..7009bc2cb 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -47,7 +47,7 @@ A short model that tries to cover many DynamicPPL features. Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating a variable vector; observations passed as arguments, and as literals. """ -@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV} +@model function smorgasbord(x, y, (::Type{TV})=Vector{Float64}) where {TV} @assert length(x) == length(y) m ~ truncated(Normal(); lower=0) means ~ product_distribution(fill(Exponential(m), length(x))) @@ -68,7 +68,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio See `multivariate` for a version that uses `product_distribution` rather than loops. """ -@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV} +@model function loop_univariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV} a = TV(undef, num_dims) o = TV(undef, num_dims) for i in 1:num_dims @@ -88,7 +88,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio See `loop_univariate` for a version that uses loops rather than `product_distribution`. """ -@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV} +@model function multivariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV} a = TV(undef, num_dims) o = TV(undef, num_dims) a ~ product_distribution(fill(Normal(0, 1), num_dims)) @@ -118,7 +118,7 @@ end A model with random variables that have changing support under linking, or otherwise complicated bijectors. """ -@model function dynamic(::Type{T}=Vector{Float64}) where {T} +@model function dynamic((::Type{T})=Vector{Float64}) where {T} eta ~ truncated(Normal(); lower=0.0, upper=0.1) mat1 ~ LKJCholesky(4, eta) mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0])) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..9ffc2ac1e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -161,8 +161,8 @@ function _predictive_samples_to_arrays(predictive_samples) variable_names = collect(variable_names_set) variable_values = [ - get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), - key in variable_names + get(sample_dicts[i], key, missing) for + i in eachindex(sample_dicts), key in variable_names ] return variable_names, variable_values @@ -254,7 +254,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + return model(deepcopy(varinfo)) end end diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..4c2fcdf13 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -635,7 +635,7 @@ end function namedtuple_from_splitargs(splitargs) names = map(splitargs) do (arg_name, arg_type, is_splat, default) - is_splat ? Symbol("#splat#$(arg_name)") : arg_name + return is_splat ? Symbol("#splat#$(arg_name)") : arg_name end names_expr = Expr(:tuple, map(QuoteNode, names)...) vals = Expr(:tuple, map(first, splitargs)...) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 328fe6983..d825aa01b 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -521,7 +521,7 @@ function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs... ) results = map(1:num_evals) do _ - check_model_and_trace(rng, model; kwargs...) + return check_model_and_trace(rng, model; kwargs...) end issuccess = all(first, results) issuccess || throw(ArgumentError("model check failed")) @@ -530,7 +530,7 @@ function has_static_constraints( traces = map(last, results) dists_per_trace = map(distributions_in_trace, traces) transforms = map(dists_per_trace) do dists - map(DynamicPPL.link_transform, dists) + return map(DynamicPPL.link_transform, dists) end # Check if the distributions are the same across all runs. diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0f312fa2c..a52c3e218 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -105,8 +105,9 @@ julia> length(extract_priors(rng, model)[@varname(x)]) 9 ``` """ -extract_priors(args::Union{Model,AbstractVarInfo}...) = - extract_priors(Random.default_rng(), args...) +function extract_priors(args::Union{Model,AbstractVarInfo}...) + return extract_priors(Random.default_rng(), args...) +end function extract_priors(rng::Random.AbstractRNG, model::Model) context = PriorExtractorContext(SamplingContext(rng)) evaluate!!(model, VarInfo(), context) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..00ad9d582 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -175,7 +175,7 @@ end Evaluate the log density of the given `model` at the given parameter values `x`, using the given `varinfo` and `context`. Note that the `varinfo` argument is provided only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +it, and its own parameters are discarded. """ function logdensity_at( x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext @@ -245,9 +245,11 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype( +function tweak_adtype( adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext -) = adtype +) + return adtype +end """ use_closure(adtype::ADTypes.AbstractADType) diff --git a/src/model.jl b/src/model.jl index a0451b1b6..26841f4d3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -96,8 +96,9 @@ Return a `Model` which now treats variables on the right-hand side as observatio See [`condition`](@ref) for more information and examples. """ -Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) = - condition(model, values) +function Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) + return condition(model, values) +end """ condition(model::Model; values...) @@ -1068,7 +1069,7 @@ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) ) - logjoint(model, argvals_dict) + return logjoint(model, argvals_dict) end end @@ -1115,7 +1116,7 @@ function logprior(model::Model, chain::AbstractMCMC.AbstractChains) values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) ) - logprior(model, argvals_dict) + return logprior(model, argvals_dict) end end @@ -1162,7 +1163,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) ) - loglikelihood(model, argvals_dict) + return loglikelihood(model, argvals_dict) end end @@ -1467,5 +1468,6 @@ ERROR: ArgumentError: `~` with a model on the right-hand side of an observe stat [...] ``` """ -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) +function to_submodel(model::Model, auto_prefix::Bool=true) + return to_sampleable(returned(model), auto_prefix) +end diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..dbf47f6c4 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -204,6 +204,8 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain) return Iterators.map( Iterators.product(1:size(chain, 1), 1:size(chain, 3)) ) do (iteration_idx, chain_idx) - values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()) + return values_from_chain!( + vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}() + ) end end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 064483ddd..91725e621 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::TypedVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( @@ -315,7 +315,7 @@ function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) + return getindex(vi, vn, dist) end return recombine(dist, vals_linked, length(vns)) end @@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName # Attempt to split into `parent` and `child` optic. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(dict, VarName(vn, o)) + return haskey(dict, VarName(vn, o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index e29614982..2662c4994 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -49,7 +49,7 @@ x[4:5] ~ Dirichlet([1.0, 2.0]) ``` """ @model function demo_one_variable_multiple_constraints( - ::Type{TV}=Vector{Float64} + (::Type{TV})=Vector{Float64} ) where {TV} x = TV(undef, 5) x[1] ~ Normal() @@ -186,7 +186,9 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end -@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe( + x=[1.5, 2.0], (::Type{TV})=Vector{Float64} +) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) m = TV(undef, length(x)) @@ -212,7 +214,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe)}) end @model function demo_assume_index_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], (::Type{TV})=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` s = TV(undef, length(x)) @@ -268,7 +270,7 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) end @model function demo_dot_assume_observe_index( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], (::Type{TV})=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing s = TV(undef, length(x)) @@ -348,7 +350,9 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)} return [@varname(s), @varname(m)] end -@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_index_literal( + (::Type{TV})=Vector{Float64} +) where {TV} # `dot_assume` and literal `observe` with indexing s = TV(undef, 2) m = TV(undef, 2) @@ -425,7 +429,7 @@ function varnames(model::Model{typeof(demo_assume_dot_observe_literal)}) end # Only used as a submodel -@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} +@model function _prior_dot_assume((::Type{TV})=Vector{Float64}) where {TV} s = TV(undef, 2) s .~ InverseGamma(2, 3) m = TV(undef, 2) @@ -466,7 +470,7 @@ end end @model function demo_dot_assume_observe_submodel( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], (::Type{TV})=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -496,7 +500,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) end @model function demo_dot_assume_observe_matrix_index( - x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} + x=transpose([1.5 2.0;]), (::Type{TV})=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -525,7 +529,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) end @model function demo_assume_matrix_observe_matrix_index( - x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} + x=transpose([1.5 2.0;]), (::Type{TV})=Array{Float64} ) where {TV} n = length(x) d = n ÷ 2 diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6a655ded4..f55160ce1 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -58,7 +58,7 @@ function setup_varinfos( svi_vnv_ref, )) do vi # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + return DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) end if include_threadsafe diff --git a/src/utils.jl b/src/utils.jl index 50f9baf61..2412ce5a5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -710,7 +710,7 @@ ERROR: Could not find x.a[2] in x.a[1] function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} _, child, issuccess = splitoptic(getoptic(vn_child)) do optic o = optic === nothing ? identity : optic - VarName(vn_child, o) == vn_parent + return VarName(vn_child, o) == vn_parent end issuccess || error("Could not find $vn_parent in $vn_child") @@ -905,7 +905,7 @@ function hasvalue(vals::AbstractDict, vn::VarName) # If `issuccess` is `true`, we found such a split, and hence `vn` is present. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(vals, VarName(vn, o)) + return haskey(vals, VarName(vn, o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent @@ -934,7 +934,7 @@ function nested_getindex(values::AbstractDict, vn::VarName) # Split the optic into the key / `parent` and the extraction optic / `child`. parent, child, issuccess = splitoptic(getoptic(vn)) do optic o = optic === nothing ? identity : optic - haskey(values, VarName(vn, o)) + return haskey(values, VarName(vn, o)) end # When combined with `VarInfo`, `nothing` is equivalent to `identity`. keyoptic = parent === nothing ? identity : parent @@ -1078,7 +1078,7 @@ end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym optic = Accessors.PropertyLens{sym}() - varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val)) + return varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end @@ -1244,7 +1244,7 @@ end function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym optic = DynamicPPL.Accessors.PropertyLens{sym}() - varname_and_value_leaves_inner( + return varname_and_value_leaves_inner( VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) ) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 0c033e504..c48195708 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -311,7 +311,7 @@ function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) vns_syms = Set(unique(map(getsym, vns))) syms = filter(Base.Fix2(in, vns_syms), keys(metadata)) metadatas = map(syms) do sym - subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) + return subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) end return NamedTuple{syms}(metadatas) end @@ -327,7 +327,7 @@ end # TODO(mhauru) Note that this could still generate an empty metadata object if none # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for # emptiness would make this type unstable again. - :((; $sym=subset(metadata.$sym, vns))) + :((; ($sym)=subset(metadata.$sym, vns))) else :(NamedTuple{}()) end @@ -708,8 +708,9 @@ findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_grouped_by_symbol(vi::TypedVarInfo) = - all_varnames_grouped_by_symbol(vi.metadata) +function all_varnames_grouped_by_symbol(vi::TypedVarInfo) + return all_varnames_grouped_by_symbol(vi.metadata) +end @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -981,25 +982,22 @@ end if !(f in vns_names) continue end - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) - if !isempty(f_vns) - if !istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, true, vn) end + else + @warn("[DynamicPPL] attempt to link a linked vi") end - end, - ) + end + end) end return expr end @@ -1085,23 +1083,20 @@ end continue end - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") + push!(expr.args, quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns.$f, f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) end - end, - ) + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") + end + end) end return expr end @@ -1496,7 +1491,7 @@ end function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) + return getindex(vi, vn, dist) end return recombine(dist, vals_linked, length(vns)) end @@ -1542,7 +1537,7 @@ Check whether `vn` has a value in `vi`. Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) function Base.haskey(vi::TypedVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata - haskey(metadata, vn) + return haskey(metadata, vn) end return any(md_haskey) end @@ -1774,23 +1769,20 @@ end f_idcs = :(idcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) - push!( - expr.args, - quote - # Set the flag for variables with symbol `f` - if num_produce == 0 - for i in length($f_idcs):-1:1 - $f_flags["del"][$f_idcs[i]] = true - end - else - for i in 1:length($f_orders) - if i in $f_idcs && $f_orders[i] > num_produce - $f_flags["del"][i] = true - end + push!(expr.args, quote + # Set the flag for variables with symbol `f` + if num_produce == 0 + for i in length($f_idcs):-1:1 + $f_flags["del"][$f_idcs[i]] = true + end + else + for i in 1:length($f_orders) + if i in $f_idcs && $f_orders[i] > num_produce + $f_flags["del"][i] = true end end - end, - ) + end + end) end return expr end @@ -1831,7 +1823,7 @@ end kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n - quote + return quote for vn in Base.keys(metadata.$n) indices_found = kernel!(vi, vn, values, keys_strings) if indices_found !== nothing @@ -1863,7 +1855,7 @@ function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) string_vns = map(string, collect_maybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key - !any(Base.Fix2(subsumes_string, key), string_vns) + return !any(Base.Fix2(subsumes_string, key), string_vns) end return missing_keys diff --git a/test/ad.jl b/test/ad.jl index a4f3dbfa7..76ba28281 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -71,7 +71,7 @@ using DynamicPPL: LogDensityFunction t = 1:0.05:8 σ = 0.3 y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} + @model function state_space(y, TT, (::Type{T})=Float64) where {T} # Priors α ~ Normal(y[1], 0.001) τ ~ Exponential(1) @@ -94,9 +94,11 @@ using DynamicPPL: LogDensityFunction # overload assume so that model evaluation doesn't fail due to a lack # of implementation struct MyEmptyAlg end - DynamicPPL.assume( + function DynamicPPL.assume( ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) + ) + return DynamicPPL.assume(dist, vn, vi) + end # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) @@ -117,7 +119,7 @@ using DynamicPPL: LogDensityFunction return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + @model function scalar_matrix_model((::Type{T})=Float64) where {T<:Real} m = Matrix{T}(undef, 2, 3) return m ~ filldist(MvNormal(zeros(2), I), 3) end @@ -126,14 +128,14 @@ using DynamicPPL: LogDensityFunction scalar_matrix_model, test_m, ref_adtype ) - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + @model function matrix_model((::Type{T})=Matrix{Float64}) where {T} m = T(undef, 2, 3) return m ~ filldist(MvNormal(zeros(2), I), 3) end matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + @model function scalar_array_model((::Type{T})=Float64) where {T<:Real} m = Array{T}(undef, 2, 3) return m ~ filldist(MvNormal(zeros(2), I), 3) end @@ -142,7 +144,7 @@ using DynamicPPL: LogDensityFunction scalar_array_model, test_m, ref_adtype ) - @model function array_model(::Type{T}=Array{Float64}) where {T} + @model function array_model((::Type{T})=Array{Float64}) where {T} m = T(undef, 2, 3) return m ~ filldist(MvNormal(zeros(2), I), 3) end diff --git a/test/compiler.jl b/test/compiler.jl index 3d3c6d9e3..e44234b14 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,6 +1,6 @@ macro custom(expr) (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage") - quote + return quote $(esc(expr.args[2])) = 0.0 end end @@ -487,7 +487,7 @@ module Issue537 end @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 # AR1 model. Dynamic prefixing. - @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV} + @model function AR1(num_steps, α, μ, σ, (::Type{TV})=Vector{Float64}) where {TV} η ~ MvNormal(zeros(num_steps), I) δ = sqrt(1 - α^2) x = TV(undef, num_steps) @@ -678,7 +678,7 @@ module Issue537 end end @testset "issue #393: anonymous argument with type parameter" begin - @model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1 + @model f_393((::Val{ispredict})=Val(false)) where {ispredict} = ispredict ? 0 : 1 @test f_393()() == 1 @test f_393(Val(true))() == 0 end @@ -759,7 +759,7 @@ module Issue537 end @testset "signature parsing + TypeWrap" begin @model function demo_typewrap( - a, b=1, ::Type{T1}=Float64; c, d=2, t::Type{T2}=Int + a, b=1, (::Type{T1})=Float64; c, d=2, t::Type{T2}=Int ) where {T1,T2} return (; a, b, c, d, t) end diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..4f4c23b16 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -182,10 +182,9 @@ end @model function demo_highdim_dirichlet(ns...) return x ~ filldist(Dirichlet(ones(2)), ns...) end - @testset "ns=$ns" for ns in [ - (3,), - # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304 - # (3, 4), (3, 4, 5) + @testset "ns=$ns" for ns in [(3,), + # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304 + # (3, 4), (3, 4, 5) ] model = demo_highdim_dirichlet(ns...) example_values = rand(NamedTuple, model) diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..18de427c0 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -43,7 +43,7 @@ end # Build correlation matrix from factor corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) + return pd_from_triangular(M, uplo) end @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol end @@ -54,7 +54,7 @@ end # Build correlation matrix from factor corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) + return pd_from_triangular(M, uplo) end @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol end diff --git a/test/model.jl b/test/model.jl index a863b6596..459a7665c 100644 --- a/test/model.jl +++ b/test/model.jl @@ -353,7 +353,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ) vns = map(first, first(vns_and_vals_xs)) vals = map(vns_and_vals_xs) do vns_and_vals - map(last, vns_and_vals) + return map(last, vns_and_vals) end # Construct the chain. @@ -376,7 +376,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() d end vals_with_extra = map(enumerate(vals)) do (i, v) - vcat(v, i) + return vcat(v, i) end chain_with_extra = MCMCChains.Chains( permutedims(stack(vals_with_extra)), diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 8e48814a4..a0ce0aa94 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -87,6 +87,7 @@ @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( SimpleVarInfo(Dict()), diff --git a/test/test_util.jl b/test/test_util.jl index 87c69b5fe..41daf8615 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -33,8 +33,9 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" +function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) + return "threadsafe($(short_varinfo_name(vi.varinfo)))" +end function short_varinfo_name(vi::TypedVarInfo) DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" return "TypedVarInfo" @@ -91,7 +92,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) tuples = mapreduce(collect, vcat, iters) push!(varnames, map(first, tuples)...) - OrderedDict(tuples) + return OrderedDict(tuples) end # Convert back to list varnames = collect(varnames) diff --git a/test/varinfo.jl b/test/varinfo.jl index 74feb42f6..0e47cc30b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -184,7 +184,7 @@ end return x ~ MvNormal(m, s^2 * I) end - @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV} + @model function testmodel_univariate(x, (::Type{TV})=Vector{Float64}) where {TV} n = length(x) s ~ truncated(Normal(), 0, Inf) @@ -374,7 +374,7 @@ end end @testset "link!! and invlink!!" begin - @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin + @model gdemo(a, b, (::Type{T})=Float64) where {T} = begin s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) x = Vector{T}(undef, length(a)) @@ -534,7 +534,9 @@ end vals = values_as(vi, OrderedDict) # All varnames in `vns` should be subsumed by one of `keys(vals)`. @test all(vns) do vn - any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) + return any( + DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals) + ) end # Iterate over `keys(vals)` because we might have scenarios such as # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is @@ -624,7 +626,7 @@ end end @testset "subset" begin - @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} + @model function demo_subsetting_varinfo((::Type{TV})=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) x = TV(undef, 2) @@ -691,12 +693,14 @@ end @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported + varinfo_subset = subset(varinfo, VarName[]) @test isempty(varinfo_subset) end @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported + varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_subset) @@ -715,6 +719,7 @@ end @testset "$(convert(Vector{VarName}, vns_subset))" for ( vns_subset, vns_target ) in vns_supported_with_subsumes + varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_target) @@ -732,6 +737,7 @@ end @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in vns_supported + varinfo_subset = subset(varinfo, vns_subset) vns_subset_reversed = reverse(vns_subset) varinfo_subset_reversed = subset(varinfo, vns_subset_reversed) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..1c40bc2d0 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -14,7 +14,7 @@ function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, v # we need `vn` to also be of this type. # => If the varname types don't match, we need to relax the container type. return any(keys(vnv)) do vn_present - typeof(vn_present) !== typeof(val) + return typeof(vn_present) !== typeof(val) end end @@ -40,7 +40,7 @@ function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, # If the container is concrete, we need to make sure that the sizes match. # => If the sizes don't match, we need to relax the container type. any(keys(vnv)) do vn_present - size(vnv[vn_present]) != size(val) + return size(vnv[vn_present]) != size(val) end elseif eltype(vnv.transforms) !== Any # If it's not concrete AND it's not `Any`, then we should just make it `Any`. @@ -619,7 +619,7 @@ end @test getlogp(varinfo_sample) != getlogp(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( - varinfo_sample, value_true, vns; compare=!isequal + varinfo_sample, value_true, vns; compare=(!isequal) ) end end