From 1df2b75141919db2b959e2d4a0e7747ed232358c Mon Sep 17 00:00:00 2001 From: Frank Austin Nothaft Date: Wed, 8 Feb 2017 20:46:05 -0800 Subject: [PATCH] [ADAM-1018] Add support for Spark SQL Datasets. Resolves #1018. Adds the `adam-codegen` module, which generates classes that: 1. Implement the Scala Product interface and thus can be read into a Spark SQL Dataset. 2. Have a complete constructor that is compatible with the constructor that Spark SQL expects to see when exporting a Dataset back to Scala. 3. And, that have methods for converting to/from the bdg-formats Avro models. Then, we build these model classes in the `org.bdgenomics.adam.sql` package, and use them for export from the Avro based GenomicRDDs. With a Dataset, we can then export to a DataFrame, which enables us to expose data through Python via RDD->Dataset->DataFrame. This is important since the Avro classes generated by bdg-formats can't be pickled, and thus we can't do a Java RDD to Python RDD crossing with them. --- adam-apis/pom.xml | 4 + adam-cli/pom.xml | 4 + adam-codegen/pom.xml | 98 ++++++ .../adam/codegen/DumpSchemasToProduct.scala | 280 ++++++++++++++++++ adam-core/pom.xml | 42 +++ .../contig/NucleotideContigFragmentRDD.scala | 24 ++ .../adam/rdd/feature/FeatureRDD.scala | 24 ++ .../adam/rdd/fragment/FragmentRDD.scala | 24 ++ .../adam/rdd/read/AlignmentRecordRDD.scala | 24 ++ .../adam/rdd/variant/GenotypeRDD.scala | 24 ++ .../adam/rdd/variant/VariantRDD.scala | 24 ++ .../serialization/ADAMKryoRegistrator.scala | 5 + .../adam/rdd/ADAMContextSuite.scala | 14 +- .../adam/rdd/fragment/FragmentRDDSuite.scala | 2 + adam-python/src/bdgenomics/adam/rdd.py | 50 +++- .../bdgenomics/adam/test/adamContext_test.py | 8 + pom.xml | 25 ++ 17 files changed, 671 insertions(+), 5 deletions(-) create mode 100644 adam-codegen/pom.xml create mode 100644 adam-codegen/src/main/scala/org/bdgenomics/adam/codegen/DumpSchemasToProduct.scala diff --git a/adam-apis/pom.xml b/adam-apis/pom.xml index 38df5c197b..25edd12acd 100644 --- a/adam-apis/pom.xml +++ b/adam-apis/pom.xml @@ -142,5 +142,9 @@ scalatest_${scala.version.prefix} test + + org.apache.spark + spark-sql_${scala.version.prefix} + diff --git a/adam-cli/pom.xml b/adam-cli/pom.xml index aef9762cb7..b54c1f5c0f 100644 --- a/adam-cli/pom.xml +++ b/adam-cli/pom.xml @@ -193,5 +193,9 @@ scala-guice_${scala.version.prefix} compile + + org.apache.spark + spark-sql_${scala.version.prefix} + diff --git a/adam-codegen/pom.xml b/adam-codegen/pom.xml new file mode 100644 index 0000000000..764eeb3d7b --- /dev/null +++ b/adam-codegen/pom.xml @@ -0,0 +1,98 @@ + + + 4.0.0 + + org.bdgenomics.adam + adam-parent_2.10 + 0.23.0-SNAPSHOT + ../pom.xml + + + adam-codegen_2.10 + jar + ADAM_${scala.version.prefix}: Avro-to-Dataset codegen utils + + ${maven.build.timestamp} + yyyy-MM-dd + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-source + generate-sources + + add-source + + + + src/main/scala + + + + + add-test-source + generate-test-sources + + add-test-source + + + + src/test/scala + + + + + + + org.scalatest + scalatest-maven-plugin + + ${project.build.directory}/scalatest-reports + . + ADAMTestSuite.txt + + -Xmx1024m -Dsun.io.serialization.extendedDebugInfo=true + F + + + + test + + test + + + + + + + + + org.scala-lang + scala-library + + + org.apache.avro + avro + + + org.scalatest + scalatest_${scala.version.prefix} + test + + + diff --git a/adam-codegen/src/main/scala/org/bdgenomics/adam/codegen/DumpSchemasToProduct.scala b/adam-codegen/src/main/scala/org/bdgenomics/adam/codegen/DumpSchemasToProduct.scala new file mode 100644 index 0000000000..ab87d2b181 --- /dev/null +++ b/adam-codegen/src/main/scala/org/bdgenomics/adam/codegen/DumpSchemasToProduct.scala @@ -0,0 +1,280 @@ +/** + * Licensed to Big Data Genomics (BDG) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The BDG licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.bdgenomics.adam.codegen + +import java.io.{ File, FileWriter } +import org.apache.avro.Schema +import org.apache.avro.reflect.ReflectData +import scala.collection.JavaConversions._ + +object DumpSchemasToProduct { + + def main(args: Array[String]) { + new DumpSchemasToProduct()(args) + } +} + +class DumpSchemasToProduct { + + private def getSchemaByReflection(className: String): Schema = { + + // load the class + val classLoader = Thread.currentThread().getContextClassLoader() + val klazz = classLoader.loadClass(className) + + // get the schema through reflection + ReflectData.get().getSchema(klazz) + } + + private def toMatch(fields: Seq[(String, String)]): String = { + fields.map(_._1) + .zipWithIndex + .map(vk => { + val (field, idx) = vk + " case %d => %s".format(idx, field) + }).mkString("\n") + } + + private def getType(schema: Schema): String = schema.getType match { + case Schema.Type.DOUBLE => "Double" + case Schema.Type.FLOAT => "Float" + case Schema.Type.INT => "Int" + case Schema.Type.LONG => "Long" + case Schema.Type.BOOLEAN => "Boolean" + case Schema.Type.STRING => "String" + case Schema.Type.ENUM => "String" + case Schema.Type.RECORD => schema.getName() + case other => throw new IllegalStateException("Unsupported type %s.".format(other)) + } + + private def getUnionType(schema: Schema): Schema = { + val unionTypes = schema.getTypes() + .filter(t => { + t.getType != Schema.Type.NULL + }) + assert(unionTypes.size == 1) + unionTypes.head + } + + private def fields(schema: Schema): Seq[(String, String)] = { + schema.getFields() + .map(field => { + val name = field.name + val fieldSchema = field.schema + val fieldType = fieldSchema.getType match { + case Schema.Type.ARRAY => { + "Seq[%s]".format(getType(fieldSchema.getElementType())) + } + case Schema.Type.MAP => { + "scala.collection.Map[String,%s]".format(getType(fieldSchema.getValueType())) + } + case Schema.Type.UNION => { + "Option[%s]".format(getType(getUnionType(fieldSchema))) + } + case other => { + throw new IllegalStateException("Unsupported type %s in field %s.".format(other, name)) + } + } + (name, fieldType) + }).toSeq + } + + private def conversion(schema: Schema, mapFn: String): String = schema.getType match { + case Schema.Type.DOUBLE => ".%s(d => d: java.lang.Double)".format(mapFn) + case Schema.Type.FLOAT => ".%s(f => f: java.lang.Float)".format(mapFn) + case Schema.Type.INT => ".%s(i => i: java.lang.Integer)".format(mapFn) + case Schema.Type.LONG => ".%s(l => l: java.lang.Long)".format(mapFn) + case Schema.Type.BOOLEAN => ".%s(b => b: java.lang.Boolean)".format(mapFn) + case Schema.Type.STRING => "" + case Schema.Type.ENUM => ".%s(e => %s.valueOf(e))".format(mapFn, schema.getFullName) + case Schema.Type.RECORD => ".%s(r => r.toAvro)".format(mapFn) + case other => throw new IllegalStateException("Unsupported type %s.".format(other)) + } + + private def setters(schema: Schema): String = { + schema.getFields + .map(field => { + val name = field.name + + field.schema.getType match { + case Schema.Type.UNION => { + getUnionType(field.schema).getType match { + case Schema.Type.RECORD => " %s.foreach(field => record.set%s(field.toAvro))".format(name, name.capitalize) + case Schema.Type.ENUM => " %s.foreach(field => record.set%s(%s.valueOf(field)))".format(name, name.capitalize, getUnionType(field.schema).getFullName) + case Schema.Type.DOUBLE | Schema.Type.FLOAT | + Schema.Type.INT | Schema.Type.LONG | + Schema.Type.BOOLEAN | Schema.Type.STRING => " %s.foreach(field => record.set%s(field))".format(name, name.capitalize) + case other => throw new IllegalStateException("Unsupported type %s.".format(other)) + } + } + case Schema.Type.ARRAY => { + val convAction = conversion(field.schema.getElementType(), "map") + " if (%s.nonEmpty) {\n record.set%s(%s%s)\n }".format(name, name.capitalize, name, convAction) + } + case Schema.Type.MAP => { + val convAction = conversion(field.schema.getValueType(), "mapValues") + " if (%s.nonEmpty) {\n record.set%s(%s%s.asJava)\n }".format(name, name.capitalize, name, convAction) + } + case _ => { + throw new IllegalArgumentException("Bad type %s.".format(field.schema)) + } + } + }).mkString("\n") + } + + private def dumpToAvroFn(schema: Schema): String = { + " val record = new %s()\n%s\n record".format(schema.getFullName, + setters(schema)) + } + + private def generateClassDump(className: String): String = { + + // get schema + val schema = getSchemaByReflection(className) + + // get class name without package + val classNameNoPackage = className.split('.').last + + "\n%s\n\nclass %s (\n%s) extends Product {\n def productArity: Int = %d\n def productElement(i: Int): Any = i match {\n%s\n }\n def toAvro: %s = {\n%s\n }\n def canEqual(that: Any): Boolean = that match {\n case %s => true\n case _ => false\n }\n}".format( + dumpObject(schema), + classNameNoPackage, + fields(schema).map(p => " val %s: %s".format(p._1, p._2)).mkString(",\n"), + schema.getFields().size, + toMatch(fields(schema)), + schema.getFullName, + dumpToAvroFn(schema), + classNameNoPackage + ) + } + + private def getConversion(schema: Schema, mapFn: String): String = schema.getType match { + case Schema.Type.DOUBLE => ".%s(d => d: Double)".format(mapFn) + case Schema.Type.FLOAT => ".%s(f => f: Float)".format(mapFn) + case Schema.Type.INT => ".%s(i => i: Int)".format(mapFn) + case Schema.Type.LONG => ".%s(l => l: Long)".format(mapFn) + case Schema.Type.BOOLEAN => ".%s(b => b: Boolean)".format(mapFn) + case Schema.Type.STRING => "" + case Schema.Type.ENUM => ".%s(e => e.toString)".format(mapFn) + case Schema.Type.RECORD => ".%s(r => %s.fromAvro(r))".format(mapFn, schema.getName) + case other => throw new IllegalStateException("Unsupported type %s.".format(other)) + } + + private def getters(schema: Schema): String = { + schema.getFields + .map(field => { + val name = field.name + + field.schema.getType match { + case Schema.Type.UNION => { + getUnionType(field.schema).getType match { + case Schema.Type.RECORD => " Option(record.get%s).map(field => %s.fromAvro(field))".format(name.capitalize, getUnionType(field.schema).getName) + case Schema.Type.ENUM => " Option(record.get%s).map(field => field.toString)".format(name.capitalize) + case Schema.Type.DOUBLE | Schema.Type.FLOAT | + Schema.Type.INT | Schema.Type.LONG | + Schema.Type.BOOLEAN | Schema.Type.STRING => " Option(record.get%s)%s".format(name.capitalize, getConversion(getUnionType(field.schema), "map")) + case other => throw new IllegalStateException("Unsupported type %s.".format(other)) + } + } + case Schema.Type.ARRAY => { + val convAction = getConversion(field.schema.getElementType(), "map") + " record.get%s().toSeq%s".format(name.capitalize, convAction) + } + case Schema.Type.MAP => { + val convAction = getConversion(field.schema.getValueType(), "mapValues") + " record.get%s()%s.asScala".format(name.capitalize, convAction) + } + case _ => { + throw new IllegalArgumentException("Bad type %s.".format(field.schema)) + } + } + }).mkString(",\n") + } + + private def dumpObject(schema: Schema): String = { + "object %s extends Serializable {\n def apply(\n%s): %s = {\n new %s(\n%s)\n }\n def fromAvro(record: %s): %s = {\n new %s (\n%s)\n }\n}".format( + schema.getName, + fields(schema).map(p => " %s: %s".format(p._1, p._2)).mkString(",\n"), + schema.getName, + schema.getName, + fields(schema).map(_._1).map(s => " %s".format(s)).mkString(",\n"), + schema.getFullName, + schema.getName, + schema.getName, + getters(schema)) + } + + private def writeHeader(fw: FileWriter, packageName: String) { + val hdr = Seq( + "/**", + "* Licensed to Big Data Genomics (BDG) under one", + "* or more contributor license agreements. See the NOTICE file", + "* distributed with this work for additional information", + "* regarding copyright ownership. The BDG licenses this file", + "* to you under the Apache License, Version 2.0 (the", + "* \"License\"); you may not use this file except in compliance", + "* with the License. You may obtain a copy of the License at", + "*", + "* http://www.apache.org/licenses/LICENSE-2.0", + "*", + "* Unless required by applicable law or agreed to in writing, software", + "* distributed under the License is distributed on an \"AS IS\" BASIS,", + "* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "* See the License for the specific language governing permissions and", + "* limitations under the License.", + "*/", + "package %s".format(packageName), + "", + "import scala.collection.JavaConversions._", + "import scala.collection.JavaConverters._").mkString("\n") + + fw.write(hdr) + } + + def apply(args: Array[String]) { + + if (args.length < 3) { + println("DumpSchemas ... ") + System.exit(1) + } else { + + // drop the file to write and the package name + val classesToDump = args.drop(1).dropRight(1) + + // open the file to write + val dir = new File(args.last).getParentFile + if (!dir.exists()) { + dir.mkdirs() + } + val fw = new FileWriter(args.last) + + // write the header + writeHeader(fw, args.head) + + // loop and dump the classes + classesToDump.foreach(className => { + val dumpString = generateClassDump(className) + + fw.write("\n") + fw.write(dumpString) + }) + + // we are done, so close and flush + fw.close() + } + } +} diff --git a/adam-core/pom.xml b/adam-core/pom.xml index 0a47a2f4a7..568f4e002e 100644 --- a/adam-core/pom.xml +++ b/adam-core/pom.xml @@ -69,6 +69,7 @@ src/main/scala + target/generated-sources/src/main/scala @@ -86,6 +87,39 @@ + + org.codehaus.mojo + exec-maven-plugin + + + generate-scala-products + generate-sources + + java + + + org.bdgenomics.adam.codegen.DumpSchemasToProduct + + org.bdgenomics.adam.sql + org.bdgenomics.formats.avro.AlignmentRecord + org.bdgenomics.formats.avro.Contig + org.bdgenomics.formats.avro.Dbxref + org.bdgenomics.formats.avro.Feature + org.bdgenomics.formats.avro.Fragment + org.bdgenomics.formats.avro.Genotype + org.bdgenomics.formats.avro.NucleotideContigFragment + org.bdgenomics.formats.avro.OntologyTerm + org.bdgenomics.formats.avro.TranscriptEffect + org.bdgenomics.formats.avro.Variant + org.bdgenomics.formats.avro.VariantAnnotation + org.bdgenomics.formats.avro.VariantCallingAnnotations + adam-core/target/generated-sources/src/main/scala/org/bdgenomics/adam/sql/Schemas.scala + + compile + + + + @@ -146,6 +180,10 @@ spark-core_${scala.version.prefix} + + org.apache.spark + spark-sql_${scala.version.prefix} + it.unimi.dsi fastutil @@ -201,5 +239,9 @@ guava compile + + org.bdgenomics.adam + adam-codegen_${scala.version.prefix} + diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/contig/NucleotideContigFragmentRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/contig/NucleotideContigFragmentRDD.scala index 384cab56e4..8cd6c98c51 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/contig/NucleotideContigFragmentRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/contig/NucleotideContigFragmentRDD.scala @@ -20,6 +20,7 @@ package org.bdgenomics.adam.rdd.contig import com.google.common.base.Splitter import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.bdgenomics.adam.converters.FragmentConverter import org.bdgenomics.adam.models.{ ReferenceRegion, @@ -30,6 +31,7 @@ import org.bdgenomics.adam.models.{ import org.bdgenomics.adam.rdd.{ AvroGenomicRDD, JavaSaveArgs } import org.bdgenomics.adam.serialization.AvroSerializer import org.bdgenomics.adam.util.ReferenceFile +import org.bdgenomics.adam.sql.{ NucleotideContigFragment => NucleotideContigFragmentProduct } import org.bdgenomics.formats.avro.{ AlignmentRecord, NucleotideContigFragment } import org.bdgenomics.utils.interval.array.{ IntervalArray, @@ -147,6 +149,28 @@ case class NucleotideContigFragmentRDD( ReferenceRegion(elem).toSeq } + /** + * @return Creates a SQL Dataset of contig fragments. + */ + def toDataset(): Dataset[NucleotideContigFragmentProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(NucleotideContigFragmentProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[NucleotideContigFragmentProduct] => Dataset[NucleotideContigFragmentProduct]): NucleotideContigFragmentRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Save nucleotide contig fragments as Parquet or FASTA. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/feature/FeatureRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/feature/FeatureRDD.scala index f588a86b58..621dfe9ce1 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/feature/FeatureRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/feature/FeatureRDD.scala @@ -21,6 +21,7 @@ import com.google.common.collect.ComparisonChain import java.util.Comparator import org.apache.hadoop.fs.{ FileSystem, Path } import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.apache.spark.storage.StorageLevel import org.bdgenomics.adam.instrumentation.Timers._ import org.bdgenomics.adam.models._ @@ -31,6 +32,7 @@ import org.bdgenomics.adam.rdd.{ SAMHeaderWriter } import org.bdgenomics.adam.serialization.AvroSerializer +import org.bdgenomics.adam.sql.{ Feature => FeatureProduct } import org.bdgenomics.formats.avro.{ Feature, Strand } import org.bdgenomics.utils.interval.array.{ IntervalArray, @@ -256,6 +258,28 @@ case class FeatureRDD(rdd: RDD[Feature], iterableRdds.map(_.sequences).fold(sequences)(_ ++ _)) } + /** + * @return Creates a SQL Dataset of genotypes. + */ + def toDataset(): Dataset[FeatureProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(FeatureProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[FeatureProduct] => Dataset[FeatureProduct]): FeatureRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Java friendly save function. Automatically detects the output format. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDD.scala index 8a49396463..fdca62726c 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDD.scala @@ -18,6 +18,7 @@ package org.bdgenomics.adam.rdd.fragment import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.bdgenomics.adam.converters.AlignmentRecordConverter import org.bdgenomics.adam.instrumentation.Timers._ import org.bdgenomics.adam.models.{ @@ -34,6 +35,7 @@ import org.bdgenomics.adam.rdd.read.{ QualityScoreBin } import org.bdgenomics.adam.serialization.AvroSerializer +import org.bdgenomics.adam.sql.{ Fragment => FragmentProduct } import org.bdgenomics.formats.avro._ import org.bdgenomics.utils.interval.array.{ IntervalArray, @@ -132,6 +134,28 @@ case class FragmentRDD(rdd: RDD[Fragment], iterableRdds.map(_.recordGroups).fold(recordGroups)(_ ++ _)) } + /** + * @return Creates a SQL Dataset of fragments. + */ + def toDataset(): Dataset[FragmentProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(FragmentProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[FragmentProduct] => Dataset[FragmentProduct]): FragmentRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Essentially, splits up the reads in a Fragment. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDD.scala index f92b48104c..214e52627f 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/AlignmentRecordRDD.scala @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.MetricsContext._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.apache.spark.storage.StorageLevel import org.bdgenomics.adam.algorithms.consensus.{ ConsensusGenerator, @@ -51,6 +52,7 @@ import org.bdgenomics.adam.rdd.read.realignment.RealignIndels import org.bdgenomics.adam.rdd.read.recalibration.BaseQualityRecalibration import org.bdgenomics.adam.rdd.fragment.FragmentRDD import org.bdgenomics.adam.rdd.variant.VariantRDD +import org.bdgenomics.adam.sql.{ AlignmentRecord => AlignmentRecordProduct } import org.bdgenomics.adam.serialization.AvroSerializer import org.bdgenomics.adam.util.ReferenceFile import org.bdgenomics.formats.avro._ @@ -131,6 +133,28 @@ case class AlignmentRecordRDD( sequences: SequenceDictionary, recordGroups: RecordGroupDictionary) extends AvroReadGroupGenomicRDD[AlignmentRecord, AlignmentRecordRDD] { + /** + * @return Creates a SQL Dataset of reads. + */ + def toDataset(): Dataset[AlignmentRecordProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(AlignmentRecordProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[AlignmentRecordProduct] => Dataset[AlignmentRecordProduct]): AlignmentRecordRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Replaces the underlying RDD and SequenceDictionary and emits a new object. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/GenotypeRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/GenotypeRDD.scala index 7803fb1350..b67fffa532 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/GenotypeRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/GenotypeRDD.scala @@ -20,6 +20,7 @@ package org.bdgenomics.adam.rdd.variant import htsjdk.samtools.ValidationStringency import htsjdk.variant.vcf.VCFHeaderLine import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.bdgenomics.adam.converters.DefaultHeaderLines import org.bdgenomics.adam.models.{ ReferencePosition, @@ -31,6 +32,7 @@ import org.bdgenomics.adam.models.{ import org.bdgenomics.adam.rdd.{ JavaSaveArgs, MultisampleAvroGenomicRDD } import org.bdgenomics.adam.rich.RichVariant import org.bdgenomics.adam.serialization.AvroSerializer +import org.bdgenomics.adam.sql.{ Genotype => GenotypeProduct } import org.bdgenomics.utils.cli.SaveArgs import org.bdgenomics.utils.interval.array.{ IntervalArray, @@ -92,6 +94,28 @@ case class GenotypeRDD(rdd: RDD[Genotype], IntervalArray(rdd, GenotypeArray.apply(_, _)) } + /** + * @return Creates a SQL Dataset of genotypes. + */ + def toDataset(): Dataset[GenotypeProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(GenotypeProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[GenotypeProduct] => Dataset[GenotypeProduct]): GenotypeRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Java-friendly method for saving. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/VariantRDD.scala b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/VariantRDD.scala index a1bb8f3ee3..39dd5fc113 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/VariantRDD.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/rdd/variant/VariantRDD.scala @@ -21,6 +21,7 @@ import htsjdk.samtools.ValidationStringency import htsjdk.variant.vcf.{ VCFHeader, VCFHeaderLine } import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ Dataset, SQLContext } import org.bdgenomics.adam.converters.DefaultHeaderLines import org.bdgenomics.adam.models.{ ReferenceRegion, @@ -34,6 +35,7 @@ import org.bdgenomics.adam.rdd.{ VCFHeaderUtils } import org.bdgenomics.adam.serialization.AvroSerializer +import org.bdgenomics.adam.sql.{ Variant => VariantProduct } import org.bdgenomics.formats.avro.{ Contig, Sample, @@ -111,6 +113,28 @@ case class VariantRDD(rdd: RDD[Variant], (headerLines ++ iterableRdds.flatMap(_.headerLines)).distinct) } + /** + * @return Creates a SQL Dataset of variants. + */ + def toDataset(): Dataset[VariantProduct] = { + val sqlContext = SQLContext.getOrCreate(rdd.context) + import sqlContext.implicits._ + sqlContext.createDataset(rdd.map(VariantProduct.fromAvro)) + } + + /** + * Applies a function that transforms the underlying RDD into a new RDD using + * the Spark SQL API. + * + * @param tFn A function that transforms the underlying RDD as a Dataset. + * @return A new RDD where the RDD of genomic data has been replaced, but the + * metadata (sequence dictionary, and etc) is copied without modification. + */ + def transformDataset( + tFn: Dataset[VariantProduct] => Dataset[VariantProduct]): VariantRDD = { + replaceRdd(tFn(toDataset()).rdd.map(_.toAvro)) + } + /** * Java-friendly method for saving to Parquet. * diff --git a/adam-core/src/main/scala/org/bdgenomics/adam/serialization/ADAMKryoRegistrator.scala b/adam-core/src/main/scala/org/bdgenomics/adam/serialization/ADAMKryoRegistrator.scala index f15eacd683..cd7e8581eb 100644 --- a/adam-core/src/main/scala/org/bdgenomics/adam/serialization/ADAMKryoRegistrator.scala +++ b/adam-core/src/main/scala/org/bdgenomics/adam/serialization/ADAMKryoRegistrator.scala @@ -277,10 +277,15 @@ class ADAMKryoRegistrator extends KryoRegistrator { kryo.register(classOf[org.codehaus.jackson.node.BooleanNode]) kryo.register(classOf[org.codehaus.jackson.node.TextNode]) + // org.apache.spark + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.UnsafeRow]) + // scala + kryo.register(classOf[scala.Array[scala.Array[Byte]]]) kryo.register(classOf[scala.Array[htsjdk.variant.vcf.VCFHeader]]) kryo.register(classOf[scala.Array[java.lang.Long]]) kryo.register(classOf[scala.Array[java.lang.Object]]) + kryo.register(classOf[scala.Array[org.apache.spark.sql.catalyst.InternalRow]]) kryo.register(classOf[scala.Array[org.bdgenomics.formats.avro.AlignmentRecord]]) kryo.register(classOf[scala.Array[org.bdgenomics.formats.avro.Contig]]) kryo.register(classOf[scala.Array[org.bdgenomics.formats.avro.Dbxref]]) diff --git a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/ADAMContextSuite.scala b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/ADAMContextSuite.scala index bdf13eb4bf..fb4e4f5eab 100644 --- a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/ADAMContextSuite.scala +++ b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/ADAMContextSuite.scala @@ -78,8 +78,10 @@ class ADAMContextSuite extends ADAMFunSuite { sparkTest("can read a small .SAM file") { val path = testFile("small.sam") - val reads: RDD[AlignmentRecord] = sc.loadAlignments(path).rdd - assert(reads.count() === 20) + val reads = sc.loadAlignments(path) + assert(reads.rdd.count() === 20) + assert(reads.toDataset.count === 20) + assert(reads.toDataset.rdd.count === 20) } sparkTest("loading a sam file with a bad header and strict stringency should fail") { @@ -125,8 +127,10 @@ class ADAMContextSuite extends ADAMFunSuite { sparkTest("Can read a .gtf file") { val path = testFile("Homo_sapiens.GRCh37.75.trun20.gtf") - val features: RDD[Feature] = sc.loadFeatures(path).rdd - assert(features.count === 15) + val features = sc.loadFeatures(path) + assert(features.rdd.count === 15) + assert(features.toDataset.count === 15) + assert(features.toDataset.rdd.count === 15) } sparkTest("Can read a .bed file") { @@ -291,6 +295,8 @@ class ADAMContextSuite extends ADAMFunSuite { val variants = sc.loadVariants(path) assert(variants.rdd.count === 681) + assert(variants.toDataset.count === 681) + assert(variants.toDataset.rdd.count === 681) val loc = tmpLocation() variants.saveAsParquet(loc, 1024, 1024) // force more than one row group (block) diff --git a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDDSuite.scala b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDDSuite.scala index 33dc1e9e1f..c598be313a 100644 --- a/adam-core/src/test/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDDSuite.scala +++ b/adam-core/src/test/scala/org/bdgenomics/adam/rdd/fragment/FragmentRDDSuite.scala @@ -36,6 +36,8 @@ class FragmentRDDSuite extends ADAMFunSuite { val ardd = sc.loadFragments(fragmentsPath) val records = ardd.rdd.count assert(records === 3) + assert(ardd.toDataset.count === 3) + assert(ardd.toDataset.rdd.count === 3) implicit val tFormatter = InterleavedFASTQInFormatter implicit val uFormatter = new AnySAMOutFormatter diff --git a/adam-python/src/bdgenomics/adam/rdd.py b/adam-python/src/bdgenomics/adam/rdd.py index d5875f59da..22baa15278 100644 --- a/adam-python/src/bdgenomics/adam/rdd.py +++ b/adam-python/src/bdgenomics/adam/rdd.py @@ -18,7 +18,7 @@ from pyspark.rdd import RDD - +from pyspark.sql import DataFrame, SQLContext from bdgenomics.adam.stringency import LENIENT, _toJava @@ -115,6 +115,14 @@ def _replaceRdd(self, newRdd): return AlignmentRecordRDD(newRdd, self.sc) + + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + def toFragments(self): """ @@ -550,6 +558,14 @@ def __init__(self, jvmRdd, sc): GenomicRDD.__init__(self, jvmRdd, sc) + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + + def save(self, filePath, asSingleFile = False): """ Saves coverage, autodetecting the file type from the extension. @@ -600,6 +616,14 @@ def __init__(self, jvmRdd): GenomicRDD.__init__(self, jvmRdd, sc) + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + + def toReads(self): """ Splits up the reads in a Fragment, and creates a new RDD. @@ -653,6 +677,14 @@ def __init__(self, jvmRdd, sc): GenomicRDD.__init__(self, jvmRdd, sc) + + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + def save(self, filePath): """ @@ -720,6 +752,14 @@ def __init__(self, jvmRdd, sc): GenomicRDD.__init__(self, jvmRdd, sc) + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + + def save(self, fileName): """ Save nucleotide contig fragments as Parquet or FASTA. @@ -783,6 +823,14 @@ def __init__(self, jvmRdd, sc): GenomicRDD.__init__(self, jvmRdd, sc) + def toDataFrame(self): + """ + :return: Returns a dataframe representing this RDD. + """ + + return DataFrame(self._jvmRdd.toDataset().toDF(), SQLContext(self.sc)) + + def save(self, filePath): """ Saves this RDD of variants to disk. diff --git a/adam-python/src/bdgenomics/adam/test/adamContext_test.py b/adam-python/src/bdgenomics/adam/test/adamContext_test.py index 7144f567a8..5d4345ba9a 100644 --- a/adam-python/src/bdgenomics/adam/test/adamContext_test.py +++ b/adam-python/src/bdgenomics/adam/test/adamContext_test.py @@ -31,6 +31,7 @@ def test_load_alignments(self): reads = ac.loadAlignments(testFile) + self.assertEqual(reads.toDataFrame().count(), 20) self.assertEqual(reads._jvmRdd.jrdd().count(), 20) @@ -41,6 +42,7 @@ def test_load_gtf(self): reads = ac.loadFeatures(testFile) + self.assertEqual(reads.toDataFrame().count(), 15) self.assertEqual(reads._jvmRdd.jrdd().count(), 15) @@ -51,6 +53,7 @@ def test_load_bed(self): reads = ac.loadFeatures(testFile) + self.assertEqual(reads.toDataFrame().count(), 10) self.assertEqual(reads._jvmRdd.jrdd().count(), 10) @@ -61,6 +64,7 @@ def test_load_narrowPeak(self): reads = ac.loadFeatures(testFile) + self.assertEqual(reads.toDataFrame().count(), 10) self.assertEqual(reads._jvmRdd.jrdd().count(), 10) @@ -72,6 +76,7 @@ def test_load_interval_list(self): reads = ac.loadFeatures(testFile) + self.assertEqual(reads.toDataFrame().count(), 369) self.assertEqual(reads._jvmRdd.jrdd().count(), 369) @@ -83,6 +88,7 @@ def test_load_genotypes(self): reads = ac.loadGenotypes(testFile) + self.assertEqual(reads.toDataFrame().count(), 18) self.assertEqual(reads._jvmRdd.jrdd().count(), 18) @@ -94,6 +100,7 @@ def test_load_variants(self): reads = ac.loadVariants(testFile) + self.assertEqual(reads.toDataFrame().count(), 6) self.assertEqual(reads._jvmRdd.jrdd().count(), 6) @@ -105,4 +112,5 @@ def test_load_sequence(self): reads = ac.loadSequence(testFile) + self.assertEqual(reads.toDataFrame().count(), 1) self.assertEqual(reads._jvmRdd.jrdd().count(), 1) diff --git a/pom.xml b/pom.xml index b582fd76f3..7d4d517882 100644 --- a/pom.xml +++ b/pom.xml @@ -37,6 +37,7 @@ + adam-codegen adam-core adam-apis adam-cli @@ -461,6 +462,24 @@ + + net.razorvine + pyrolite + 4.9 + provided + + + net.razorvine + serpent + + + + + org.apache.spark + spark-sql_${scala.version.prefix} + ${spark.version} + provided + org.apache.spark spark-core_${scala.version.prefix} @@ -579,6 +598,11 @@ scala-guice_${scala.version.prefix} 4.1.0 + + org.bdgenomics.adam + adam-codegen_${scala.version.prefix} + ${project.version} + @@ -586,6 +610,7 @@ python + adam-codegen adam-core adam-apis adam-python