Skip to content

Commit

Permalink
Merge pull request #81 from reworkhow/BayesNN
Browse files Browse the repository at this point in the history
update Bayesian Neural Network with HMC
  • Loading branch information
reworkhow authored Mar 17, 2021
2 parents 37015a0 + 0c6e91e commit d9df108
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 29 deletions.
20 changes: 13 additions & 7 deletions src/1.JWAS/src/JWAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("structure_equation_model/SEM.jl")

#Latent Traits
include("Nonlinear/nonlinear.jl")
include("Nonlinear/bnn_hmc.jl")

#input
include("input_data_validation.jl")
Expand Down Expand Up @@ -179,6 +180,14 @@ function runMCMC(mme::MME,df;
Pi = 0.0,
estimatePi = false,
estimateScale = false)

#Nonlinear
if mme.latent_traits == true
yobs = df[!,Symbol(string(Symbol(mme.lhsVec[1]))[1:(end-1)])]
for i in mme.lhsVec
df[!,i]= yobs
end
end
#for deprecated JWAS fucntions
if mme.M != 0
for Mi in mme.M
Expand Down Expand Up @@ -256,13 +265,9 @@ function runMCMC(mme::MME,df;
error("The causal structue needs to be a lower triangular matrix.")
end
end
#Nonlinear
if mme.latent_traits == true
yobs = df_whole[!,Symbol(string(Symbol(mme.lhsVec[1]))[1:(end-1)])]
for i in mme.lhsVec
df_whole[!,i]= yobs
end
end



# Double Precision
if double_precision == true
if mme.M != 0
Expand Down Expand Up @@ -320,6 +325,7 @@ function runMCMC(mme::MME,df;
for (key,value) in mme.output
CSV.write(output_folder*"/"*replace(key," "=>"_")*".txt",value)
end

if mme.M != 0
for Mi in mme.M
if Mi.name == "GBLUP"
Expand Down
11 changes: 10 additions & 1 deletion src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ function MCMC_BayesianAlphabet(mme,df)
############################################################################
# MCMC (starting values for sol (zeros); mme.RNew; G0 are used)
############################################################################
# # Initialize mme for hmc before Gibbs
if mme.latent_traits == true
num_latent_traits = mme.M[1].ntraits
mme.weights_NN = vcat(mean(mme.ySparse),zeros(num_latent_traits))
end

@showprogress "running MCMC ..." for iter=1:chain_length
########################################################################
# 0. Categorical traits (liabilities)
Expand Down Expand Up @@ -204,6 +210,7 @@ function MCMC_BayesianAlphabet(mme,df)
else
Gibbs(mme.mmeLhs,mme.sol,mme.mmeRhs,mme.R)
end

ycorr[:] = ycorr - mme.X*mme.sol
########################################################################
# 2. Marker Effects
Expand Down Expand Up @@ -326,7 +333,9 @@ function MCMC_BayesianAlphabet(mme,df)
########################################################################
# 5. Latent Traits
########################################################################
if latent_traits == true

#mme.M[1].genotypes here is 5-by-5
if latent_traits == true #to update ycorr!
sample_latent_traits(yobs,mme,ycorr,nonlinear_function)
end
########################################################################
Expand Down
95 changes: 95 additions & 0 deletions src/1.JWAS/src/Nonlinear/bnn_hmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#tianjing: modified from non-linear.jl
#sample all hidden nodes together

# |---- Z[:,1] ----- Z0*W0[:,1]
#yobs ---f(X)----|---- Z[:,2] ----- Z0*W0[:,2]
# |---- Z[:,3] ----- Z0*W0[:,3]
#
# in total 1 hidden layers, with l1 nodes.

###########
# X: marker covariate matrix, n by p, each col is for 1 marker
# Z: latent traits, n by l1 matrix, each col is for 1 latent trait
# y: observed trait, vector of n by 1
###########

###########
# W0: marker effects, matrix of p by l1, each col is for a latent trait
# W1: weights from hidden layer to observed trait, vector of l1 by 1
###########

###########
# Mu0: bias of latent traits, vector of length l1
# mu: bias of observed trait, scaler
###########

###########
# Sigma2z: residual variance of latent trait, diagonal matrix of size l1*l1
# sigma2e: residual variance of observed trait, scaler
###########


#helper 1: calculate gradiant of all latent traits for all individual
function calc_gradient_z(ylats,yobs,weights_NN,σ_ylats,σ_yobs,ycorr)
μ1, w1 = weights_NN[1], weights_NN[2:end]
tanh_ylats = tanh.(ylats)
#dlogfz =- (Z - ones(n)*Mu0' - X*W0) * inv(Sigma2z) #(n,l1)
dlogf_ylats = - ycorr * inv(σ_ylats)
dlogfy = ((yobs .- μ1 - tanh_ylats*w1)/σ_yobs) * w1' .* (-tanh_ylats.^2 .+ 1) #size: (n, l1)
gradient_ylats = dlogf_ylats + dlogfy

return gradient_ylats #size (n,l1)
end

# helper 2: calculate log p(z|y) to help calculate the acceptance rate
function calc_log_p_z(ylats,yobs,weights_NN,σ_ylats,σ_yobs,ycorr)
μ1 = weights_NN[1]
w1 = weights_NN[2:end]
#logfz = -0.5*sum(((Z-ones(n)*Mu0'-X*W0).^2)*inv(Sigma2z),dims=2) .- (0.5*log(prod(diag(Sigma2z))))
logf_ylats = -0.5*sum((ycorr.^2)*inv(σ_ylats),dims=2) .- (0.5*log(prod(diag(σ_ylats))))
logfy = -0.5*(yobs .- μ1 - tanh.(ylats)*w1).^2 /σ_yobs .- 0.5*log(σ_yobs)
log_p_ylats= logf_ylats + logfy

return log_p_ylats #size: (n,1)
end

#helper 3: one iterations of HMC to sample Z
#ycor is a temporary variable to save ycorr after reshape; ycorr is residual for latent traits
function hmc_one_iteration(nLeapfrog,ϵ,ylats_old,yobs,weights_NN,σ_ylats,σ_yobs,ycorr)
nobs, ntraits = size(ylats_old)
ylats_old = copy(ylats_old)
ylats_new = copy(ylats_old)

#step 1: Initiate Φ from N(0,M)
Φ = randn(nobs, ntraits) #rand(n,Normal(0,M=1.0)), tuning parameter: M
log_p_old = calc_log_p_z(ylats_old,yobs,weights_NN,σ_ylats,σ_yobs,ycorr) - 0.5*sum.^2,dims=2) #(n,1)
#step 2: update (ylats,Φ) from 10 leapfrog
#2(a): update Φ
Φ += 0.5 * ϵ * calc_gradient_z(ylats_new,yobs,weights_NN,σ_ylats,σ_yobs,ycorr) #(n,l1)
for leap_i in 1:nLeapfrog
#2(b) update latent traits
ylats_new += ϵ * Φ # (n,l1)
ycorr += ϵ * Φ #update ycorr due to change of Z
#(c) half step of phi
if leap_i == nLeapfrog
#2(c): update Φ
Φ += 0.5 * ϵ * calc_gradient_z(ylats_new,yobs,weights_NN,σ_ylats,σ_yobs,ycorr)
else
#2(a)+2(c): update Φ
Φ += ϵ * calc_gradient_z(ylats_new,yobs,weights_NN,σ_ylats,σ_yobs,ycorr)
end
end

#Step3. acceptance rate
log_p_new = calc_log_p_z(ylats_new,yobs,weights_NN,σ_ylats,σ_yobs,ycorr) - 0.5*sum.^2,dims=2) #(n,1)
r = exp.(log_p_new - log_p_old) # (n,1)
nojump = rand(nobs) .> r # bool (n,1)

for i in 1:nobs
if nojump[i]
ylats_new[i,:] = ylats_old[i,:]
end
end

return ylats_new
end
44 changes: 25 additions & 19 deletions src/1.JWAS/src/Nonlinear/nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,39 @@
#nonlinear function: e.g.,
# (1) pig_growth(x1,x2) = sqrt(x1^2 / (x1^2 + x2^2))
# (2) neural network: a1*tan(x1)+a2*tan(x2)

#nonlinear_function: #user-provide function, "Neural Network"
function sample_latent_traits(yobs,mme,ycorr,nonlinear_function)
ylats_old = mme.ySparse # current values of each latent trait
μ_ylats = mme.ySparse - ycorr # mean of each latent trait
ylats_old = mme.ySparse # current values of each latent trait; [trait_1_obs;trait_2_obs;...]
μ_ylats = mme.ySparse - ycorr # mean of each latent trait, [trait_1_obs-residuals;trait_2_obs-residuals;...]
# = vcat(getEBV(mme,1).+mme.sol[1],getEBV(mme,2).+mme.sol[2]))
σ2_yobs = mme.σ2_yobs # residual variance of yobs (scalar)

#reshape the vector to nind X ntraits
nobs, ntraits = length(mme.obsID), mme.nModels
ylats_old = reshape(ylats_old,nobs,ntraits)
ylats_old = reshape(ylats_old,nobs,ntraits) #Tianjing's mme.Z
μ_ylats = reshape(μ_ylats,nobs,ntraits)

if nonlinear_function == "Neural Network" #HMC
ylats_new = hmc_one_iteration(10,0.1,ylats_old,yobs,mme.weights_NN,mme.R,σ2_yobs,reshape(ycorr,nobs,ntraits))
else
candidates = μ_ylats+randn(size(μ_ylats)) #candidate samples
if nonlinear_function == "Neural Network (MH)"
μ_yobs_candidate = [ones(nobs) tanh.(candidates)]*weights
μ_yobs_current = X*weights
else #user-defined non-linear function
μ_yobs_candidate = nonlinear_function.(Tuple([view(candidates,:,i) for i in 1:ntraits])...)
μ_yobs_current = nonlinear_function.(Tuple([view(ylats_old,:,i) for i in 1:ntraits])...)
end
llh_current = -0.5*(yobs - μ_yobs_current ).^2/σ2_yobs
llh_candidate = -0.5*(yobs - μ_yobs_candidate).^2/σ2_yobs
mhRatio = exp.(llh_candidate - llh_current)
updateus = rand(nobs) .< mhRatio
ylats_new = candidates.*updateus + ylats_old.*(.!updateus)
end

if nonlinear_function == "Neural Network" #sample weights
X = [ones(nobs) tanh.(ylats_old)]
X = [ones(nobs) tanh.(ylats_new)]
lhs = X'X + I*0.00001
Ch = cholesky(lhs)
L = Ch.L
Expand All @@ -27,22 +47,8 @@ function sample_latent_traits(yobs,mme,ycorr,nonlinear_function)
mme.weights_NN = weights
end

candidates = μ_ylats+randn(size(μ_ylats)) #candidate samples
if nonlinear_function == "Neural Network"
μ_yobs_candidate = [ones(nobs) tanh.(candidates)]*weights
μ_yobs_current = X*weights
else
μ_yobs_candidate = nonlinear_function.(Tuple([view(candidates,:,i) for i in 1:ntraits])...)
μ_yobs_current = nonlinear_function.(Tuple([view(ylats_old,:,i) for i in 1:ntraits])...)
end
llh_current = -0.5*(yobs - μ_yobs_current ).^2/σ2_yobs
llh_candidate = -0.5*(yobs - μ_yobs_candidate).^2/σ2_yobs
mhRatio = exp.(llh_candidate - llh_current)
updateus = rand(nobs) .< mhRatio
ylats_new = candidates.*updateus + ylats_old.*(.!updateus)

mme.ySparse = vec(ylats_new)
ycorr[:] = mme.ySparse - vec(μ_ylats)
ycorr[:] = mme.ySparse - vec(μ_ylats) # =(ylats_new - ylats_old) + ycorr: update residuls (ycorr)

#sample σ2_yobs
if nonlinear_function != "Neural Network"
Expand Down
2 changes: 1 addition & 1 deletion src/1.JWAS/src/build_MME.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function build_model(model_equations::AbstractString, R = false; df = 4.0,
mme.M = genotypes
end

#laten traits
#latent traits
if num_latent_traits != false
mme.latent_traits = true
if nonlinear_function != false
Expand Down
3 changes: 2 additions & 1 deletion src/1.JWAS/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ mutable struct MME
causal_structure

latent_traits
nonlinear_function
nonlinear_function #user-provide function, "Neural Network"
weights_NN
σ2_yobs


function MME(nModels,modelVec,modelTerms,dict,lhsVec,R,ν)
if nModels == 1
scaleR = R*-2)/ν
Expand Down

0 comments on commit d9df108

Please sign in to comment.