diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml
index 0d9f0f8dd..e517b7d9a 100644
--- a/.JuliaFormatter.toml
+++ b/.JuliaFormatter.toml
@@ -1,2 +1,6 @@
 style="blue"
 format_markdown = true
+# The below should actually be part of Blue according to
+# https://github.com/JuliaDiff/BlueStyle?tab=readme-ov-file#method-definitions
+# but JuliaFormatter v2.10 doesn't enforce it.
+always_use_return = true
diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl
index 2c881aa95..7009bc2cb 100644
--- a/benchmarks/src/Models.jl
+++ b/benchmarks/src/Models.jl
@@ -47,7 +47,7 @@ A short model that tries to cover many DynamicPPL features.
 Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating
 a variable vector; observations passed as arguments, and as literals.
 """
-@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
+@model function smorgasbord(x, y, (::Type{TV})=Vector{Float64}) where {TV}
     @assert length(x) == length(y)
     m ~ truncated(Normal(); lower=0)
     means ~ product_distribution(fill(Exponential(m), length(x)))
@@ -68,7 +68,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio
 
 See `multivariate` for a version that uses `product_distribution` rather than loops.
 """
-@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
+@model function loop_univariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
     a = TV(undef, num_dims)
     o = TV(undef, num_dims)
     for i in 1:num_dims
@@ -88,7 +88,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio
 
 See `loop_univariate` for a version that uses loops rather than `product_distribution`.
 """
-@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
+@model function multivariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
     a = TV(undef, num_dims)
     o = TV(undef, num_dims)
     a ~ product_distribution(fill(Normal(0, 1), num_dims))
@@ -118,7 +118,7 @@ end
 A model with random variables that have changing support under linking, or otherwise
 complicated bijectors.
 """
-@model function dynamic(::Type{T}=Vector{Float64}) where {T}
+@model function dynamic((::Type{T})=Vector{Float64}) where {T}
     eta ~ truncated(Normal(); lower=0.0, upper=0.1)
     mat1 ~ LKJCholesky(4, eta)
     mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0]))
diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl
index 7fcbd6a7c..9ffc2ac1e 100644
--- a/ext/DynamicPPLMCMCChainsExt.jl
+++ b/ext/DynamicPPLMCMCChainsExt.jl
@@ -48,10 +48,10 @@ end
 Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
 in `chain`, and return the resulting `Chains`.
 
-The `model` passed to `predict` is often different from the one used to generate `chain`. 
-Typically, the model from which `chain` originated treats certain variables as observed (i.e., 
-data points), while the model you pass to `predict` may mark these same variables as missing 
-or unobserved. Calling `predict` then leverages the previously inferred parameter values to 
+The `model` passed to `predict` is often different from the one used to generate `chain`.
+Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
+data points), while the model you pass to `predict` may mark these same variables as missing
+or unobserved. Calling `predict` then leverages the previously inferred parameter values to
 simulate what new, unobserved data might look like, given your posterior beliefs.
 
 For each parameter configuration in `chain`:
@@ -59,7 +59,7 @@ For each parameter configuration in `chain`:
 2. Any variables not included in `chain` are sampled from their prior distributions.
 
 If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
-the samples in `chain`. This is useful when you want to sample only new variables from the posterior 
+the samples in `chain`. This is useful when you want to sample only new variables from the posterior
 predictive distribution.
 
 # Examples
@@ -161,8 +161,8 @@ function _predictive_samples_to_arrays(predictive_samples)
 
     variable_names = collect(variable_names_set)
     variable_values = [
-        get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
-        key in variable_names
+        get(sample_dicts[i], key, missing) for
+        i in eachindex(sample_dicts), key in variable_names
     ]
 
     return variable_names, variable_values
@@ -254,7 +254,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
         DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
         # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
         # `deepcopy` the `varinfo` before passing it to the `model`.
-        model(deepcopy(varinfo))
+        return model(deepcopy(varinfo))
     end
 end
 
diff --git a/src/compiler.jl b/src/compiler.jl
index 4771b0171..4c2fcdf13 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -635,7 +635,7 @@ end
 
 function namedtuple_from_splitargs(splitargs)
     names = map(splitargs) do (arg_name, arg_type, is_splat, default)
-        is_splat ? Symbol("#splat#$(arg_name)") : arg_name
+        return is_splat ? Symbol("#splat#$(arg_name)") : arg_name
     end
     names_expr = Expr(:tuple, map(QuoteNode, names)...)
     vals = Expr(:tuple, map(first, splitargs)...)
diff --git a/src/debug_utils.jl b/src/debug_utils.jl
index 328fe6983..d825aa01b 100644
--- a/src/debug_utils.jl
+++ b/src/debug_utils.jl
@@ -521,7 +521,7 @@ function has_static_constraints(
     rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs...
 )
     results = map(1:num_evals) do _
-        check_model_and_trace(rng, model; kwargs...)
+        return check_model_and_trace(rng, model; kwargs...)
     end
     issuccess = all(first, results)
     issuccess || throw(ArgumentError("model check failed"))
@@ -530,7 +530,7 @@ function has_static_constraints(
     traces = map(last, results)
     dists_per_trace = map(distributions_in_trace, traces)
     transforms = map(dists_per_trace) do dists
-        map(DynamicPPL.link_transform, dists)
+        return map(DynamicPPL.link_transform, dists)
     end
 
     # Check if the distributions are the same across all runs.
diff --git a/src/extract_priors.jl b/src/extract_priors.jl
index 0f312fa2c..a52c3e218 100644
--- a/src/extract_priors.jl
+++ b/src/extract_priors.jl
@@ -105,8 +105,9 @@ julia> length(extract_priors(rng, model)[@varname(x)])
 9
 ```
 """
-extract_priors(args::Union{Model,AbstractVarInfo}...) =
-    extract_priors(Random.default_rng(), args...)
+function extract_priors(args::Union{Model,AbstractVarInfo}...)
+    return extract_priors(Random.default_rng(), args...)
+end
 function extract_priors(rng::Random.AbstractRNG, model::Model)
     context = PriorExtractorContext(SamplingContext(rng))
     evaluate!!(model, VarInfo(), context)
diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl
index a42855f05..00ad9d582 100644
--- a/src/logdensityfunction.jl
+++ b/src/logdensityfunction.jl
@@ -175,7 +175,7 @@ end
 Evaluate the log density of the given `model` at the given parameter values `x`,
 using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
 only for its structure, in the sense that the parameters from the vector `x` are inserted into
-it, and its own parameters are discarded. 
+it, and its own parameters are discarded.
 """
 function logdensity_at(
     x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
@@ -245,9 +245,11 @@ model.
 
 By default, this just returns the input unchanged.
 """
-tweak_adtype(
+function tweak_adtype(
     adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
-) = adtype
+)
+    return adtype
+end
 
 """
     use_closure(adtype::ADTypes.AbstractADType)
diff --git a/src/model.jl b/src/model.jl
index a0451b1b6..26841f4d3 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -96,8 +96,9 @@ Return a `Model` which now treats variables on the right-hand side as observatio
 
 See [`condition`](@ref) for more information and examples.
 """
-Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
-    condition(model, values)
+function Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}})
+    return condition(model, values)
+end
 
 """
     condition(model::Model; values...)
@@ -1068,7 +1069,7 @@ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
                 values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
             vn_parent in keys(var_info)
         )
-        logjoint(model, argvals_dict)
+        return logjoint(model, argvals_dict)
     end
 end
 
@@ -1115,7 +1116,7 @@ function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
                 values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
             vn_parent in keys(var_info)
         )
-        logprior(model, argvals_dict)
+        return logprior(model, argvals_dict)
     end
 end
 
@@ -1162,7 +1163,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
                 values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
             vn_parent in keys(var_info)
         )
-        loglikelihood(model, argvals_dict)
+        return loglikelihood(model, argvals_dict)
     end
 end
 
@@ -1467,5 +1468,6 @@ ERROR: ArgumentError: `~` with a model on the right-hand side of an observe stat
 [...]
 ```
 """
-to_submodel(model::Model, auto_prefix::Bool=true) =
-    to_sampleable(returned(model), auto_prefix)
+function to_submodel(model::Model, auto_prefix::Bool=true)
+    return to_sampleable(returned(model), auto_prefix)
+end
diff --git a/src/model_utils.jl b/src/model_utils.jl
index ac4ec7022..dbf47f6c4 100644
--- a/src/model_utils.jl
+++ b/src/model_utils.jl
@@ -204,6 +204,8 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain)
     return Iterators.map(
         Iterators.product(1:size(chain, 1), 1:size(chain, 3))
     ) do (iteration_idx, chain_idx)
-        values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
+        return values_from_chain!(
+            vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()
+        )
     end
 end
diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl
index 064483ddd..91725e621 100644
--- a/src/simple_varinfo.jl
+++ b/src/simple_varinfo.jl
@@ -244,7 +244,7 @@ function SimpleVarInfo{T}(
 end
 
 # Constructor from `VarInfo`.
-function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
+function SimpleVarInfo(vi::TypedVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
     return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
 end
 function SimpleVarInfo{T}(
@@ -315,7 +315,7 @@ function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
 end
 function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
     vals_linked = mapreduce(vcat, vns) do vn
-        getindex(vi, vn, dist)
+        return getindex(vi, vn, dist)
     end
     return recombine(dist, vals_linked, length(vns))
 end
@@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
     # Attempt to split into `parent` and `child` optic.
     parent, child, issuccess = splitoptic(getoptic(vn)) do optic
         o = optic === nothing ? identity : optic
-        haskey(dict, VarName(vn, o))
+        return haskey(dict, VarName(vn, o))
     end
     # When combined with `VarInfo`, `nothing` is equivalent to `identity`.
     keyoptic = parent === nothing ? identity : parent
diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl
index e29614982..2662c4994 100644
--- a/src/test_utils/models.jl
+++ b/src/test_utils/models.jl
@@ -49,7 +49,7 @@ x[4:5] ~ Dirichlet([1.0, 2.0])
 ```
 """
 @model function demo_one_variable_multiple_constraints(
-    ::Type{TV}=Vector{Float64}
+    (::Type{TV})=Vector{Float64}
 ) where {TV}
     x = TV(undef, 5)
     x[1] ~ Normal()
@@ -186,7 +186,9 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
     return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp
 end
 
-@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV}
+@model function demo_dot_assume_observe(
+    x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
+) where {TV}
     # `dot_assume` and `observe`
     s = TV(undef, length(x))
     m = TV(undef, length(x))
@@ -212,7 +214,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe)})
 end
 
 @model function demo_assume_index_observe(
-    x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
+    x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
 ) where {TV}
     # `assume` with indexing and `observe`
     s = TV(undef, length(x))
@@ -268,7 +270,7 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe)})
 end
 
 @model function demo_dot_assume_observe_index(
-    x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
+    x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
 ) where {TV}
     # `dot_assume` and `observe` with indexing
     s = TV(undef, length(x))
@@ -348,7 +350,9 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}
     return [@varname(s), @varname(m)]
 end
 
-@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
+@model function demo_dot_assume_observe_index_literal(
+    (::Type{TV})=Vector{Float64}
+) where {TV}
     # `dot_assume` and literal `observe` with indexing
     s = TV(undef, 2)
     m = TV(undef, 2)
@@ -425,7 +429,7 @@ function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
 end
 
 # Only used as a submodel
-@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
+@model function _prior_dot_assume((::Type{TV})=Vector{Float64}) where {TV}
     s = TV(undef, 2)
     s .~ InverseGamma(2, 3)
     m = TV(undef, 2)
@@ -466,7 +470,7 @@ end
 end
 
 @model function demo_dot_assume_observe_submodel(
-    x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
+    x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
 ) where {TV}
     s = TV(undef, length(x))
     s .~ InverseGamma(2, 3)
@@ -496,7 +500,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)})
 end
 
 @model function demo_dot_assume_observe_matrix_index(
-    x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64}
+    x=transpose([1.5 2.0;]), (::Type{TV})=Vector{Float64}
 ) where {TV}
     s = TV(undef, length(x))
     s .~ InverseGamma(2, 3)
@@ -525,7 +529,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)})
 end
 
 @model function demo_assume_matrix_observe_matrix_index(
-    x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
+    x=transpose([1.5 2.0;]), (::Type{TV})=Array{Float64}
 ) where {TV}
     n = length(x)
     d = n ÷ 2
diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl
index 6a655ded4..f55160ce1 100644
--- a/src/test_utils/varinfo.jl
+++ b/src/test_utils/varinfo.jl
@@ -58,7 +58,7 @@ function setup_varinfos(
         svi_vnv_ref,
     )) do vi
         # Set them all to the same values.
-        DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
+        return DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
     end
 
     if include_threadsafe
diff --git a/src/utils.jl b/src/utils.jl
index 50f9baf61..2412ce5a5 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -710,7 +710,7 @@ ERROR: Could not find x.a[2] in x.a[1]
 function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
     _, child, issuccess = splitoptic(getoptic(vn_child)) do optic
         o = optic === nothing ? identity : optic
-        VarName(vn_child, o) == vn_parent
+        return VarName(vn_child, o) == vn_parent
     end
 
     issuccess || error("Could not find $vn_parent in $vn_child")
@@ -905,7 +905,7 @@ function hasvalue(vals::AbstractDict, vn::VarName)
     # If `issuccess` is `true`, we found such a split, and hence `vn` is present.
     parent, child, issuccess = splitoptic(getoptic(vn)) do optic
         o = optic === nothing ? identity : optic
-        haskey(vals, VarName(vn, o))
+        return haskey(vals, VarName(vn, o))
     end
     # When combined with `VarInfo`, `nothing` is equivalent to `identity`.
     keyoptic = parent === nothing ? identity : parent
@@ -934,7 +934,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
     # Split the optic into the key / `parent` and the extraction optic / `child`.
     parent, child, issuccess = splitoptic(getoptic(vn)) do optic
         o = optic === nothing ? identity : optic
-        haskey(values, VarName(vn, o))
+        return haskey(values, VarName(vn, o))
     end
     # When combined with `VarInfo`, `nothing` is equivalent to `identity`.
     keyoptic = parent === nothing ? identity : parent
@@ -1078,7 +1078,7 @@ end
 function varname_leaves(vn::VarName, val::NamedTuple)
     iter = Iterators.map(keys(val)) do sym
         optic = Accessors.PropertyLens{sym}()
-        varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
+        return varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val))
     end
     return Iterators.flatten(iter)
 end
@@ -1244,7 +1244,7 @@ end
 function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
     iter = Iterators.map(keys(val)) do sym
         optic = DynamicPPL.Accessors.PropertyLens{sym}()
-        varname_and_value_leaves_inner(
+        return varname_and_value_leaves_inner(
             VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)
         )
     end
diff --git a/src/varinfo.jl b/src/varinfo.jl
index 0c033e504..c48195708 100644
--- a/src/varinfo.jl
+++ b/src/varinfo.jl
@@ -311,7 +311,7 @@ function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
     vns_syms = Set(unique(map(getsym, vns)))
     syms = filter(Base.Fix2(in, vns_syms), keys(metadata))
     metadatas = map(syms) do sym
-        subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns))
+        return subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns))
     end
     return NamedTuple{syms}(metadatas)
 end
@@ -327,7 +327,7 @@ end
         # TODO(mhauru) Note that this could still generate an empty metadata object if none
         # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
         # emptiness would make this type unstable again.
-        :((; $sym=subset(metadata.$sym, vns)))
+        :((; ($sym)=subset(metadata.$sym, vns)))
     else
         :(NamedTuple{}())
     end
@@ -708,8 +708,9 @@ findinds(vnv::VarNamedVector) = 1:length(vnv.varnames)
 
 Return a `NamedTuple` of the variables in `vi` grouped by symbol.
 """
-all_varnames_grouped_by_symbol(vi::TypedVarInfo) =
-    all_varnames_grouped_by_symbol(vi.metadata)
+function all_varnames_grouped_by_symbol(vi::TypedVarInfo)
+    return all_varnames_grouped_by_symbol(vi.metadata)
+end
 
 @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names}
     expr = Expr(:tuple)
@@ -981,25 +982,22 @@ end
         if !(f in vns_names)
             continue
         end
-        push!(
-            expr.args,
-            quote
-                f_vns = vi.metadata.$f.vns
-                f_vns = filter_subsumed(vns.$f, f_vns)
-                if !isempty(f_vns)
-                    if !istrans(vi, f_vns[1])
-                        # Iterate over all `f_vns` and transform
-                        for vn in f_vns
-                            f = internal_to_linked_internal_transform(vi, vn)
-                            _inner_transform!(vi, vn, f)
-                            settrans!!(vi, true, vn)
-                        end
-                    else
-                        @warn("[DynamicPPL] attempt to link a linked vi")
+        push!(expr.args, quote
+            f_vns = vi.metadata.$f.vns
+            f_vns = filter_subsumed(vns.$f, f_vns)
+            if !isempty(f_vns)
+                if !istrans(vi, f_vns[1])
+                    # Iterate over all `f_vns` and transform
+                    for vn in f_vns
+                        f = internal_to_linked_internal_transform(vi, vn)
+                        _inner_transform!(vi, vn, f)
+                        settrans!!(vi, true, vn)
                     end
+                else
+                    @warn("[DynamicPPL] attempt to link a linked vi")
                 end
-            end,
-        )
+            end
+        end)
     end
     return expr
 end
@@ -1085,23 +1083,20 @@ end
             continue
         end
 
-        push!(
-            expr.args,
-            quote
-                f_vns = vi.metadata.$f.vns
-                f_vns = filter_subsumed(vns.$f, f_vns)
-                if istrans(vi, f_vns[1])
-                    # Iterate over all `f_vns` and transform
-                    for vn in f_vns
-                        f = linked_internal_to_internal_transform(vi, vn)
-                        _inner_transform!(vi, vn, f)
-                        settrans!!(vi, false, vn)
-                    end
-                else
-                    @warn("[DynamicPPL] attempt to invlink an invlinked vi")
+        push!(expr.args, quote
+            f_vns = vi.metadata.$f.vns
+            f_vns = filter_subsumed(vns.$f, f_vns)
+            if istrans(vi, f_vns[1])
+                # Iterate over all `f_vns` and transform
+                for vn in f_vns
+                    f = linked_internal_to_internal_transform(vi, vn)
+                    _inner_transform!(vi, vn, f)
+                    settrans!!(vi, false, vn)
                 end
-            end,
-        )
+            else
+                @warn("[DynamicPPL] attempt to invlink an invlinked vi")
+            end
+        end)
     end
     return expr
 end
@@ -1496,7 +1491,7 @@ end
 function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution)
     @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo"
     vals_linked = mapreduce(vcat, vns) do vn
-        getindex(vi, vn, dist)
+        return getindex(vi, vn, dist)
     end
     return recombine(dist, vals_linked, length(vns))
 end
@@ -1542,7 +1537,7 @@ Check whether `vn` has a value in `vi`.
 Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn)
 function Base.haskey(vi::TypedVarInfo, vn::VarName)
     md_haskey = map(vi.metadata) do metadata
-        haskey(metadata, vn)
+        return haskey(metadata, vn)
     end
     return any(md_haskey)
 end
@@ -1774,23 +1769,20 @@ end
         f_idcs = :(idcs.$f)
         f_orders = :(metadata.$f.orders)
         f_flags = :(metadata.$f.flags)
-        push!(
-            expr.args,
-            quote
-                # Set the flag for variables with symbol `f`
-                if num_produce == 0
-                    for i in length($f_idcs):-1:1
-                        $f_flags["del"][$f_idcs[i]] = true
-                    end
-                else
-                    for i in 1:length($f_orders)
-                        if i in $f_idcs && $f_orders[i] > num_produce
-                            $f_flags["del"][i] = true
-                        end
+        push!(expr.args, quote
+            # Set the flag for variables with symbol `f`
+            if num_produce == 0
+                for i in length($f_idcs):-1:1
+                    $f_flags["del"][$f_idcs[i]] = true
+                end
+            else
+                for i in 1:length($f_orders)
+                    if i in $f_idcs && $f_orders[i] > num_produce
+                        $f_flags["del"][i] = true
                     end
                 end
-            end,
-        )
+            end
+        end)
     end
     return expr
 end
@@ -1831,7 +1823,7 @@ end
     kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys
 ) where {names}
     updates = map(names) do n
-        quote
+        return quote
             for vn in Base.keys(metadata.$n)
                 indices_found = kernel!(vi, vn, values, keys_strings)
                 if indices_found !== nothing
@@ -1863,7 +1855,7 @@ function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys)
     string_vns = map(string, collect_maybe(Base.keys(vi)))
     # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`.
     missing_keys = filter(keys) do key
-        !any(Base.Fix2(subsumes_string, key), string_vns)
+        return !any(Base.Fix2(subsumes_string, key), string_vns)
     end
 
     return missing_keys
diff --git a/test/ad.jl b/test/ad.jl
index a4f3dbfa7..76ba28281 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -71,7 +71,7 @@ using DynamicPPL: LogDensityFunction
         t = 1:0.05:8
         σ = 0.3
         y = @. rand(sin(t) + Normal(0, σ))
-        @model function state_space(y, TT, ::Type{T}=Float64) where {T}
+        @model function state_space(y, TT, (::Type{T})=Float64) where {T}
             # Priors
             α ~ Normal(y[1], 0.001)
             τ ~ Exponential(1)
@@ -94,9 +94,11 @@ using DynamicPPL: LogDensityFunction
         # overload assume so that model evaluation doesn't fail due to a lack
         # of implementation
         struct MyEmptyAlg end
-        DynamicPPL.assume(
+        function DynamicPPL.assume(
             ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
-        ) = DynamicPPL.assume(dist, vn, vi)
+        )
+            return DynamicPPL.assume(dist, vn, vi)
+        end
 
         # Compiling the ReverseDiff tape used to fail here
         spl = Sampler(MyEmptyAlg())
@@ -117,7 +119,7 @@ using DynamicPPL: LogDensityFunction
             return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
         end
 
-        @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
+        @model function scalar_matrix_model((::Type{T})=Float64) where {T<:Real}
             m = Matrix{T}(undef, 2, 3)
             return m ~ filldist(MvNormal(zeros(2), I), 3)
         end
@@ -126,14 +128,14 @@ using DynamicPPL: LogDensityFunction
             scalar_matrix_model, test_m, ref_adtype
         )
 
-        @model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
+        @model function matrix_model((::Type{T})=Matrix{Float64}) where {T}
             m = T(undef, 2, 3)
             return m ~ filldist(MvNormal(zeros(2), I), 3)
         end
 
         matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
 
-        @model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
+        @model function scalar_array_model((::Type{T})=Float64) where {T<:Real}
             m = Array{T}(undef, 2, 3)
             return m ~ filldist(MvNormal(zeros(2), I), 3)
         end
@@ -142,7 +144,7 @@ using DynamicPPL: LogDensityFunction
             scalar_array_model, test_m, ref_adtype
         )
 
-        @model function array_model(::Type{T}=Array{Float64}) where {T}
+        @model function array_model((::Type{T})=Array{Float64}) where {T}
             m = T(undef, 2, 3)
             return m ~ filldist(MvNormal(zeros(2), I), 3)
         end
diff --git a/test/compiler.jl b/test/compiler.jl
index 3d3c6d9e3..e44234b14 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -1,6 +1,6 @@
 macro custom(expr)
     (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage")
-    quote
+    return quote
         $(esc(expr.args[2])) = 0.0
     end
 end
@@ -487,7 +487,7 @@ module Issue537 end
         @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10
 
         # AR1 model. Dynamic prefixing.
-        @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
+        @model function AR1(num_steps, α, μ, σ, (::Type{TV})=Vector{Float64}) where {TV}
             η ~ MvNormal(zeros(num_steps), I)
             δ = sqrt(1 - α^2)
             x = TV(undef, num_steps)
@@ -678,7 +678,7 @@ module Issue537 end
     end
 
     @testset "issue #393: anonymous argument with type parameter" begin
-        @model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1
+        @model f_393((::Val{ispredict})=Val(false)) where {ispredict} = ispredict ? 0 : 1
         @test f_393()() == 1
         @test f_393(Val(true))() == 0
     end
@@ -759,7 +759,7 @@ module Issue537 end
 
     @testset "signature parsing + TypeWrap" begin
         @model function demo_typewrap(
-            a, b=1, ::Type{T1}=Float64; c, d=2, t::Type{T2}=Int
+            a, b=1, (::Type{T1})=Float64; c, d=2, t::Type{T2}=Int
         ) where {T1,T2}
             return (; a, b, c, d, t)
         end
diff --git a/test/linking.jl b/test/linking.jl
index d424a9c2d..4f4c23b16 100644
--- a/test/linking.jl
+++ b/test/linking.jl
@@ -182,10 +182,9 @@ end
         @model function demo_highdim_dirichlet(ns...)
             return x ~ filldist(Dirichlet(ones(2)), ns...)
         end
-        @testset "ns=$ns" for ns in [
-            (3,),
-            # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304
-            # (3, 4), (3, 4, 5)
+        @testset "ns=$ns" for ns in [(3,),
+        # TODO: Uncomment once we have https://github.com/TuringLang/Bijectors.jl/pull/304
+        # (3, 4), (3, 4, 5)
         ]
             model = demo_highdim_dirichlet(ns...)
             example_values = rand(NamedTuple, model)
diff --git a/test/lkj.jl b/test/lkj.jl
index d581cd21b..18de427c0 100644
--- a/test/lkj.jl
+++ b/test/lkj.jl
@@ -43,7 +43,7 @@ end
         # Build correlation matrix from factor
         corr_matrices = map(samples) do s
             M = reshape(s.metadata.vals, (2, 2))
-            pd_from_triangular(M, uplo)
+            return pd_from_triangular(M, uplo)
         end
         @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol
     end
@@ -54,7 +54,7 @@ end
         # Build correlation matrix from factor
         corr_matrices = map(samples) do s
             M = reshape(s.metadata.vals, (2, 2))
-            pd_from_triangular(M, uplo)
+            return pd_from_triangular(M, uplo)
         end
         @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol
     end
diff --git a/test/model.jl b/test/model.jl
index a863b6596..459a7665c 100644
--- a/test/model.jl
+++ b/test/model.jl
@@ -353,7 +353,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
         )
         vns = map(first, first(vns_and_vals_xs))
         vals = map(vns_and_vals_xs) do vns_and_vals
-            map(last, vns_and_vals)
+            return map(last, vns_and_vals)
         end
 
         # Construct the chain.
@@ -376,7 +376,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
             d
         end
         vals_with_extra = map(enumerate(vals)) do (i, v)
-            vcat(v, i)
+            return vcat(v, i)
         end
         chain_with_extra = MCMCChains.Chains(
             permutedims(stack(vals_with_extra)),
diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl
index 8e48814a4..a0ce0aa94 100644
--- a/test/simple_varinfo.jl
+++ b/test/simple_varinfo.jl
@@ -87,6 +87,7 @@
 
     @testset "link!! & invlink!! on $(nameof(model))" for model in
                                                           DynamicPPL.TestUtils.DEMO_MODELS
+
         values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
         @testset "$(typeof(vi))" for vi in (
             SimpleVarInfo(Dict()),
diff --git a/test/test_util.jl b/test/test_util.jl
index 87c69b5fe..41daf8615 100644
--- a/test/test_util.jl
+++ b/test/test_util.jl
@@ -33,8 +33,9 @@ end
 
 Return string representing a short description of `vi`.
 """
-short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
-    "threadsafe($(short_varinfo_name(vi.varinfo)))"
+function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo)
+    return "threadsafe($(short_varinfo_name(vi.varinfo)))"
+end
 function short_varinfo_name(vi::TypedVarInfo)
     DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
     return "TypedVarInfo"
@@ -91,7 +92,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I
         iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
         tuples = mapreduce(collect, vcat, iters)
         push!(varnames, map(first, tuples)...)
-        OrderedDict(tuples)
+        return OrderedDict(tuples)
     end
     # Convert back to list
     varnames = collect(varnames)
diff --git a/test/varinfo.jl b/test/varinfo.jl
index 74feb42f6..0e47cc30b 100644
--- a/test/varinfo.jl
+++ b/test/varinfo.jl
@@ -184,7 +184,7 @@ end
             return x ~ MvNormal(m, s^2 * I)
         end
 
-        @model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV}
+        @model function testmodel_univariate(x, (::Type{TV})=Vector{Float64}) where {TV}
             n = length(x)
             s ~ truncated(Normal(), 0, Inf)
 
@@ -374,7 +374,7 @@ end
     end
 
     @testset "link!! and invlink!!" begin
-        @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin
+        @model gdemo(a, b, (::Type{T})=Float64) where {T} = begin
             s ~ InverseGamma(2, 3)
             m ~ Uniform(0, 2)
             x = Vector{T}(undef, length(a))
@@ -534,7 +534,9 @@ end
                     vals = values_as(vi, OrderedDict)
                     # All varnames in `vns` should be subsumed by one of `keys(vals)`.
                     @test all(vns) do vn
-                        any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals))
+                        return any(
+                            DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)
+                        )
                     end
                     # Iterate over `keys(vals)` because we might have scenarios such as
                     # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is
@@ -624,7 +626,7 @@ end
     end
 
     @testset "subset" begin
-        @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV}
+        @model function demo_subsetting_varinfo((::Type{TV})=Vector{Float64}) where {TV}
             s ~ InverseGamma(2, 3)
             m ~ Normal(0, sqrt(s))
             x = TV(undef, 2)
@@ -691,12 +693,14 @@ end
 
             @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
                                                                            vns_supported
+
                 varinfo_subset = subset(varinfo, VarName[])
                 @test isempty(varinfo_subset)
             end
 
             @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
                                                                    vns_supported
+
                 varinfo_subset = subset(varinfo, vns_subset)
                 # Should now only contain the variables in `vns_subset`.
                 check_varinfo_keys(varinfo_subset, vns_subset)
@@ -715,6 +719,7 @@ end
             @testset "$(convert(Vector{VarName}, vns_subset))" for (
                 vns_subset, vns_target
             ) in vns_supported_with_subsumes
+
                 varinfo_subset = subset(varinfo, vns_subset)
                 # Should now only contain the variables in `vns_subset`.
                 check_varinfo_keys(varinfo_subset, vns_target)
@@ -732,6 +737,7 @@ end
 
             @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in
                                                                          vns_supported
+
                 varinfo_subset = subset(varinfo, vns_subset)
                 vns_subset_reversed = reverse(vns_subset)
                 varinfo_subset_reversed = subset(varinfo, vns_subset_reversed)
diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl
index bd3f5553f..1c40bc2d0 100644
--- a/test/varnamedvector.jl
+++ b/test/varnamedvector.jl
@@ -14,7 +14,7 @@ function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, v
         # we need `vn` to also be of this type.
         # => If the varname types don't match, we need to relax the container type.
         return any(keys(vnv)) do vn_present
-            typeof(vn_present) !== typeof(val)
+            return typeof(vn_present) !== typeof(val)
         end
     end
 
@@ -40,7 +40,7 @@ function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName,
         # If the container is concrete, we need to make sure that the sizes match.
         # => If the sizes don't match, we need to relax the container type.
         any(keys(vnv)) do vn_present
-            size(vnv[vn_present]) != size(val)
+            return size(vnv[vn_present]) != size(val)
         end
     elseif eltype(vnv.transforms) !== Any
         # If it's not concrete AND it's not `Any`, then we should just make it `Any`.
@@ -619,7 +619,7 @@ end
             @test getlogp(varinfo_sample) != getlogp(varinfo)
             # Values should be different.
             DynamicPPL.TestUtils.test_values(
-                varinfo_sample, value_true, vns; compare=!isequal
+                varinfo_sample, value_true, vns; compare=(!isequal)
             )
         end
     end