Skip to content

Commit

Permalink
Merge pull request #4 from katehoffshutta/dragon-r
Browse files Browse the repository at this point in the history
Dragon r
  • Loading branch information
katehoffshutta authored Nov 30, 2022
2 parents a76aeb2 + e0cee54 commit f7d35ca
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
Package: netZooR
Type: Package
Title: Unified methods for the inference and analysis of gene regulatory networks

Version: 1.1.16

Date: 2022-07-07
Authors@R: c(person("Marouen", "Ben Guebila",
email = "benguebila@hsph.harvard.edu", role = c("aut","cre"), comment = c(ORCID = "0000-0001-5934-966X")),
Expand Down
59 changes: 58 additions & 1 deletion R/DRAGON.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ risk = function(gamma, const, t11, t12, t21, t22, t3, t4)

R = const + (1-gamma1^2)*t11 + (1-gamma2^2)*t12 +
(1-gamma1^2)^2*t21 + (1-gamma2^2)^2*t22 +
(1-gamma1^2)*(1-gamma2^2)*t3 + gamma1*gamma2*t4
(1-gamma1^2)*(1-gamma2^2)*t3 + gamma1*gamma2*t4
return(R)
}

Expand Down Expand Up @@ -242,6 +242,7 @@ estimatePenaltyParameters = function(X1,X2)
upper=c(1,1),
control = list(pgtol = 1e-12))

return(res)
# penalty_parameters = (1.-res.x[0]**2), (1.-res.x[1]**2)
#
# def risk_orig(lam):
Expand All @@ -254,6 +255,62 @@ estimatePenaltyParameters = function(X1,X2)
#
}

get_shrunken_covariance_dragon = function(X1,X2, lambdas)
{
n = nrow(X1)
p1 = ncol(X1)
p2 = ncol(X2)
p = p1 + p2
X = cbind.data.frame(X1,X2)
S = cov(X) # the R implementation of cov() uses the unbiased (1/(n-1)); we need the unbiased version for the lemma of Ledoit and Wolf

# target matrix
Targ = diag(diag(S))
Sigma = matrix(nrow=p, ncol=p)

# Sigma = np.zeros((p,p))
# IDs = np.cumsum([0,p1,p2])
IDs = c(cumsum(c(p1,p2)))

idx1 = 1:IDs[1]
idx2 = (IDs[1]+1):IDs[2]

# Fill in Sigma_11
Sigma[idx1,idx1] = (1-lambdas[1])*S[idx1,idx1] + lambdas[1]*Targ[idx1,idx1]

# Fill in Sigma_22
Sigma[idx2,idx2] = (1-lambdas[2])*S[idx2,idx2] + lambdas[2]*Targ[idx2,idx2]

# Fill in Sigma_12
Sigma[idx1,idx2] = sqrt((1-lambdas[1])*(1-lambdas[2]))*S[idx1,idx2] + sqrt(lambdas[1]*lambdas[2])*Targ[idx1,idx2]

# Fill in Sigma_21
Sigma[idx2,idx1] = sqrt((1-lambdas[1])*(1-lambdas[2]))*S[idx2,idx1] + sqrt(lambdas[1]*lambdas[2])*Targ[idx2,idx1]

return(Sigma)
}

get_partial_correlation_from_precision = function(Theta)
{

}

get_precision_matrix_dragon = function(X1, X2, lambdas)
{

}

get_partial_correlation_dragon = function(X1,X2,lambdas)
{

}

estimate_kappa = function(n, p, lambda0, seed)
{

}


dragon = function()
{

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
,0,1,2
0,4.0,7.5,-0.6123724356957945
1,7.5,37.0,0.30618621784789724
2,-0.6123724356957945,0.30618621784789724,1.0
24 changes: 21 additions & 3 deletions tests/testthat/test-dragon.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# test dragon
# unit-tests for DRAGON

context("test DRAGON helper functions")

test_that("[DRAGON] scale() function yields expected results", {
Expand Down Expand Up @@ -42,12 +43,29 @@ test_that("[DRAGON] risk() function calculates correct risk",{
T22 = 6
T3 = 7
T4 = 8
const = 1

# manual calc
R_hand = (1-Gamma1^2)*T11 + (1-Gamma2^2)*T12 +
R_hand = 1+(1-Gamma1^2)*T11 + (1-Gamma2^2)*T12 +
(1-Gamma1^2)^2*T21 +(1-Gamma2^2)^2*T22 +
(1-Gamma1^2)*(1-Gamma2^2)*T3 +
Gamma1*Gamma2*T4
R = risk(Gamma1,Gamma2,T11,T12,T21,T22,T3,T4)
R = risk(c(Gamma1,Gamma2),const,T11,T12,T21,T22,T3,T4)
expect_equal(R,R_hand, tolerance = 1e-15)
}
)

test_that("[DRAGON] get_shrunken_covariance_dragon() function returns the right values",{
# confirm that matches python results
myX = matrix(c(1,2,9,3,1,7,5,12,8),byrow=T,ncol=3)
X1 = as.matrix(myX[,1:2])
X2 = as.matrix(myX[,3])
lambdas = c(0.25,0.5)
res = get_shrunken_covariance_dragon(X1,X2,lambdas)
res_py = as.matrix(read.csv("tests/testthat/dragon-test-files/dragon_test_get_shrunken_covariance.csv",row.names=1))
expect_equal(as.vector(res),as.vector(res_py),tolerance = 1e-15)
}
)

#testing format
#test_that(,{})

0 comments on commit f7d35ca

Please sign in to comment.