Skip to content

Commit

Permalink
Merge pull request #92 from reworkhow/BayesNN
Browse files Browse the repository at this point in the history
add partial-connected Bayesian Neural Networks
  • Loading branch information
reworkhow authored Jun 18, 2021
2 parents 082b76b + 4a76da2 commit b529ba8
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 102 deletions.
27 changes: 11 additions & 16 deletions src/1.JWAS/src/JWAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ include("structure_equation_model/SEM.jl")
#Latent Traits
include("Nonlinear/nonlinear.jl")
include("Nonlinear/bnn_hmc.jl")
include("Nonlinear/nnbayes_check.jl")

#input
include("input_data_validation.jl")
Expand Down Expand Up @@ -286,21 +287,15 @@ function runMCMC(mme::MME,df;
end
mme.Gi = map(Float64,mme.Gi)
end
#mega_trait

# NNBayes mega trait: from multi-trait to multiple single-trait
if mme.MCMCinfo.mega_trait == true || mme.MCMCinfo.constraint == true
if mme.nModels == 1
error("more than 1 trait is required for MegaLMM analysis.")
end
mme.MCMCinfo.constraint = true
##sample from scale-inv-⁠χ2, not InverseWishart
mme.df.residual = mme.df.residual - mme.nModels
mme.scaleR = diag(mme.scaleR/(mme.df.residual - 1))*(mme.df.residual-2)/mme.df.residual #diag(R_prior_mean)*(ν-2)/ν
if mme.M != 0
for Mi in mme.M
Mi.df = Mi.df - mme.nModels
Mi.scale = diag(Mi.scale/(Mi.df - 1))*(Mi.df-2)/Mi.df
end
end
nnbayes_mega_trait(mme)
end

# NNBayes: modify parameters for partial connected NN
if mme.nnbayes_partial==true
nnbayes_partial_para_modify2(mme)
end
############################################################################
#Make incidence matrices and genotype covariates for training observations
Expand Down Expand Up @@ -462,7 +457,7 @@ function getMCMCinfo(mme)
@printf("%-30s %20s\n","Method",Mi.method)
for Mi in mme.M
if Mi.genetic_variance != false
if mme.nModels == 1
if mme.nModels == 1 || mme.nnbayes_partial == true
@printf("%-30s %20.3f\n","genetic variances (genomic):",Mi.genetic_variance)
else
@printf("%-30s\n","genetic variances (genomic):")
Expand All @@ -471,7 +466,7 @@ function getMCMCinfo(mme)
end
end
if !(Mi.method in ["GBLUP"])
if mme.nModels == 1
if mme.nModels == 1 || mme.nnbayes_partial == true
@printf("%-30s %20.3f\n","marker effect variances:",Mi.G)
else
@printf("%-30s\n","marker effect variances:")
Expand Down
29 changes: 20 additions & 9 deletions src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@ function MCMC_BayesianAlphabet(mme,df)
end
end
end
if mme.nnbayes_partial == true
nnbayes_partial_para_modify3(mme)
end

#phenotypes corrected for all effects
ycorr = vec(Matrix(mme.ySparse)-mme.X*mme.sol)
if mme.M != 0
for Mi in mme.M
for traiti in 1:mme.nModels
for traiti in 1:Mi.ntraits
if Mi.α[traiti] != zero(Mi.α[traiti])
ycorr[(traiti-1)*Mi.nObs+1 : traiti*Mi.nObs] = ycorr[(traiti-1)*Mi.nObs+1 : traiti*Mi.nObs]
- Mi.genotypes*Mi.α[traiti]
Expand Down Expand Up @@ -204,48 +208,57 @@ function MCMC_BayesianAlphabet(mme,df)
# 2. Marker Effects
########################################################################
if mme.M !=0
for Mi in mme.M
for i in 1:length(mme.M)
Mi=mme.M[i]
########################################################################
# Marker Effects
########################################################################
if Mi.method in ["BayesC","BayesB","BayesA"]
locus_effect_variances = (Mi.method == "BayesC" ? fill(Mi.G,Mi.nMarkers) : Mi.G)
if is_multi_trait
if is_multi_trait && mme.nnbayes_partial==false
if is_mega_trait
megaBayesABC!(Mi,wArray,mme.R,locus_effect_variances)
else
MTBayesABC!(Mi,wArray,mme.R,locus_effect_variances)
end
elseif mme.nnbayes_partial==true
BayesABC!(Mi,wArray[i],mme.R[i,i],locus_effect_variances)
else
BayesABC!(Mi,ycorr,mme.R,locus_effect_variances)
end
elseif Mi.method =="RR-BLUP"
if is_multi_trait
if is_multi_trait && mme.nnbayes_partial==false
if is_mega_trait
megaBayesC0!(Mi,wArray,mme.R)
else
MTBayesC0!(Mi,wArray,mme.R)
end
elseif mme.nnbayes_partial==true
BayesC0!(Mi,wArray[i],mme.R[i,i])
else
BayesC0!(Mi,ycorr,mme.R)
end
elseif Mi.method == "BayesL"
if is_multi_trait
if is_multi_trait && mme.nnbayes_partial==false
if is_mega_trait #problem with sampleGammaArray
megaBayesL!(Mi,wArray,mme.R)
else
MTBayesL!(Mi,wArray,mme.R)
end
elseif mme.nnbayes_partial==true
BayesC0!(Mi,wArray[i],mme.R[i,i])
else
BayesL!(Mi,ycorr,mme.R)
end
elseif Mi.method == "GBLUP"
if is_multi_trait
if is_multi_trait && mme.nnbayes_partial==false
if is_mega_trait
megaGBLUP!(Mi,wArray,mme.R,invweights)
else
MTGBLUP!(Mi,wArray,ycorr,mme.R,invweights)
end
elseif mme.nnbayes_partial==true
GBLUP!(Mi,wArray[i],mme.R[i,i],invweights)
else
GBLUP!(Mi,ycorr,mme.R,invweights)
end
Expand All @@ -254,7 +267,7 @@ function MCMC_BayesianAlphabet(mme,df)
# Marker Inclusion Probability
########################################################################
if Mi.estimatePi == true
if is_multi_trait
if is_multi_trait && mme.nnbayes_partial==false
if is_mega_trait
Mi.π = [samplePi(sum(Mi.δ[i]), Mi.nMarkers) for i in 1:mme.nModels]
else
Expand Down Expand Up @@ -321,8 +334,6 @@ function MCMC_BayesianAlphabet(mme,df)
########################################################################
# 5. Latent Traits
########################################################################

#mme.M[1].genotypes here is 5-by-5
if latent_traits == true #to update ycorr!
sample_latent_traits(yobs,mme,ycorr,nonlinear_function,activation_function)
end
Expand Down
157 changes: 157 additions & 0 deletions src/1.JWAS/src/Nonlinear/nnbayes_check.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#Below function is to check parameters for NNBayes and print information
function nnbayes_check_print_parameter(num_latent_traits,nonlinear_function,activation_function)
printstyled("Bayesian Neural Network is used with follwing information: \n",bold=false,color=:green)

#part1: fully/partial-connected NN
if typeof(num_latent_traits) == Int64 #fully-connected. e.g, num_latent_traits=5
printstyled(" - Neural network: fully connected neural network \n",bold=false,color=:green)
printstyled(" - Number of hidden nodes: $num_latent_traits \n",bold=false,color=:green)
nnbayes_partial=false
elseif num_latent_traits == false #partial-connected.
printstyled(" - Neural network: partially connected neural network \n",bold=false,color=:green)
nnbayes_partial=true
else
error("Please check you number of latent traits")
end

#part2: activation function/user-defined non-linear function
if nonlinear_function == "Neural Network" #NN
if activation_function in ["tanh","sigmoid","relu","leakyrelu","linear"]
printstyled(" - Activation function: $activation_function.\n",bold=false,color=:green)
printstyled(" - Sampler: Hamiltonian Monte Carlo. \n",bold=false,color=:green)
else
error("Please select the activation function from tanh/sigmoid/relu/leakyrelu/linear")
end
elseif isa(nonlinear_function, Function) #user-defined nonlinear function. e.g, CropGrowthModel()
if activation_function == false
printstyled(" - Nonlinear function: user-defined nonlinear_function for the relationship between hidden nodes and observed trait is used.\n",bold=false,color=:green)
printstyled(" - Sampler: Matropolis-Hastings.\n",bold=false,color=:green)
else
error("activation function is not allowed for user-defined nonlinear function")
end
else
error("nonlinear_function can only be Neural Network or a user-defined nonlinear function")
end
return nnbayes_partial
end


#Below function is to re-phase modelm for NNBayes
function nnbayes_model_equation(model_equations,num_latent_traits)

lhs, rhs = strip.(split(model_equations,"="))
model_equations = ""

if typeof(num_latent_traits) == Int64 #fully-connected
# old: y=intercept+geno
# new: y1=intercept+geno;y2=intercept+geno
for i = 1:num_latent_traits
model_equations = model_equations*lhs*string(i)*"="*rhs*";"
end
elseif num_latent_traits == false #partially-connected
# old: y=intercept+geno1+geno2
# new: y1= intercept+geno1;y2=intercept+geno2
rhs_split=strip.(split(rhs,"+"))
geno_term=[]
for i in rhs_split
if isdefined(Main,Symbol(i)) && typeof(getfield(Main,Symbol(i))) == Genotypes
push!(geno_term,i)
end
end
non_gene_term = filter(x->x geno_term,rhs_split)
non_gene_term = join(non_gene_term,"+")

for i = 1:length(geno_term)
model_equations = model_equations*lhs*string(i)*"="*non_gene_term*"+"*geno_term[i]*";"
end
end
model_equations = model_equations[1:(end-1)]
end


# below function is to check whether the loaded genotype matches the model equation
function nnbayes_check_nhiddennode(num_latent_traits,mme)
if typeof(num_latent_traits) == Int64 #fully-connected. e.g, num_latent_traits=5
if length(mme.M)>1
error("fully-connected NN only allow one genotype; num_latent_traits is not allowed in partial-connected NN ")
end
elseif num_latent_traits == false #partial-connected.
if length(mme.M)==1
error("partial-connected NN requirs >1 genotype group")
else
num_latent_traits = length(mme.M)
printstyled(" - Number of hidden nodes: $num_latent_traits \n",bold=false,color=:green)
end #Note, if only geno1 & geno2 are loaded by get_genotypes, but there is "geno3" in equation, then geno3 will be treated like age.
end
end



# below function is to define the activation function for neural network
function nnbayes_activation(activation_function)
if activation_function == "tanh"
mytanh(x) = tanh(x)
return mytanh
elseif activation_function == "sigmoid"
mysigmoid(x) = 1/(1+exp(-x))
return mysigmoid
elseif activation_function == "relu"
myrelu(x) = max(0, x)
return myrelu
elseif activation_function == "leakyrelu"
myleakyrelu(x) = max(0.01x, x)
return myleakyrelu
elseif activation_function == "linear"
mylinear(x) = x
return mylinear
else
error("invalid actication function")
end
end


# below function is to modify mme from multi-trait model to multiple single trait models
# coded by Hao
function nnbayes_mega_trait(mme)
#mega_trait
if mme.nModels == 1
error("more than 1 trait is required for MegaLMM analysis.")
end
mme.MCMCinfo.constraint = true

##sample from scale-inv-⁠χ2, not InverseWishart
mme.df.residual = mme.df.residual - mme.nModels
mme.scaleR = diag(mme.scaleR/(mme.df.residual - 1))*(mme.df.residual-2)/mme.df.residual #diag(R_prior_mean)*(ν-2)/ν
if mme.M != 0
for Mi in mme.M
Mi.df = Mi.df - mme.nModels
Mi.scale = diag(Mi.scale/(Mi.df - 1))*(Mi.df-2)/Mi.df
end
end

end



# below function is to modify essential parameters for partial connected NN
function nnbayes_partial_para_modify2(mme)
for Mi in mme.M
Mi.scale = Mi.scale[1]
Mi.G = Mi.G[1,1]
Mi.genetic_variance=Mi.genetic_variance[1,1]
end
end


# below function is to modify essential parameters for partial connected NN
function nnbayes_partial_para_modify3(mme)
for Mi in mme.M
Mi.meanVara = Mi.meanVara[1]
Mi.meanVara2 = Mi.meanVara2[1]
Mi.meanScaleVara = Mi.meanScaleVara[1]
Mi.meanScaleVara2 = Mi.meanScaleVara2[1]
Mi.π = Mi.π[1]
Mi.mean_pi = Mi.mean_pi[1]
Mi.mean_pi2 = Mi.mean_pi2[1]
end
end
Loading

0 comments on commit b529ba8

Please sign in to comment.