Skip to content

Commit

Permalink
Pass rng so that Global RNG is not reset
Browse files Browse the repository at this point in the history
  • Loading branch information
rejuvyesh committed Aug 9, 2021
1 parent b860ba4 commit 1663709
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 29 deletions.
18 changes: 10 additions & 8 deletions src/problem/solver_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ function newton_solve!(prob::GameProblem{KN,n,m,T,SVd,SVx}) where {KN,n,m,T,SVd,
opts = prob.opts

# Set initial trajectory
Random.seed!(opts.seed)
init_traj!(prob.pdtraj; x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.pdtraj_trial; x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.Δpdtraj; x0=prob.x0, f=zeros, amplitude=0.0)
#Random.seed!(opts.seed)
rng = MersenneTwister(opts.seed)
init_traj!(prob.pdtraj; rng=rng, x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.pdtraj_trial; rng=rng, x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.Δpdtraj; rng=rng, x0=prob.x0, f=(rng, args)->zeros(args), amplitude=0.0)

rollout!(RK3, prob.model, prob.pdtraj.pr)
rollout!(RK3, prob.model, prob.pdtraj_trial.pr)
Expand Down Expand Up @@ -138,10 +139,11 @@ function ibr_newton_solve!(prob::GameProblem{KN,n,m,T,SVd,SVx};
# Reset Statistics
reset!(prob.stats)
# Set initial trajectory
Random.seed!(opts.seed)
init_traj!(prob.pdtraj; x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.pdtraj_trial; x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.Δpdtraj; x0=prob.x0, f=zeros, amplitude=0.0)
#Random.seed!(opts.seed)
rng = MersenneTwister(opts.seed)
init_traj!(prob.pdtraj; rng=rng, x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.pdtraj_trial; rng=rng, x0=prob.x0, f=opts.f_init, amplitude=opts.amplitude_init, s=opts.shift)
init_traj!(prob.Δpdtraj; rng=rng, x0=prob.x0, f=(rng, args)->zeros(args), amplitude=0.0)

rollout!(RK3, prob.model, prob.pdtraj.pr)
rollout!(RK3, prob.model, prob.pdtraj_trial.pr)
Expand Down
16 changes: 10 additions & 6 deletions src/struct/primal_dual_traj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ mutable struct PrimalDualTraj{KN,n,m,T,SVd}
du::Vector{SVd} # Dual trajectory
end

function PrimalDualTraj(probsize::ProblemSize, dt::T; f=rand, amplitude=1e-8) where {T}
function PrimalDualTraj(probsize::ProblemSize, dt::T; f=rand, amplitude=1e-8, rng=Random.GLOBAL_RNG) where {T}
N = probsize.N
n = probsize.n
m = probsize.m
p = probsize.p
pr = Traj(n,m,dt,N)
du = [[amplitude*f(SVector{n,T}) for k=1:N-1] for i=1:p]
du = [[amplitude*f(rng, SVector{n,T}) for k=1:N-1] for i=1:p]
for k = 1:N
pr[k].z = amplitude*f(n+m)
pr[k].z = amplitude*f(rng, n+m)
end
TYPE = (eltype(pr),n,m,T,eltype(du))
return PrimalDualTraj{TYPE...}(probsize, pr, du)
Expand All @@ -26,17 +26,21 @@ end
# Methods
################################################################################

function init_traj!(pdtraj::PrimalDualTraj{KN,n,m,T,SVd}; x0=1e-8*rand(SVector{n,T}),
"""
- `f`: Always takes two arguments (rng, TYPE).
"""
function init_traj!(pdtraj::PrimalDualTraj{KN,n,m,T,SVd}; rng=Random.GLOBAL_RNG, x0=1e-8*rand(rng, SVector{n,T}),
s::Int=2^10, f=rand, amplitude=1e-8) where {KN,n,m,T,SVd}
N = pdtraj.probsize.N
p = pdtraj.probsize.p

for k = 1:N
pdtraj.pr[k].z = (k+s<=N) ? pdtraj.pr[k+s].z : amplitude*f(SVector{n+m,T})
pdtraj.pr[k].z = (k+s<=N) ? pdtraj.pr[k+s].z : amplitude*f(rng, SVector{n+m,T})
end
for i = 1:p
for k = 1:N-1
pdtraj.du[i][k] = (k+s<=N-1) ? pdtraj.du[i][k+s] : amplitude*f(SVector{n,T})
pdtraj.du[i][k] = (k+s<=N-1) ? pdtraj.du[i][k+s] : amplitude*f(rng, SVector{n,T})
end
end
RobotDynamics.set_state!(pdtraj.pr[1], x0)
Expand Down
2 changes: 1 addition & 1 deletion test/constraints/constraint_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
dt = 0.1
model = UnicycleGame(p=3)
probsize = ProblemSize(N,model)
pdtraj = PrimalDualTraj(probsize, dt, f=ones, amplitude=0.1)
pdtraj = PrimalDualTraj(probsize, dt, f=(rng,args)->ones(args), amplitude=0.1)
game_con = GameConstraintValues(probsize)

u_max = ones(SVector{model.m,T})
Expand Down
4 changes: 2 additions & 2 deletions test/constraints/constraints_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@
@test game_con.state_conval[1][1].λ[1] == 0.0*ones(P)

pdtraj = PrimalDualTraj(probsize, dt)
init_traj!(pdtraj, f=ones, amplitude=1e2)
init_traj!(pdtraj, f=(rng,args)->ones(args), amplitude=1e2)
evaluate!(game_con, pdtraj.pr)
@test game_con.control_conval[1].vals[1] == [90*ones(model.m); -110*ones(model.m)]
dual_update!(game_con)
@test game_con.control_conval[1].λ[1] == 1e-3 * [90*ones(model.m); 0*ones(model.m)]
@test game_con.state_conval[1][1].λ[1] == 1e-3 * max.(0, game_con.state_conval[1][1].vals[1])

init_traj!(pdtraj, f=ones, amplitude=1e5)
init_traj!(pdtraj, f=(rng,args)->ones(args), amplitude=1e5)
evaluate!(game_con, pdtraj.pr)
@test game_con.control_conval[1].vals[1] == [(1e5-10)*ones(model.m); -(1e5+10)*ones(model.m)]
dual_update!(game_con)
Expand Down
2 changes: 1 addition & 1 deletion test/constraints/velocity_constraint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
dt = 0.1
model = UnicycleGame(p=3)
probsize = ProblemSize(N,model)
pdtraj = PrimalDualTraj(probsize, dt, f=ones, amplitude=0.1)
pdtraj = PrimalDualTraj(probsize, dt, f=(rng,args)->ones(args), amplitude=0.1)
game_con = GameConstraintValues(probsize)

v_max = ones(model.p)
Expand Down
14 changes: 7 additions & 7 deletions test/struct/primal_dual_traj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Test init_traj!
T = Float64
x0 = rand(SVector{model.n,T})
init_traj!(pdtraj, x0=x0, f=ones, amplitude=10.0)
init_traj!(pdtraj, x0=x0, f=(rng, args) -> ones(args), amplitude=10.0)
@test state(pdtraj.pr[1]) == x0
@test state(pdtraj.pr[2]) == 10*ones(model.n)
@test control(pdtraj.pr[1]) == 10*ones(model.m)
Expand All @@ -33,7 +33,7 @@
x0 = rand(SVector{model.n,T})
Δpdtraj = PrimalDualTraj(probsize, dt)
Δtraj = ones(n*(N-1)+m*(N-1)+n*p*(N-1))
init_traj!(Δpdtraj, x0=x0, f=ones, amplitude=10.0)
init_traj!(Δpdtraj, x0=x0, f=(rng, args)->ones(args), amplitude=10.0)
set_traj!(core, Δpdtraj, Δtraj)

@test state(Δpdtraj.pr[1]) == x0
Expand All @@ -46,7 +46,7 @@
@test Δpdtraj.du[end][end] == ones(n)

Δtraj = rand(n*(N-1)+m*(N-1)+n*p*(N-1))
init_traj!(Δpdtraj, x0=x0, f=ones, amplitude=10.0)
init_traj!(Δpdtraj, x0=x0, f=(rng,args)->ones(args), amplitude=10.0)
set_traj!(core, Δpdtraj, Δtraj)

@test state(Δpdtraj.pr[1]) == x0
Expand Down Expand Up @@ -91,9 +91,9 @@
target = PrimalDualTraj(probsize, dt)
source = PrimalDualTraj(probsize, dt)
Δ = PrimalDualTraj(probsize, dt)
init_traj!(target, x0=x0, f=ones, amplitude=0.0)
init_traj!(source, x0=x0, f=ones, amplitude=10.0)
init_traj!(Δ, x0=x0, f=ones, amplitude=100.0)
init_traj!(target, x0=x0, f=(rng,args)->ones(args), amplitude=0.0)
init_traj!(source, x0=x0, f=(rng, args)->ones(args), amplitude=10.0)
init_traj!(Δ, x0=x0, f=(rng,args)->ones(args), amplitude=100.0)
α = 0.5
update_traj!(target, source, α, Δ)

Expand All @@ -118,7 +118,7 @@
m = model.m
x0 = 1e3*ones(SVector{model.n,T})
Δ = PrimalDualTraj(probsize, dt)
init_traj!(pdtraj, x0=x0, f=ones, amplitude=10.0)
init_traj!(pdtraj, x0=x0, f=(rng,args)->ones(args), amplitude=10.0)
α = 0.5
@test Δ_step(pdtraj, α) == 10.0*α

Expand Down
8 changes: 4 additions & 4 deletions test/struct/violations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
pi = probsize.pz[i]
pdtraj = PrimalDualTraj(probsize, dt)

init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=zeros)
init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=(rng,args)->zeros(args))
dyn_vio = dynamics_violation(model, pdtraj)
@test dyn_vio.N == pdtraj.probsize.N
@test dyn_vio.vio == zeros(N-1)
@test dyn_vio.max == 0.0
init_traj!(pdtraj, x0=ones(SVector{model.n,T}), f=ones, amplitude=1.0)
init_traj!(pdtraj, x0=ones(SVector{model.n,T}), f=(rng,args)->ones(args), amplitude=1.0)
@test dynamics_violation(model, pdtraj).max - maximum(abs.(dynamics_residual(model, pdtraj, 1))) < 1e-10
@test dynamics_violation(model, pdtraj, i).max - maximum(abs.(dynamics_residual(model, pdtraj, 1)[pi])) < 1e-10

Expand All @@ -24,7 +24,7 @@
u_max = 0.1*ones(model.m)
u_min = -0.1*ones(model.m)
add_control_bound!(game_con, u_max, u_min)
init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=ones, amplitude=1.0)
init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=(rng,args)->ones(args), amplitude=1.0)
con_vio = control_violation(game_con, pdtraj)
@test con_vio.N == pdtraj.probsize.N
@test con_vio.vio == 0.9*ones(N-1)
Expand All @@ -38,7 +38,7 @@
game_con = GameConstraintValues(probsize)
walls = [Wall([0.,1], [1,0], [1,1]/sqrt(2))]
add_wall_constraint!(game_con, walls)
init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=ones, amplitude=1.0)
init_traj!(pdtraj, x0=zeros(SVector{model.n,T}), f=(rng,args)->ones(args), amplitude=1.0)
sta_vio = state_violation(game_con, pdtraj)
@test sta_vio.N == pdtraj.probsize.N
@test norm(sta_vio.vio - [0; sqrt(2)/2*ones(N-1)...], 1) <= 1e-10
Expand Down

0 comments on commit 1663709

Please sign in to comment.