Skip to content

Commit

Permalink
Added a simple mode based imputation in python hail interface. Added …
Browse files Browse the repository at this point in the history
…a better error message in case of missing genotypes and imputation not requested (aehrc#174)
  • Loading branch information
piotrszul authored and BMJHayward committed Jan 15, 2021
1 parent e3d67d5 commit 008f0de
Show file tree
Hide file tree
Showing 8 changed files with 2,181 additions and 26 deletions.
1,999 changes: 1,999 additions & 0 deletions data/chr22_1000_missing.vcf

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions python/examples/chr21_missing_hail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
'''
Created on 24 Jan 2018
@author: szu004
'''
import os
import pkg_resources
import hail as hl
import varspark as vs
import varspark.hail as vshl
from pyspark.sql import SparkSession

PROJECT_DIR=os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))

def main():
vshl.init()

data = hl.import_vcf(os.path.join(PROJECT_DIR, 'data/chr22_1000_missing.vcf'))
labels = hl.import_table(os.path.join(PROJECT_DIR, 'data/chr22-labels-hail.csv'), delimiter=',',
types={'x22_16050408':'float64'}).key_by('sample')

mt = data.annotate_cols(hipster = labels[data.s])
print(mt.count())

rf_model = vshl.random_forest_model(y=mt.hipster.x22_16050408,
x=mt.GT.n_alt_alleles(), seed = 13, mtry_fraction = 0.05,
min_node_size = 5, max_depth = 10, imputation_type = "mode")
rf_model.fit_trees(100, 50)

print("OOB error: %s" % rf_model.oob_error())
impTable = rf_model.variable_importance()
impTable.show(3)

rf_model.to_json(os.path.join(PROJECT_DIR, "target/chr22_1000_GRCh38-model.json"), True)

rf_model.release()

if __name__ == '__main__':
main()

8 changes: 5 additions & 3 deletions python/varspark/hail/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
mtry_fraction=nullable(float),
min_node_size=nullable(int),
max_depth=nullable(int),
seed=nullable(int)
seed=nullable(int),
imputation_type=nullable(str)
)

def random_forest_model(y, x, covariates=(), oob=True, mtry_fraction=None,
min_node_size = None, max_depth=None, seed=None):
min_node_size = None, max_depth=None, seed=None, imputation_type=None):

mt = matrix_table_source('random_forest_model/x', x)
check_entry_indexed('random_forest_model/x', x)
Expand All @@ -44,4 +45,5 @@ def random_forest_model(y, x, covariates=(), oob=True, mtry_fraction=None,
mtry_fraction=mtry_fraction,
min_node_size = min_node_size,
max_depth = max_depth,
seed = seed)
seed = seed,
imputation_type = imputation_type)
7 changes: 4 additions & 3 deletions python/varspark/hail/rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ class RandomForestModel(object):
mtry_fraction=nullable(float),
min_node_size=nullable(int),
max_depth=nullable(int),
seed=nullable(int)
seed=nullable(int),
imputation_type=nullable(str)
)
def __init__(self,_mir, oob=True, mtry_fraction=None, min_node_size=None,
max_depth=None, seed=None):
max_depth=None, seed=None, imputation_type = None):
self._mir = _mir
self._jrf_model = Env.jvm().au.csiro.variantspark.hail.methods.RFModel.pyApply(
Env.spark_backend('rf')._to_java_ir(self._mir),
java.joption(mtry_fraction), oob, java.joption(min_node_size),
java.joption(max_depth), java.joption(seed))
java.joption(max_depth), java.joption(seed), java.joption(imputation_type))

@typecheck_method(
n_trees=int,
Expand Down
47 changes: 37 additions & 10 deletions src/main/scala/au/csiro/variantspark/hail/methods/RFModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import au.csiro.variantspark.algo.{
}
import au.csiro.variantspark.data.{BoundedOrdinalVariable, Feature, StdFeature}
import au.csiro.variantspark.external.ModelConverter
import au.csiro.variantspark.input.{
DisabledImputationStrategy,
ImputationStrategy,
Missing,
ModeImputationStrategy
}
import au.csiro.variantspark.utils.HdfsPath
import is.hail.annotations.Annotation
import is.hail.expr.ir.{Interpret, MatrixIR, MatrixValue, TableIR, TableLiteral, TableValue}
Expand All @@ -37,7 +43,8 @@ import scala.collection.IndexedSeq
* while the dependent variable is named `y`
* @param rfParams random forest parameters to use
*/
case class RFModel(mv: MatrixValue, rfParams: RandomForestParams) {
case class RFModel(mv: MatrixValue, rfParams: RandomForestParams,
imputationStrategy: Option[ImputationStrategy]) {

val responseVarName: String = "y"
val entryVarname: String = "e"
Expand All @@ -63,7 +70,9 @@ case class RFModel(mv: MatrixValue, rfParams: RandomForestParams) {
lazy val sig: TStruct = keySignature.insertFields(Array(("importance", TFloat64())))

lazy val rf: RandomForest = new RandomForest(rfParams)
val featuresRDD: RDD[Feature] = mv.rvd.toRows.map(RFModel.rowToFeature)

val featuresRDD: RDD[Feature] =
RFModel.mvToFeatureRDD(mv, imputationStrategy.getOrElse(DisabledImputationStrategy))
lazy val inputData: RDD[TreeFeature] =
DefTreeRepresentationFactory.createRepresentation(featuresRDD.zipWithIndex())

Expand Down Expand Up @@ -148,7 +157,7 @@ case class RFModel(mv: MatrixValue, rfParams: RandomForestParams) {
}

private def importanceMapBroadcast: Broadcast[Map[Long, Double]] = {
require(rfModel != null, "Traind the model first")
require(rfModel != null, "Train the model first")
if (impVarBroadcast != null) {
impVarBroadcast
} else {
Expand All @@ -175,19 +184,37 @@ object RFModel {
Row(Locus(elements(0), elements(1).toInt), alleles, impValue) // , elements.drop(2))
}

def rowToFeature(r: Row): Feature = {
def mvToFeatureRDD(mv: MatrixValue, imputationStrategy: ImputationStrategy): RDD[Feature] =
mv.rvd.toRows.map(rowToFeature(_, imputationStrategy))

def rowToFeature(r: Row, is: ImputationStrategy): Feature = {
val locus = r.getAs[Locus](0)
val varName =
(Seq(locus.contig, locus.position.toString) ++ r.getSeq[String](1)).mkString("_")
val data = r.getSeq[Row](2).map(_.getInt(0)).toArray
StdFeature.from(varName, BoundedOrdinalVariable(3), data)
// perform a rudimentary imputation but replacing missing values with 0
val data = r
.getSeq[Row](2)
.map(g => if (!g.isNullAt(0)) g.getInt(0).toByte else Missing.BYTE_NA_VALUE)
.toArray
StdFeature.from(varName, BoundedOrdinalVariable(3), is.impute(data))
}

def imputationFromString(imputationType: String): ImputationStrategy = {
imputationType match {
case "mode" => ModeImputationStrategy(3)
case _ =>
throw new IllegalArgumentException(
"Unknown imputation type: '" + imputationType + "'. Valid types are: 'mode'")
}

}

def pyApply(inputIR: MatrixIR, mTryFraction: Option[Double], oob: Boolean,
minNodeSize: Option[Int], maxDepth: Option[Int], seed: Option[Int]): RFModel = {
minNodeSize: Option[Int], maxDepth: Option[Int], seed: Option[Int],
imputationType: Option[String] = None): RFModel = {
var rfParams = RandomForestParams.fromOptions(mTryFraction = mTryFraction, oob = Some(oob),
minNodeSize = minNodeSize, maxDepth = maxDepth, seed = seed.map(_.longValue))
val mv = Interpret(inputIR)
RFModel(mv,
RandomForestParams.fromOptions(mTryFraction = mTryFraction, oob = Some(oob),
minNodeSize = minNodeSize, maxDepth = maxDepth, seed = seed.map(_.longValue)))
RFModel(mv, rfParams, imputationType.map(imputationFromString))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package au.csiro.variantspark.input

object Missing {
val BYTE_NA_VALUE: Byte = (-1).toByte
def isNA(v: Byte): Boolean = (BYTE_NA_VALUE == v)
def isNotNA(v: Byte): Boolean = (BYTE_NA_VALUE != v)
def replaceNA(v: Byte, inputedValue: Byte): Byte = {
if (isNA(v)) inputedValue else v
}
}

trait ImputationStrategy {
def impute(data: Array[Byte]): Array[Byte]
}

case object DisabledImputationStrategy extends ImputationStrategy {
override def impute(data: Array[Byte]): Array[Byte] = {
if (data.exists(Missing.isNA)) {
throw new IllegalArgumentException(
"Missing values present in data but imputation is not enabled.")
}
data
}
}

case object ZeroImputationStrategy extends ImputationStrategy {
override def impute(data: Array[Byte]): Array[Byte] = data.map(Missing.replaceNA(_, 0.toByte))
}

case class ModeImputationStrategy(noLevels: Int) extends ImputationStrategy {

require(noLevels > 0)

override def impute(data: Array[Byte]): Array[Byte] = {
val counters = Array.ofDim[Int](noLevels)
for (i <- data.indices) {
if (Missing.isNotNA(data(i))) {
counters(data(i)) += 1
}
}
val modeValue: Byte = counters.indices.maxBy(counters).toByte
data.map(Missing.replaceNA(_, modeValue))
}
}
29 changes: 19 additions & 10 deletions src/test/scala/au/csiro/variantspark/hail/HailIntegrationTest.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
package au.csiro.variantspark.hail

import au.csiro.variantspark.test.SparkTest
import org.junit.Test
import org.junit.Assert._

import au.csiro.variantspark.hail.methods.RFModel
import au.csiro.variantspark.test.{SparkTest, TestSparkContext}
import is.hail.HailContext
import is.hail.expr._
import au.csiro.variantspark.hail._
import au.csiro.variantspark.algo.metrics.ManhattanPairwiseMetric
import is.hail.expr.ir.IRParser
import au.csiro.variantspark.hail.methods.RFModel
import is.hail.table.Table
import is.hail.expr.ir.MatrixIR
import au.csiro.variantspark.test.TestSparkContext
import org.junit.Assert._
import org.junit.Test

object TestHailContext {
lazy val hc = HailContext(TestSparkContext.spark.sparkContext)
Expand Down Expand Up @@ -112,6 +106,21 @@ class HailIntegrationTest extends SparkTest {
assertEquals("All variables have reported importance", 1988, importanceTable.count())
rfModel.release()
}
@Test
def testRunImportanceAnalysisWithMissingCalls() {
val strMatrixIR = loadDataToMatrixIr("data/chr22_1000_missing.vcf",
"data/chr22-labels-hail.csv", "sample", "x22_16051480", "GRCh37")
val matrixIR = IRParser.parse_matrix_ir(strMatrixIR)
val rfModel = RFModel.pyApply(matrixIR, None, true, None, None, Some(13), Some("mode"))
rfModel.fitTrees(100, 50)
assertTrue("OOB Error is defined", !rfModel.oobError.isNaN)
val importanceTableValue = rfModel.variableImportance
val importanceTable = new Table(hc, importanceTableValue)
assertEquals(List("locus", "alleles", "importance"),
importanceTable.signature.fieldNames.toList)
assertEquals("All variables have reported importance", 1988, importanceTable.count())
rfModel.release()
}

@Test
def testRunImportanceAnalysisForGRCh38() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package au.csiro.variantspark.input

import au.csiro.variantspark.input.Missing.BYTE_NA_VALUE
import org.junit.Assert._
import org.junit.Test;

class ModeImputationStrategyTest {

@Test
def imputesEmptyArrayCorrectly() {
assertArrayEquals(Array.emptyByteArray,
ModeImputationStrategy(1).impute(Array.emptyByteArray))
}

@Test
def imputesAllMissingToZeros {
assertArrayEquals(Array.fill(3)(0.toByte),
ModeImputationStrategy(1).impute(Array.fill(3)(BYTE_NA_VALUE)))
}
@Test
def imputesMissingToTheMode {
assertArrayEquals(Array(1.toByte, 1.toByte, 1.toByte, 0.toByte, 1.toByte, 1.toByte),
ModeImputationStrategy(3).impute(Array(BYTE_NA_VALUE, BYTE_NA_VALUE, BYTE_NA_VALUE, 0.toByte, 1.toByte, 1.toByte)))
}

@Test
def imputesMissingToFirstMode {
assertArrayEquals(Array(1.toByte, 1.toByte, 1.toByte, 0.toByte, 1.toByte, 1.toByte, 2.toByte,
2.toByte),
ModeImputationStrategy(3).impute(Array(BYTE_NA_VALUE, BYTE_NA_VALUE, BYTE_NA_VALUE, 0.toByte, 1.toByte, 1.toByte, 2.toByte, 2.toByte)))
}
}

0 comments on commit 008f0de

Please sign in to comment.