Skip to content

Commit

Permalink
Merge pull request #492 from JuliaRobotics/master
Browse files Browse the repository at this point in the history
fast forward feature branch
  • Loading branch information
dehann authored Dec 19, 2019
2 parents 48545eb + b339a5e commit dbefd2d
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 15 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.8.2"
[deps]
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
42 changes: 36 additions & 6 deletions src/CliqStateMachine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ Notes
"""
function determineCliqNeedDownMsg_StateMachine(csmc::CliqStateMachineContainer)

infocsm(csmc, "7, start, forceproceed=$(csmc.forceproceed)")

# fetch children status
stdict = blockCliqUntilChildrenHaveUpStatus(csmc.tree, csmc.cliq, csmc.logger)

Expand All @@ -510,20 +512,32 @@ function determineCliqNeedDownMsg_StateMachine(csmc::CliqStateMachineContainer)
# :needdownmsg # 'send' downward init msg direction
!(clst in [:initialized;:upsolved;:marginalized;:downsolved;:uprecycled]) ? (proceed = false) : nothing
end
infocsm(csmc, "7, proceed=$(proceed), forceproceed=$(csmc.forceproceed)")
infocsm(csmc, "7, proceed=$(proceed)")


if proceed || csmc.forceproceed
# TODO, remove csmc.forceproceed
csmc.forceproceed = false
# return doCliqInferAttempt_StateMachine
cliqst = getCliqStatus(csmc.cliq)
infocsm(csmc, "7, status=$(cliqst), before attemptCliqInitDown_StateMachine")
# d1,d2,cliqst = doCliqInitUpOrDown!(csmc.cliqSubFg, csmc.tree, csmc.cliq, isprntnddw)
if cliqst == :needdownmsg && !isCliqParentNeedDownMsg(csmc.tree, csmc.cliq, csmc.logger)
# go to 8a
return attemptCliqInitDown_StateMachine
# HALF DUPLICATED IN STEP 4
elseif cliqst == :marginalized
# go to 1
return isCliqUpSolved_StateMachine
## NOTE -- what about notifyCliqUpInitStatus! ??
# go to 10
# return determineCliqIfDownSolve_StateMachine
end

# go to 8b
return attemptCliqInitUp_StateMachine
else
# go to 7b
return slowCliqIfChildrenNotUpsolved_StateMachine
end
end
Expand All @@ -539,6 +553,7 @@ function blockCliqSiblingsParentChildrenNeedDown_StateMachine(csmc::CliqStateMac
infocsm(csmc, "6c, check/block sibl&prnt :needdownmsg")
blockCliqSiblingsParentNeedDown(csmc.tree, csmc.cliq, logger=csmc.logger)

# go to 7
return determineCliqNeedDownMsg_StateMachine
end

Expand Down Expand Up @@ -572,18 +587,30 @@ end
$SIGNATURES
Notes
- State machine function nr. 4
- State machine function nr.4
"""
function isCliqNull_StateMachine(csmc::CliqStateMachineContainer)

prnt = getParent(csmc.tree, csmc.cliq)
infocsm(csmc, "4, isCliqNull_StateMachine, csmc.incremental=$(csmc.incremental), len(prnt)=$(length(prnt))")
cliqst = getCliqStatus(csmc.oldcliqdata)
infocsm(csmc, "4, isCliqNull_StateMachine, $cliqst, len(prnt)=$(length(prnt)), csmc.incremental=$(csmc.incremental)")

if cliqst in [:marginalized;]
# if cliqst == :marginalized || cliqst in [:downsolved; :uprecycled] && ( length(prnt) == 0 ) # ||
# 0 < length(prnt) && getCliqStatus(prnt) in [:downsolved; :uprecycled; :marginalized;] )
# go to 10 -- Add case for IIF issue #474
return determineCliqIfDownSolve_StateMachine
end

#must happen before if :null
stdict = blockCliqUntilChildrenHaveUpStatus(csmc.tree, csmc.cliq, csmc.logger)
csmc.forceproceed = false

# if clique is marginalized, then no reason to continue here
# if no parent or parent will not update

# for recycle computed clique values case
if csmc.incremental && getCliqStatus(csmc.oldcliqdata) == :downsolved
if csmc.incremental && cliqst == :downsolved
csmc.incremental = false
# might be able to recycle the previous clique solve, go to 0b
return checkChildrenAllUpRecycled_StateMachine
Expand Down Expand Up @@ -614,6 +641,7 @@ function doesCliqNeeddownmsg_StateMachine(csmc::CliqStateMachineContainer)

if cliqst != :null
if cliqst != :needdownmsg
# go to 6c
return blockCliqSiblingsParentChildrenNeedDown_StateMachine
end
else
Expand All @@ -638,9 +666,11 @@ function doesCliqNeeddownmsg_StateMachine(csmc::CliqStateMachineContainer)
return blockCliqSiblingsParentChildrenNeedDown_StateMachine
end # != :null

areChildDown = areCliqChildrenNeedDownMsg(csmc.tree, csmc.cliq)
infocsm(csmc, "4b, areCliqChildrenNeedDownMsg(csmc.tree, csmc.cliq)=$(areChildDown)")
# if cliqst == :needdownmsg
if areCliqChildrenNeedDownMsg(csmc.tree, csmc.cliq)
infocsm(csmc, "4, must deal with child :needdownmsg")
if areChildDown
infocsm(csmc, "4b, must deal with child :needdownmsg")
csmc.forceproceed = true
else
# go to 5
Expand Down
3 changes: 2 additions & 1 deletion src/FactorGraph01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,8 @@ end

function ensureAllInitialized!(dfg::T; solvable::Int=1) where T <: AbstractDFG
# allvarnodes = getVariables(dfg)
syms = ls(dfg, solvable=solvable) |> sortDFG
syms = intersect(getAddHistory(dfg), ls(dfg, solvable=solvable) )
# syms = ls(dfg, solvable=solvable) # |> sortDFG
repeatCount = 0
repeatFlag = true
while repeatFlag
Expand Down
2 changes: 2 additions & 0 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using Reexport
using
Dates,
DistributedFactorGraphs,
DelimitedFiles,
Statistics,
Random,
NLsolve,
Expand Down Expand Up @@ -419,6 +420,7 @@ export
getCliqVarSingletons,
getCliqAllFactIds,
getCliqFactorIdsAll,
getCliqFactors,
areCliqVariablesAllMarginalized,
setTreeCliquesMarginalized!,

Expand Down
25 changes: 19 additions & 6 deletions src/JunctionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ Build Bayes/Junction/Elimination tree from a given variable ordering.
function buildTreeFromOrdering!(dfg::G,
p::Vector{Symbol};
drawbayesnet::Bool=false,
maxparallel::Int=50 ,
maxparallel::Int=50,
solvable::Int=1 ) where G <: InMemoryDFGTypes
#
println()
Expand Down Expand Up @@ -441,15 +441,24 @@ Notes
"""
function prepBatchTree!(dfg::AbstractDFG;
ordering::Symbol=:qr,
variableOrder::Union{Nothing, Vector{Symbol}}=nothing,
drawpdf::Bool=false,
show::Bool=false,
filepath::String="/tmp/caesar/bt.pdf",
viewerapp::String="evince",
imgs::Bool=false,
drawbayesnet::Bool=false,
maxparallel::Int=50 )
maxparallel::Int=50 )
#
p = getEliminationOrder(dfg, ordering=ordering)
p = variableOrder != nothing ? variableOrder : getEliminationOrder(dfg, ordering=ordering)

# for debuggin , its useful to have the variable ordering
if drawpdf
ispath(getLogPath(dfg)) ? nothing : Base.mkpath(getLogPath(dfg))
open(joinLogPath(dfg,"variableOrder.txt"), "a") do io
writedlm(io, string.(reshape(p,1,:)), ',')
end
end

tree = buildTreeFromOrdering!(dfg, p, drawbayesnet=drawbayesnet, maxparallel=maxparallel)

Expand Down Expand Up @@ -526,10 +535,11 @@ function wipeBuildNewTree!(dfg::G;
filepath::String="/tmp/caesar/bt.pdf",
viewerapp::String="evince",
imgs::Bool=false,
maxparallel::Int=50 )::BayesTree where G <: AbstractDFG
maxparallel::Int=50,
variableOrder::Union{Nothing, Vector{Symbol}}=nothing )::BayesTree where G <: AbstractDFG
#
resetFactorGraphNewTree!(dfg);
return prepBatchTree!(dfg, ordering=ordering, drawpdf=drawpdf, show=show, filepath=filepath, viewerapp=viewerapp, imgs=imgs, maxparallel=maxparallel);
return prepBatchTree!(dfg, variableOrder=variableOrder, ordering=ordering, drawpdf=drawpdf, show=show, filepath=filepath, viewerapp=viewerapp, imgs=imgs, maxparallel=maxparallel);
end

"""
Expand Down Expand Up @@ -953,10 +963,13 @@ DEPRECATED, use getCliqFactorIdsAll instead.
Related
getCliqVarIdsAll
getCliqVarIdsAll, getCliqFactors
"""
getCliqFactorIdsAll(cliqd::BayesTreeNodeData) = cliqd.potentials
getCliqFactorIdsAll(cliq::Graphs.ExVertex) = getCliqFactorIdsAll(getData(cliq))
getCliqFactorIdsAll(treel::BayesTree, frtl::Symbol) = getCliqFactorIdsAll(getCliq(treel, frtl))

const getCliqFactors = getCliqFactorIdsAll

"""
$SIGNATURES
Expand Down
5 changes: 3 additions & 2 deletions src/SolverAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ function solveTree!(dfgl::G,
delaycliqs::Vector{Symbol}=Symbol[],
recordcliqs::Vector{Symbol}=Symbol[],
skipcliqids::Vector{Symbol}=Symbol[],
maxparallel::Int=50 ) where G <: DFG.AbstractDFG
maxparallel::Int=50,
variableOrder::Union{Nothing, Vector{Symbol}}=nothing ) where G <: DFG.AbstractDFG
#
@info "Solving over the Bayes (Junction) tree."
smtasks=Vector{Task}()
Expand All @@ -38,7 +39,7 @@ function solveTree!(dfgl::G,
end

# current incremental solver builds a new tree and matches against old tree for recycling.
tree = wipeBuildNewTree!(dfgl, drawpdf=opt.drawtree, show=opt.showtree, maxparallel=maxparallel, filepath=joinpath(getSolverParams(dfgl).logpath,"bt.pdf"))
tree = wipeBuildNewTree!(dfgl, variableOrder=variableOrder, drawpdf=opt.drawtree, show=opt.showtree, maxparallel=maxparallel, filepath=joinpath(getSolverParams(dfgl).logpath,"bt.pdf"))
# setAllSolveFlags!(tree, false)

@info "Do tree based init-inference on tree"
Expand Down
5 changes: 5 additions & 0 deletions src/TreeBasedInitialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ function blockCliqUntilParentDownSolved(prnt::Graphs.ExVertex; logger=ConsoleLog
with_logger(logger) do
@info "blockCliqUntilParentDownSolved, prntcliq=$(prnt.index) | $lbl | going to fetch initdownchannel..."
end
flush(logger.stream)
while fetch(getData(prnt).initDownChannel) != :downsolved
# @sync begin
# @async begin
Expand Down Expand Up @@ -355,8 +356,12 @@ function blockCliqUntilChildrenHaveUpStatus(tree::BayesTree,
with_logger(logger) do
@info "cliq $(prnt.index), child $(ch.index) status is $(chst), isready(initUpCh)=$(isready(getData(ch).initUpChannel))."
end
flush(logger.stream)
ret[ch.index] = fetch(getData(ch).initUpChannel)
end
with_logger(logger) do
@info "cliq $(prnt.index), fetched all."
end
return ret
end

Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,17 @@ end
@testset "GenericWrapParam functors..." begin
include("testCommonConvWrapper.jl")
end

include("testBasicForwardConvolve.jl")

@testset "with simple local constraint examples Odo, Obsv2..." begin
include("testlocalconstraintexamples.jl")
end

include("testFactorMetadata.jl")

include("testBasicCSM.jl")

include("testExplicitMultihypo.jl")

include("TestCSMMultihypo.jl")
Expand Down
71 changes: 71 additions & 0 deletions test/testBasicCSM.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# IIF #485 --

# using Revise

using Test
using Logging
using Statistics
using DistributedFactorGraphs
using IncrementalInference


@testset "test basic three variable graph with prior" begin

VAR1 = :a
VAR2 = :b
VAR3 = :c

logger = SimpleLogger(stdout, Logging.Debug)
global_logger(logger)
dfg = initfg() #LightDFG{SolverParams}(solverParams=SolverParams())
# Add some nodes.
v1 = addVariable!(dfg, VAR1, ContinuousScalar, labels = [:POSE])
v2 = addVariable!(dfg, VAR2, ContinuousScalar, labels = [:POSE])
v3 = addVariable!(dfg, VAR3, ContinuousScalar, labels = [:LANDMARK])
f1 = addFactor!(dfg, [VAR1; VAR2], LinearConditional(Normal(50.0,2.0)) )
f2 = addFactor!(dfg, [VAR2; VAR3], LinearConditional(Normal(50.0,2.0)) )

addFactor!(dfg, [VAR1], Prior(Normal()))

# drawGraph(dfg, show=true)


# tree = wipeBuildNewTree!(dfg)
# # drawTree(tree, show=true)
#
# getCliqFactors(tree, VAR3)
# getCliqFactors(tree, VAR1)

ensureAllInitialized!(dfg)


# cliq= getCliq(tree, VAR3)
# getData(cliq)
#
# cliq= getCliq(tree, VAR1)
# getData(cliq)



getSolverParams(dfg).limititers = 50
# getSolverParams(dfg).drawtree = true
# getSolverParams(dfg).showtree = true
# getSolverParams(dfg).dbg = true
## getSolverParams(dfg).async = true


tree, smtasks, hist = solveTree!(dfg) #, recordcliqs=ls(dfg))


@test 70 < Statistics.mean(getKDE(dfg, :c) |> getPoints) < 130

# #
# using Gadfly, Cairo, Fontconfig
# drawTree(tree, show=true, imgs=true)

end




#
62 changes: 62 additions & 0 deletions test/testBasicForwardConvolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# test basic forward convolve, see IIF issue #477

# using Revise

using Test
using IncrementalInference
using Statistics


@testset "Test basic convolution result..." begin


function forwardConvolve(X0::Array{Float64,2}, model)

fg = initfg()

addVariable!(fg, :x0, ContinuousScalar)
manualinit!(fg, :x0, X0)
addVariable!(fg, :x1, ContinuousScalar)

addFactor!(fg, [:x0;:x1], model)

## TODO -- dont use name here, add API to just use z2 here
return approxConv(fg, :x0x1f1, :x1)
end


## Start

# first numerical values -- samples from the marginal of X0
z1 = Normal(0,0.1)
X0 = rand(z1, 1,100)


## predict -- project / conv
# 0 -> 1 seconds
# make approx function
z2 = Normal(11,1.0) # odo
statemodel = LinearConditional( z2 )
X1_ = forwardConvolve(X0, statemodel)


## measure -- product of beliefs, using `ApproxManifoldProducts.jl`

predX1 = manikde!(X1_, ContinuousScalar)
z3 = Normal(9.5,0.75)
measX1 = manikde!(reshape(rand(z3,100),1,:), ContinuousScalar)

# do actual product
posterioriX1 = predX1 * measX1
X1 = getPoints(posterioriX1)


## predict, 1->2 seconds
z4 = Normal(8,2.0) # odo
statemodel = LinearConditional( z4 )
X2_ = forwardConvolve(X1, statemodel)

@test size(X2_) == (1,100)
@test 15 < Statistics.mean(X2_) < 25

end

0 comments on commit dbefd2d

Please sign in to comment.