Skip to content

Commit

Permalink
shorter kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye committed Feb 8, 2025
1 parent 24a4e13 commit 991b2d3
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 70 deletions.
2 changes: 1 addition & 1 deletion ext/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Creates and ModelingToolkit.NonlinearProblem from a DifferentialEquation.
function ModelingToolkit.NonlinearProblem(
eom::HarmonicEquation, u0, p::AbstractDict; in_place=true, kwargs...
)
ss_prob = SteadyStateProblem(eom, u0, p::AbstractDict; in_place=in_place, kwargs...)
ss_prob = SteadyStateProblem(eom, u0, p::AbstractDict; in_place, kwargs...)
return NonlinearProblem(ss_prob)
end

Expand Down
8 changes: 3 additions & 5 deletions ext/PlotsExt/linear_response.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ function HarmonicBalance.plot_linear_response(
X = collect(values(res.swept_parameters))[1][stable]

C = if order == 1
get_jacobian_response(res, nat_var, Ω_range, branch; show_progress=show_progress)
get_jacobian_response(res, nat_var, Ω_range, branch; show_progress)
else
get_linear_response(
res, nat_var, Ω_range, branch; order=order, show_progress=show_progress
)
get_linear_response(res, nat_var, Ω_range, branch; order=order, show_progress)
end
C = logscale ? log.(C) : C

Expand Down Expand Up @@ -124,7 +122,7 @@ function HarmonicBalance.plot_rotframe_jacobian_response(
X = Vector{P}(collect(values(res.swept_parameters))[1][stable])

C = get_rotframe_jacobian_response(
res, Ω_range, branch; show_progress=show_progress, damping_mod=damping_mod
res, Ω_range, branch; show_progress, damping_mod=damping_mod
)
C = logscale ? log.(C) : C

Expand Down
27 changes: 7 additions & 20 deletions ext/PlotsExt/steady_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function plot1D(
if class == "default"
args = [:x => x, :y => y, :branches => branches]
if not_class == [] # plot stable full, unstable dashed
p = plot1D(res; args..., class=["physical", "stable"], add=add, kwargs...)
p = plot1D(res; args..., class=["physical", "stable"], add, kwargs...)
plot1D(
res;
args...,
Expand All @@ -89,9 +89,7 @@ function plot1D(
)
return p
else
p = plot1D(
res; args..., not_class=not_class, class="physical", add=add, kwargs...
)
p = plot1D(res; args..., not_class, class="physical", add, kwargs...)
return p
end
end
Expand Down Expand Up @@ -170,9 +168,7 @@ function plot2D_cut(
)
if class == "default"
if not_class == [] # plot stable full, unstable dashed
p = plot2D_cut(
res; y=y, cut=cut, class=["physical", "stable"], add=add, kwargs...
)
p = plot2D_cut(res; y=y, cut=cut, class=["physical", "stable"], add, kwargs...)
plot2D_cut(
res;
y=y,
Expand All @@ -185,9 +181,7 @@ function plot2D_cut(
)
return p
else
p = plot2D_cut(
res; y=y, cut=cut, not_class=not_class, class="physical", add=add, kwargs...
)
p = plot2D_cut(res; y=y, cut=cut, not_class, class="physical", add, kwargs...)
return p
end
end
Expand Down Expand Up @@ -270,7 +264,7 @@ function HarmonicBalance.plot_phase_diagram(res::Result{D}; kwargs...)::Plots.Pl
end

function HarmonicBalance.plot_phase_diagram(res::Result, class::String; kwargs...)
return plot_phase_diagram(res; class=class, kwargs...)
return plot_phase_diagram(res; class, kwargs...)
end

function plot_phase_diagram_2D(
Expand Down Expand Up @@ -343,7 +337,7 @@ function HarmonicBalance.plot_spaghetti(
if class == "default"
if not_class == [] # plot stable full, unstable dashed
p = HarmonicBalance.plot_spaghetti(
res; x=x, y=y, z=z, class=["physical", "stable"], add=add, kwargs...
res; x=x, y=y, z=z, class=["physical", "stable"], add, kwargs...
)
HarmonicBalance.plot_spaghetti(
res;
Expand All @@ -359,14 +353,7 @@ function HarmonicBalance.plot_spaghetti(
return p
else
p = HarmonicBalance.plot_spaghetti(
res;
x=x,
y=y,
z=z,
class="physical",
not_class=not_class,
add=add,
kwargs...,
res; x=x, y=y, z=z, class="physical", not_class, add, kwargs...
)
return p
end
Expand Down
2 changes: 1 addition & 1 deletion ext/PlotsExt/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function HarmonicBalance.plot_1D_solutions_branch(
not_class=[],
kwargs...,
)
p = plot(res; x=x, y=y, class=class, not_class=not_class, kwargs...)
p = plot(res; x=x, y=y, class, not_class, kwargs...)

followed_branch, Ys = HarmonicBalance.follow_branch(
starting_branch, res; y=y, sweep=sweep, tf=tf, ϵ=ϵ
Expand Down
14 changes: 6 additions & 8 deletions ext/SteadyStateDiffEqExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function HarmonicBalance.steady_state_sweep(
foreach(pairs(sweep_range)) do (i, value)
u0 = i == 1 ? [0.0, 0.0] : result[i - 1]
# make type-stable: FD.Dual or Float
parameters = get_new_parameters(prob, varied_idx, value)
sol = solve(remake(prob; p=parameters, u0=u0), alg; kwargs...)
p = get_new_parameters(prob, varied_idx, value)
sol = solve(remake(prob; p, u0), alg; kwargs...)
result[i] = sol.u
end
return result
Expand Down Expand Up @@ -58,19 +58,17 @@ function HarmonicBalance.steady_state_sweep(
foreach(pairs(sweep_range)) do (i, value)
u0 = i == 1 ? Base.zeros(length(prob_np.u0)) : result[i - 1]
# make type-stable: FD.Dual or Float
parameters = get_new_parameters(prob_np, varied_idx, value)
sol_nn = solve(remake(prob_np; p=parameters, u0=u0), alg_np; kwargs...)
p = get_new_parameters(prob_np, varied_idx, value)
sol_nn = solve(remake(prob_np; p, u0), alg_np; kwargs...)

# last argument is time but does not matter
param_val = tunable_parameters(parameters)
param_val = tunable_parameters(p)
zeros = norm(prob_np.f.f.f.f.f_oop(sol_nn.u, param_val, 0))
jac = prob_np.f.jac.f.f.f_oop(sol_nn.u, param_val, 0)
eigval = jac isa Vector ? jac : eigvals(jac) # eigvals favourable supports FD.Dual

if !isapprox(zeros, 0; atol=1e-5) || any-> λ > 0, real.(eigval))
sol_ss = solve(
remake(prob_ss; p=parameters, u0=u0), alg_ss; abstol=1e-5, reltol=1e-5
)
sol_ss = solve(remake(prob_ss; p, u0), alg_ss; abstol=1e-5, reltol=1e-5)
result[i] = sol_ss.u
else
result[i] = sol_nn.u
Expand Down
8 changes: 6 additions & 2 deletions ext/TimeEvolution/ODEProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ Return `true` the solution evolves within `tol` of the initial value (interprete
"""
function HarmonicBalance.is_stable(
soln::StateDict, eom::HarmonicEquation; timespan, tol=1e-1, perturb_initial=1e-3
steady_solution::StateDict,
eom::HarmonicEquation;
timespan,
tol=1e-1,
perturb_initial=1e-3,
)
problem = ODEProblem(eom; steady_solution=soln, timespan=timespan)
problem = ODEProblem(eom; steady_solution, timespan)
solution = solve(problem)
dist = norm(solution[end] - solution[1]) / (norm(solution[end]) + norm(solution[1]))
return if !is_real(solution[end]) || !is_real(solution[1])
Expand Down
22 changes: 2 additions & 20 deletions src/Jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,6 @@ function _compile_Jacobian(
return JacobianFunction(soltype)(compiled_J)
end

# function _compile_Jacobian(
# prob::Problem,
# soltype::DataType,
# swept_parameters::OrderedDict,
# fixed_parameters::OrderedDict,
# )::JacobianFunction(soltype)
# if "Hopf" ∈ getfield.(prob.eom.variables, :type)
# compiled_J = prob.jacobian
# elseif !hasnan(prob.jacobian)
# compiled_J = compile_matrix(
# prob.jacobian, _free_symbols(prob); rules=prob.fixed_parameters
# )
# else
# compiled_J = get_implicit_Jacobian(prob)
# end
# return JacobianFunction(soltype)(compiled_J)
# end

"""
Take a matrix containing symbolic variables `variables` and keys of `fixed_parameters`.
Substitute the values according to `fixed_parameters` and compile into a function that takes
Expand Down Expand Up @@ -132,8 +114,8 @@ avoiding huge symbolic operations.
Returns a function `f(soln::OrderedDict{Num,T})::Matrix{T}`.
"""
function get_implicit_Jacobian(eom::HarmonicEquation; sym_order, rules=Dict())
J0c = compile_matrix(_get_J_matrix(eom; order=0), sym_order; rules=rules)
J1c = compile_matrix(_get_J_matrix(eom; order=1), sym_order; rules=rules)
J0c = compile_matrix(_get_J_matrix(eom; order=0), sym_order; rules)
J1c = compile_matrix(_get_J_matrix(eom; order=1), sym_order; rules)
jacfunc(vals::Vector) = -inv(real.(J1c(vals))) * J0c(vals)
return jacfunc
end
Expand Down
2 changes: 1 addition & 1 deletion src/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classify_solutions!(res, "sqrt(u1^2 + v1^2) > 1.0" , "large_amplitude")
function classify_solutions!(
res::Result, func::Union{String,Function}, name::String; physical=true
)
values = classify_solutions(res, func; physical=physical)
values = classify_solutions(res, func; physical)
return res.classes[name] = values
end

Expand Down
2 changes: 1 addition & 1 deletion src/modules/LimitCycles/gauge_fixing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function _gaugefixed_Jacobian(
)
rules = Dict(rules)
setindex!(rules, 0, _remove_brackets(fixed_var))
jac = get_implicit_Jacobian(eom; rules=rules, sym_order=sym_order)
jac = get_implicit_Jacobian(eom; rules, sym_order=sym_order)
return JacobianFunction(soltype)(jac)
end

Expand Down
10 changes: 5 additions & 5 deletions src/solve_homotopy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function get_steady_states(
unique_fixed, input_array = _prepare_input_params(
prob, swept_parameters, fixed_parameters
)
solutions = get_solutions(prob, method, input_array; show_progress=show_progress)
solutions = get_solutions(prob, method, input_array; show_progress)

result = Result(
solutions,
Expand All @@ -89,7 +89,7 @@ function get_steady_states(
)

if sorting != "no_sorting"
sort_solutions!(result; sorting=sorting, show_progress=show_progress)
sort_solutions!(result; sorting, show_progress)
end
classify_default ? _classify_default!(result) : nothing

Expand Down Expand Up @@ -125,7 +125,7 @@ function get_steady_states(eom::HarmonicEquation, swept, fixed; kwargs...)
end

function get_solutions(prob, method, input_array; show_progress)
raw = _get_raw_solution(prob, method, input_array; show_progress=show_progress)
raw = _get_raw_solution(prob, method, input_array; show_progress)

solutions = HC.solutions.(getindex.(raw, 1))
if all(isempty.(solutions))
Expand Down Expand Up @@ -181,7 +181,7 @@ function _get_raw_solution(
problem.system;
start_system=method_symbol(warm_up_method),
target_parameters=start_parameters,
show_progress=show_progress,
show_progress,
alg_default_options(warm_up_method)...,
alg_specific_options(warm_up_method)...,
)
Expand All @@ -191,7 +191,7 @@ function _get_raw_solution(
HC.solutions(warmup_solution);
start_parameters=start_parameters,
target_parameters=parameter_values,
show_progress=show_progress,
show_progress,
alg_default_options(method)...,
)

Expand Down
8 changes: 3 additions & 5 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function sort_solutions(
error("Only the following sorting options are allowed: ", sorting_schemes)
sorting == "none" && return solutions
l = length(size(solutions))
l == 1 && return sort_1D(solutions; show_progress=show_progress)
l == 2 && return sort_2D(solutions; sorting=sorting, show_progress=show_progress)
l == 1 && return sort_1D(solutions; show_progress)
l == 2 && return sort_2D(solutions; sorting, show_progress)
return error("do not know how to solve solution which are not 1D or 2D")
end

Expand All @@ -34,9 +34,7 @@ specifies the method used to get continuous solution branches. Options are `"hil
and `"none"`. The `show_progress` keyword argument indicates whether a progress bar should be displayed.
"""
function sort_solutions!(res::Result; sorting="nearest", show_progress=true)
return res.solutions .= sort_solutions(
res.solutions; sorting=sorting, show_progress=show_progress
)
return res.solutions .= sort_solutions(res.solutions; sorting, show_progress)
end

#####
Expand Down
2 changes: 1 addition & 1 deletion src/transform_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end
function transform_solutions(res::Result, f::String; rules=Dict(), kwargs...)
# a string is used as input
# a macro would not "see" the user's namespace while the user's namespace does not "see" the variables
func = _build_substituted(f, res; rules=rules)
func = _build_substituted(f, res; rules)
return transform_solutions(res, func; kwargs...)
end

Expand Down

0 comments on commit 991b2d3

Please sign in to comment.