diff --git a/src/1.JWAS/src/JWAS.jl b/src/1.JWAS/src/JWAS.jl index 10c8328f..d994ebab 100644 --- a/src/1.JWAS/src/JWAS.jl +++ b/src/1.JWAS/src/JWAS.jl @@ -49,6 +49,7 @@ include("structure_equation_model/SEM.jl") #Latent Traits include("Nonlinear/nonlinear.jl") include("Nonlinear/bnn_hmc.jl") +include("Nonlinear/nnbayes_check.jl") #input include("input_data_validation.jl") @@ -286,21 +287,15 @@ function runMCMC(mme::MME,df; end mme.Gi = map(Float64,mme.Gi) end - #mega_trait + + # NNBayes mega trait: from multi-trait to multiple single-trait if mme.MCMCinfo.mega_trait == true || mme.MCMCinfo.constraint == true - if mme.nModels == 1 - error("more than 1 trait is required for MegaLMM analysis.") - end - mme.MCMCinfo.constraint = true - ##sample from scale-inv-⁠χ2, not InverseWishart - mme.df.residual = mme.df.residual - mme.nModels - mme.scaleR = diag(mme.scaleR/(mme.df.residual - 1))*(mme.df.residual-2)/mme.df.residual #diag(R_prior_mean)*(ν-2)/ν - if mme.M != 0 - for Mi in mme.M - Mi.df = Mi.df - mme.nModels - Mi.scale = diag(Mi.scale/(Mi.df - 1))*(Mi.df-2)/Mi.df - end - end + nnbayes_mega_trait(mme) + end + + # NNBayes: modify parameters for partial connected NN + if mme.nnbayes_partial==true + nnbayes_partial_para_modify2(mme) end ############################################################################ #Make incidence matrices and genotype covariates for training observations @@ -462,7 +457,7 @@ function getMCMCinfo(mme) @printf("%-30s %20s\n","Method",Mi.method) for Mi in mme.M if Mi.genetic_variance != false - if mme.nModels == 1 + if mme.nModels == 1 || mme.nnbayes_partial == true @printf("%-30s %20.3f\n","genetic variances (genomic):",Mi.genetic_variance) else @printf("%-30s\n","genetic variances (genomic):") @@ -471,7 +466,7 @@ function getMCMCinfo(mme) end end if !(Mi.method in ["GBLUP"]) - if mme.nModels == 1 + if mme.nModels == 1 || mme.nnbayes_partial == true @printf("%-30s %20.3f\n","marker effect variances:",Mi.G) else @printf("%-30s\n","marker effect variances:") diff --git a/src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl b/src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl index afb72d04..4333e19a 100644 --- a/src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl +++ b/src/1.JWAS/src/MCMC/MCMC_BayesianAlphabet.jl @@ -103,11 +103,15 @@ function MCMC_BayesianAlphabet(mme,df) end end end + if mme.nnbayes_partial == true + nnbayes_partial_para_modify3(mme) + end + #phenotypes corrected for all effects ycorr = vec(Matrix(mme.ySparse)-mme.X*mme.sol) if mme.M != 0 for Mi in mme.M - for traiti in 1:mme.nModels + for traiti in 1:Mi.ntraits if Mi.α[traiti] != zero(Mi.α[traiti]) ycorr[(traiti-1)*Mi.nObs+1 : traiti*Mi.nObs] = ycorr[(traiti-1)*Mi.nObs+1 : traiti*Mi.nObs] - Mi.genotypes*Mi.α[traiti] @@ -204,48 +208,57 @@ function MCMC_BayesianAlphabet(mme,df) # 2. Marker Effects ######################################################################## if mme.M !=0 - for Mi in mme.M + for i in 1:length(mme.M) + Mi=mme.M[i] ######################################################################## # Marker Effects ######################################################################## if Mi.method in ["BayesC","BayesB","BayesA"] locus_effect_variances = (Mi.method == "BayesC" ? fill(Mi.G,Mi.nMarkers) : Mi.G) - if is_multi_trait + if is_multi_trait && mme.nnbayes_partial==false if is_mega_trait megaBayesABC!(Mi,wArray,mme.R,locus_effect_variances) else MTBayesABC!(Mi,wArray,mme.R,locus_effect_variances) end + elseif mme.nnbayes_partial==true + BayesABC!(Mi,wArray[i],mme.R[i,i],locus_effect_variances) else BayesABC!(Mi,ycorr,mme.R,locus_effect_variances) end elseif Mi.method =="RR-BLUP" - if is_multi_trait + if is_multi_trait && mme.nnbayes_partial==false if is_mega_trait megaBayesC0!(Mi,wArray,mme.R) else MTBayesC0!(Mi,wArray,mme.R) end + elseif mme.nnbayes_partial==true + BayesC0!(Mi,wArray[i],mme.R[i,i]) else BayesC0!(Mi,ycorr,mme.R) end elseif Mi.method == "BayesL" - if is_multi_trait + if is_multi_trait && mme.nnbayes_partial==false if is_mega_trait #problem with sampleGammaArray megaBayesL!(Mi,wArray,mme.R) else MTBayesL!(Mi,wArray,mme.R) end + elseif mme.nnbayes_partial==true + BayesC0!(Mi,wArray[i],mme.R[i,i]) else BayesL!(Mi,ycorr,mme.R) end elseif Mi.method == "GBLUP" - if is_multi_trait + if is_multi_trait && mme.nnbayes_partial==false if is_mega_trait megaGBLUP!(Mi,wArray,mme.R,invweights) else MTGBLUP!(Mi,wArray,ycorr,mme.R,invweights) end + elseif mme.nnbayes_partial==true + GBLUP!(Mi,wArray[i],mme.R[i,i],invweights) else GBLUP!(Mi,ycorr,mme.R,invweights) end @@ -254,7 +267,7 @@ function MCMC_BayesianAlphabet(mme,df) # Marker Inclusion Probability ######################################################################## if Mi.estimatePi == true - if is_multi_trait + if is_multi_trait && mme.nnbayes_partial==false if is_mega_trait Mi.π = [samplePi(sum(Mi.δ[i]), Mi.nMarkers) for i in 1:mme.nModels] else @@ -321,8 +334,6 @@ function MCMC_BayesianAlphabet(mme,df) ######################################################################## # 5. Latent Traits ######################################################################## - - #mme.M[1].genotypes here is 5-by-5 if latent_traits == true #to update ycorr! sample_latent_traits(yobs,mme,ycorr,nonlinear_function,activation_function) end diff --git a/src/1.JWAS/src/Nonlinear/nnbayes_check.jl b/src/1.JWAS/src/Nonlinear/nnbayes_check.jl new file mode 100644 index 00000000..67438073 --- /dev/null +++ b/src/1.JWAS/src/Nonlinear/nnbayes_check.jl @@ -0,0 +1,157 @@ +#Below function is to check parameters for NNBayes and print information +function nnbayes_check_print_parameter(num_latent_traits,nonlinear_function,activation_function) + printstyled("Bayesian Neural Network is used with follwing information: \n",bold=false,color=:green) + + #part1: fully/partial-connected NN + if typeof(num_latent_traits) == Int64 #fully-connected. e.g, num_latent_traits=5 + printstyled(" - Neural network: fully connected neural network \n",bold=false,color=:green) + printstyled(" - Number of hidden nodes: $num_latent_traits \n",bold=false,color=:green) + nnbayes_partial=false + elseif num_latent_traits == false #partial-connected. + printstyled(" - Neural network: partially connected neural network \n",bold=false,color=:green) + nnbayes_partial=true + else + error("Please check you number of latent traits") + end + + #part2: activation function/user-defined non-linear function + if nonlinear_function == "Neural Network" #NN + if activation_function in ["tanh","sigmoid","relu","leakyrelu","linear"] + printstyled(" - Activation function: $activation_function.\n",bold=false,color=:green) + printstyled(" - Sampler: Hamiltonian Monte Carlo. \n",bold=false,color=:green) + else + error("Please select the activation function from tanh/sigmoid/relu/leakyrelu/linear") + end + elseif isa(nonlinear_function, Function) #user-defined nonlinear function. e.g, CropGrowthModel() + if activation_function == false + printstyled(" - Nonlinear function: user-defined nonlinear_function for the relationship between hidden nodes and observed trait is used.\n",bold=false,color=:green) + printstyled(" - Sampler: Matropolis-Hastings.\n",bold=false,color=:green) + else + error("activation function is not allowed for user-defined nonlinear function") + end + else + error("nonlinear_function can only be Neural Network or a user-defined nonlinear function") + end + return nnbayes_partial +end + + +#Below function is to re-phase modelm for NNBayes +function nnbayes_model_equation(model_equations,num_latent_traits) + + lhs, rhs = strip.(split(model_equations,"=")) + model_equations = "" + + if typeof(num_latent_traits) == Int64 #fully-connected + # old: y=intercept+geno + # new: y1=intercept+geno;y2=intercept+geno + for i = 1:num_latent_traits + model_equations = model_equations*lhs*string(i)*"="*rhs*";" + end + elseif num_latent_traits == false #partially-connected + # old: y=intercept+geno1+geno2 + # new: y1= intercept+geno1;y2=intercept+geno2 + rhs_split=strip.(split(rhs,"+")) + geno_term=[] + for i in rhs_split + if isdefined(Main,Symbol(i)) && typeof(getfield(Main,Symbol(i))) == Genotypes + push!(geno_term,i) + end + end + non_gene_term = filter(x->x ∉ geno_term,rhs_split) + non_gene_term = join(non_gene_term,"+") + + for i = 1:length(geno_term) + model_equations = model_equations*lhs*string(i)*"="*non_gene_term*"+"*geno_term[i]*";" + end + end + model_equations = model_equations[1:(end-1)] +end + + +# below function is to check whether the loaded genotype matches the model equation +function nnbayes_check_nhiddennode(num_latent_traits,mme) + if typeof(num_latent_traits) == Int64 #fully-connected. e.g, num_latent_traits=5 + if length(mme.M)>1 + error("fully-connected NN only allow one genotype; num_latent_traits is not allowed in partial-connected NN ") + end + elseif num_latent_traits == false #partial-connected. + if length(mme.M)==1 + error("partial-connected NN requirs >1 genotype group") + else + num_latent_traits = length(mme.M) + printstyled(" - Number of hidden nodes: $num_latent_traits \n",bold=false,color=:green) + end #Note, if only geno1 & geno2 are loaded by get_genotypes, but there is "geno3" in equation, then geno3 will be treated like age. + end +end + + + +# below function is to define the activation function for neural network +function nnbayes_activation(activation_function) + if activation_function == "tanh" + mytanh(x) = tanh(x) + return mytanh + elseif activation_function == "sigmoid" + mysigmoid(x) = 1/(1+exp(-x)) + return mysigmoid + elseif activation_function == "relu" + myrelu(x) = max(0, x) + return myrelu + elseif activation_function == "leakyrelu" + myleakyrelu(x) = max(0.01x, x) + return myleakyrelu + elseif activation_function == "linear" + mylinear(x) = x + return mylinear + else + error("invalid actication function") + end +end + + +# below function is to modify mme from multi-trait model to multiple single trait models +# coded by Hao +function nnbayes_mega_trait(mme) + #mega_trait + if mme.nModels == 1 + error("more than 1 trait is required for MegaLMM analysis.") + end + mme.MCMCinfo.constraint = true + + ##sample from scale-inv-⁠χ2, not InverseWishart + mme.df.residual = mme.df.residual - mme.nModels + mme.scaleR = diag(mme.scaleR/(mme.df.residual - 1))*(mme.df.residual-2)/mme.df.residual #diag(R_prior_mean)*(ν-2)/ν + if mme.M != 0 + for Mi in mme.M + Mi.df = Mi.df - mme.nModels + Mi.scale = diag(Mi.scale/(Mi.df - 1))*(Mi.df-2)/Mi.df + end + end + +end + + + +# below function is to modify essential parameters for partial connected NN +function nnbayes_partial_para_modify2(mme) + for Mi in mme.M + Mi.scale = Mi.scale[1] + Mi.G = Mi.G[1,1] + Mi.genetic_variance=Mi.genetic_variance[1,1] + end +end + + +# below function is to modify essential parameters for partial connected NN +function nnbayes_partial_para_modify3(mme) + for Mi in mme.M + Mi.meanVara = Mi.meanVara[1] + Mi.meanVara2 = Mi.meanVara2[1] + Mi.meanScaleVara = Mi.meanScaleVara[1] + Mi.meanScaleVara2 = Mi.meanScaleVara2[1] + Mi.π = Mi.π[1] + Mi.mean_pi = Mi.mean_pi[1] + Mi.mean_pi2 = Mi.mean_pi2[1] + end +end diff --git a/src/1.JWAS/src/build_MME.jl b/src/1.JWAS/src/build_MME.jl index feb21713..57fac50c 100644 --- a/src/1.JWAS/src/build_MME.jl +++ b/src/1.JWAS/src/build_MME.jl @@ -37,40 +37,11 @@ models = build_model(model_equations,R); function build_model(model_equations::AbstractString, R = false; df = 4.0, num_latent_traits = false, nonlinear_function = false, activation_function = false) #nonlinear_function(x1,x2) = x1+x2 if nonlinear_function != false #NNBayes - printstyled("Bayesian Neural Network is used with follwing information: \n",bold=false,color=:green) - #print activation info - if activation_function != false #e.g, activation_function="tanh" - printstyled("Activation function: $activation_function.\n Sampler: Hamiltonian Monte Carlo. \n",bold=false,color=:green) - elseif activation_function == false #e.g, nonlinear_function=f(z1,z2) - printstyled("Nonlinear function: user-defined nonlinear_function for the relationship between hidden nodes and observed trait is used.\n Sampler: Matropolis-Hastings.\n",bold=false,color=:green) - end + #NNBayes: check parameters + nnbayes_partial = nnbayes_check_print_parameter(num_latent_traits,nonlinear_function,activation_function) - #print connection info; re-write model equation - lhs, rhs = strip.(split(model_equations,"=")) - model_equations = "" - if num_latent_traits != false #e.g. num_latent_traits=5 - printstyled("Neural network: fully connected neural network \n",bold=false,color=:green) - printstyled("Number of hidden nodes: $num_latent_traits \n",bold=false,color=:green) - for i = 1:num_latent_traits - model_equations = model_equations*lhs*string(i)*"="*rhs*";" - end - elseif num_latent_traits == false #partially-connected - rhs_split=strip.(split(rhs,"+")) - geno_term=[] - for i in rhs_split - if isdefined(Main,Symbol(i)) && typeof(getfield(Main,Symbol(i))) == Genotypes - push!(geno_term,i) - end - end - non_gene_term = filter(x->x ∉ geno_term,rhs_split) - non_gene_term = join(non_gene_term,"+") - num_latent_traits = length(geno_term) - for i = 1:length(geno_term) - model_equations = model_equations*lhs*string(i)*"="*non_gene_term*"+"*geno_term[i]*";" - end - printstyled("NNBayes: partially connected with $num_latent_traits hidden nodes. \n",bold=false,color=:green) - end - model_equations = model_equations[1:(end-1)] + #NNBayes: re-write model equation + model_equations = nnbayes_model_equation(model_equations,num_latent_traits) end if R != false && !isposdef(map(AbstractFloat,R)) @@ -109,14 +80,16 @@ function build_model(model_equations::AbstractString, R = false; df = 4.0, whichterm = 1 for term in modelTerms term_symbol = Symbol(split(term.trmStr,":")[end]) - traiti = term.iModel if isdefined(Main,term_symbol) #@isdefined can be usde to tests whether a local variable or object field is defined if typeof(getfield(Main,term_symbol)) == Genotypes term.random_type = "genotypes" - if traiti == 1 #same genos are required in all traits - genotypei = getfield(Main,term_symbol) - genotypei.name = string(term_symbol) - genotypei.ntraits = nModels + genotypei = getfield(Main,term_symbol) + genotypei.name = string(term_symbol) + trait_names=[term.iTrait] + if genotypei.name ∉ map(x->x.name, genotypes) #only save unique genotype + is_nnbayes_partial = nonlinear_function != false && nnbayes_partial==true + genotypei.ntraits = is_nnbayes_partial ? 1 : nModels + genotypei.trait_names = is_nnbayes_partial ? trait_names : string.(lhsVec) if nModels != 1 genotypei.df = genotypei.df + nModels end @@ -130,6 +103,7 @@ function build_model(model_equations::AbstractString, R = false; df = 4.0, end end end + #crear mme with genotypes filter!(x->x.random_type != "genotypes",modelTerms) mme = MME(nModels,modelVec,modelTerms,dict,lhsVec,R == false ? R : Float32.(R),Float32(df)) @@ -137,31 +111,16 @@ function build_model(model_equations::AbstractString, R = false; df = 4.0, mme.M = genotypes end - #latent traits - if nonlinear_function != false #NNBayes - mme.latent_traits = true - mme.nonlinear_function = nonlinear_function - - if activation_function != false #e.g., "tanh" - if activation_function == "tanh" - mytanh(x) = tanh(x) - mme.activation_function = mytanh - elseif activation_function == "sigmoid" - mysigmoid(x) = 1/(1+exp(-x)) - mme.activation_function = mysigmoid - elseif activation_function == "relu" - myrelu(x) = max(0, x) - mme.activation_function = myrelu - elseif activation_function == "leakyrelu" - myleakyrelu(x) = max(0.01x, x) - mme.activation_function = myleakyrelu - elseif activation_function == "linear" - mylinear(x) = x - mme.activation_function = mylinear - else - error("Please select the activation function from tanh/sigmoid/relu/leakyrelu/linear.") - end - end + #NNBayes: + if nonlinear_function != false + #NNBayes: check parameters again + nnbayes_check_nhiddennode(num_latent_traits,mme) + + + mme.latent_traits = true + mme.nnbayes_partial = nnbayes_partial + mme.nonlinear_function = nonlinear_function + mme.activation_function = activation_function!=false ? nnbayes_activation(activation_function) : false end return mme diff --git a/src/1.JWAS/src/output.jl b/src/1.JWAS/src/output.jl index 557cb997..f3545368 100644 --- a/src/1.JWAS/src/output.jl +++ b/src/1.JWAS/src/output.jl @@ -102,7 +102,7 @@ function output_result(mme,output_folder, whicheffect = Mi.meanAlpha[traiti] whicheffectsd = sqrt.(abs.(Mi.meanAlpha2[traiti] .- Mi.meanAlpha[traiti] .^2)) whichdelta = Mi.meanDelta[traiti] - for traiti in 2:mme.nModels + for traiti in 2:Mi.ntraits whichtrait = vcat(whichtrait,fill(string(mme.lhsVec[traiti]),length(Mi.markerID))) whichmarker = vcat(whichmarker,Mi.markerID) whicheffect = vcat(whicheffect,Mi.meanAlpha[traiti]) @@ -205,6 +205,12 @@ end (internal function) Get breeding values for individuals defined by outputEBV(), defaulting to all genotyped individuals. This function is used inside MCMC functions for one MCMC samples from posterior distributions. +e.g., +non-NNBayes_partial (multi-classs Bayes) : y1=M1*α1[1]+M2*α2[1]+M3*α3[1] + y2=M1*α1[2]+M2*α2[2]+M3*α3[2]; +NNBayes_partial: y1=M1*α1[1] + y2=M2*α2[1] + y3=M3*α3[1]; """ function getEBV(mme,traiti) traiti_name = string(mme.lhsVec[traiti]) @@ -223,8 +229,15 @@ function getEBV(mme,traiti) end end if mme.M != 0 - for Mi in mme.M - EBV += Mi.output_genotypes*Mi.α[traiti] + for i in 1:length(mme.M) + Mi=mme.M[i] + if mme.nnbayes_partial==false #non-NNBayes_partial + EBV += Mi.output_genotypes*Mi.α[traiti] + else #NNBayes_partial + if i==traiti + EBV = Mi.output_genotypes*mme.M[i].α[1] + end + end end end return EBV @@ -253,8 +266,8 @@ function output_MCMC_samples_setup(mme,nIter,output_samples_frequency,file_name= end if mme.M !=0 #write samples for marker effects to a text file for Mi in mme.M - for traiti in 1:ntraits - push!(outvar,"marker_effects_"*Mi.name*"_"*string(mme.lhsVec[traiti])) + for traiti in Mi.trait_names + push!(outvar,"marker_effects_"*Mi.name*"_"*traiti) end push!(outvar,"marker_effects_variances"*"_"*Mi.name) push!(outvar,"pi"*"_"*Mi.name) @@ -325,8 +338,8 @@ function output_MCMC_samples_setup(mme,nIter,output_samples_frequency,file_name= if mme.M !=0 for Mi in mme.M - for traiti in 1:ntraits - writedlm(outfile["marker_effects_"*Mi.name*"_"*string(mme.lhsVec[traiti])],transubstrarr(Mi.markerID),',') + for traiti in Mi.trait_names + writedlm(outfile["marker_effects_"*Mi.name*"_"*traiti],transubstrarr(Mi.markerID),',') end end end @@ -374,8 +387,8 @@ function output_MCMC_samples(mme,vRes,G0, end if mme.M != 0 && outfile != false for Mi in mme.M - for traiti in 1:ntraits - writedlm(outfile["marker_effects_"*Mi.name*"_"*string(mme.lhsVec[traiti])],Mi.α[traiti]',',') + for traiti in 1:Mi.ntraits + writedlm(outfile["marker_effects_"*Mi.name*"_"*Mi.trait_names[traiti]],Mi.α[traiti],',') end if Mi.G != false if mme.nModels == 1 @@ -400,9 +413,11 @@ function output_MCMC_samples(mme,vRes,G0, writedlm(outfile["EBV_"*string(mme.lhsVec[1])],myEBV',',') for traiti in 2:ntraits myEBV = getEBV(mme,traiti) #actually BV - writedlm(outfile["EBV_"*string(mme.lhsVec[traiti])],myEBV',',') + trait_name = mme.nnbayes_partial ? mme.M[traiti].trait_names[1] : string(mme.lhsVec[traiti]) + writedlm(outfile["EBV_"*trait_name],myEBV',',') EBVmat = [EBVmat myEBV] end + if mme.MCMCinfo.output_heritability == true && mme.MCMCinfo.single_step_analysis == false mygvar = cov(EBVmat) genetic_variance = (ntraits == 1 ? mygvar : vec(mygvar)') diff --git a/src/1.JWAS/src/types.jl b/src/1.JWAS/src/types.jl index 432e31ae..aec2b395 100644 --- a/src/1.JWAS/src/types.jl +++ b/src/1.JWAS/src/types.jl @@ -77,7 +77,8 @@ mutable struct RandomEffect #Better to be a dict? key: term_array::Array{Abstr end mutable struct Genotypes - name #name for this category + name #name for this category, eg. "geno1" + trait_names #names for the corresponding traits, eg.["y1","y2"] obsID::Array{AbstractString,1} #row ID for (imputed) genotyped and phenotyped inds (finally) markerID @@ -124,8 +125,8 @@ mutable struct Genotypes output_genotypes #output genotypes isGRM #whether genotypes or relationship matirx is provided - - Genotypes(a1,a2,a3,a4,a5,a6,a7,a8,a9)=new(false, + + Genotypes(a1,a2,a3,a4,a5,a6,a7,a8,a9)=new(false,false, a1,a2,a3,a4,a5,a6,a7,a8,a4,false, false,false,false,false, false,true,true,false, @@ -246,6 +247,7 @@ mutable struct MME weights_NN σ2_yobs activation_function + nnbayes_partial function MME(nModels,modelVec,modelTerms,dict,lhsVec,R,ν) @@ -272,6 +274,6 @@ mutable struct MME 0, false,false,false, false, - false,false,false,1.0,false) + false,false,false,1.0,false,false) end end