diff --git a/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala b/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala index 8a492834a3..9ed7ea5634 100644 --- a/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala +++ b/src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala @@ -532,7 +532,10 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP val droppedDF = outputDF.drop(outputDF.columns.filter(_.startsWith(coercionPrefix)): _*) val unbatchedDF = if (getBatchInput) { - new FlattenBatch().transform(droppedDF) + // TODO: The cache call is a workaround for issue 1075: + // https://github.com/Azure/mmlspark/issues/1075 + val cacheAttempted = if (droppedDF.isStreaming) droppedDF else droppedDF.cache() + new FlattenBatch().transform(cacheAttempted) } else { droppedDF }