-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a simple mode based imputation in python hail interface. Added …
…a better error message in case of missing genotypes and imputation not requested (#174)
- Loading branch information
Showing
8 changed files
with
2,181 additions
and
26 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python | ||
''' | ||
Created on 24 Jan 2018 | ||
@author: szu004 | ||
''' | ||
import os | ||
import pkg_resources | ||
import hail as hl | ||
import varspark as vs | ||
import varspark.hail as vshl | ||
from pyspark.sql import SparkSession | ||
|
||
PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) | ||
|
||
def main(): | ||
vshl.init() | ||
|
||
data = hl.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000_missing.vcf')) | ||
labels = hl.import_table(os.path.join(PROJECT_DIR, 'data/chr22-labels-hail.csv'), delimiter=',', | ||
types={'x22_16050408':'float64'}).key_by('sample') | ||
|
||
mt = data.annotate_cols(hipster = labels[data.s]) | ||
print(mt.count()) | ||
|
||
rf_model = vshl.random_forest_model(y=mt.hipster.x22_16050408, | ||
x=mt.GT.n_alt_alleles(), seed = 13, mtry_fraction = 0.05, | ||
min_node_size = 5, max_depth = 10, imputation_type = "mode") | ||
rf_model.fit_trees(100, 50) | ||
|
||
print("OOB error: %s" % rf_model.oob_error()) | ||
impTable = rf_model.variable_importance() | ||
impTable.show(3) | ||
|
||
rf_model.to_json(os.path.join(PROJECT_DIR, "target/chr22_1000_GRCh38-model.json"), True) | ||
|
||
rf_model.release() | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
src/main/scala/au/csiro/variantspark/input/ImputationStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package au.csiro.variantspark.input | ||
|
||
object Missing { | ||
val BYTE_NA_VALUE: Byte = (-1).toByte | ||
def isNA(v: Byte): Boolean = (BYTE_NA_VALUE == v) | ||
def isNotNA(v: Byte): Boolean = (BYTE_NA_VALUE != v) | ||
def replaceNA(v: Byte, inputedValue: Byte): Byte = { | ||
if (isNA(v)) inputedValue else v | ||
} | ||
} | ||
|
||
trait ImputationStrategy { | ||
def impute(data: Array[Byte]): Array[Byte] | ||
} | ||
|
||
case object DisabledImputationStrategy extends ImputationStrategy { | ||
override def impute(data: Array[Byte]): Array[Byte] = { | ||
if (data.exists(Missing.isNA)) { | ||
throw new IllegalArgumentException( | ||
"Missing values present in data but imputation is not enabled.") | ||
} | ||
data | ||
} | ||
} | ||
|
||
case object ZeroImputationStrategy extends ImputationStrategy { | ||
override def impute(data: Array[Byte]): Array[Byte] = data.map(Missing.replaceNA(_, 0.toByte)) | ||
} | ||
|
||
case class ModeImputationStrategy(noLevels: Int) extends ImputationStrategy { | ||
|
||
require(noLevels > 0) | ||
|
||
override def impute(data: Array[Byte]): Array[Byte] = { | ||
val counters = Array.ofDim[Int](noLevels) | ||
for (i <- data.indices) { | ||
if (Missing.isNotNA(data(i))) { | ||
counters(data(i)) += 1 | ||
} | ||
} | ||
val modeValue: Byte = counters.indices.maxBy(counters).toByte | ||
data.map(Missing.replaceNA(_, modeValue)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
src/test/scala/au/csiro/variantspark/input/ModeImputationStrategyTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package au.csiro.variantspark.input | ||
|
||
import au.csiro.variantspark.input.Missing.BYTE_NA_VALUE | ||
import org.junit.Assert._ | ||
import org.junit.Test; | ||
|
||
class ModeImputationStrategyTest { | ||
|
||
@Test | ||
def imputesEmptyArrayCorrectly() { | ||
assertArrayEquals(Array.emptyByteArray, | ||
ModeImputationStrategy(1).impute(Array.emptyByteArray)) | ||
} | ||
|
||
@Test | ||
def imputesAllMissingToZeros { | ||
assertArrayEquals(Array.fill(3)(0.toByte), | ||
ModeImputationStrategy(1).impute(Array.fill(3)(BYTE_NA_VALUE))) | ||
} | ||
@Test | ||
def imputesMissingToTheMode { | ||
assertArrayEquals(Array(1.toByte, 1.toByte, 1.toByte, 0.toByte, 1.toByte, 1.toByte), | ||
ModeImputationStrategy(3).impute(Array(BYTE_NA_VALUE, BYTE_NA_VALUE, BYTE_NA_VALUE, 0.toByte, 1.toByte, 1.toByte))) | ||
} | ||
|
||
@Test | ||
def imputesMissingToFirstMode { | ||
assertArrayEquals(Array(1.toByte, 1.toByte, 1.toByte, 0.toByte, 1.toByte, 1.toByte, 2.toByte, | ||
2.toByte), | ||
ModeImputationStrategy(3).impute(Array(BYTE_NA_VALUE, BYTE_NA_VALUE, BYTE_NA_VALUE, 0.toByte, 1.toByte, 1.toByte, 2.toByte, 2.toByte))) | ||
} | ||
} |