Skip to content

Commit 21e91c7

Browse files
author
Niketan Pansare
committed
Added BaseSystemMLClassifier and updated the classifier to use new
MLContext
1 parent 65eb888 commit 21e91c7

File tree

13 files changed

+479
-424
lines changed

13 files changed

+479
-424
lines changed

src/main/java/org/apache/sysml/api/MLContext.java

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
6666
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
6767
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
68+
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
6869
import org.apache.sysml.runtime.instructions.Instruction;
6970
import org.apache.sysml.runtime.instructions.cp.Data;
7071
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
@@ -476,25 +477,6 @@ public void registerInput(String varName, RDD<String> rdd, String format, long r
476477
registerInput(varName, rdd.toJavaRDD().mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, nnz, null);
477478
}
478479

479-
public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
480-
MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, mb.getNonZeros());
481-
registerInput(varName, mb, mc);
482-
}
483-
484-
public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
485-
if(_variables == null)
486-
_variables = new LocalVariableMap();
487-
if(_inVarnames == null)
488-
_inVarnames = new ArrayList<String>();
489-
490-
MatrixObject mo = new MatrixObject(ValueType.DOUBLE, "temp", new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
491-
mo.acquireModify(mb);
492-
mo.release();
493-
_variables.put(varName, mo);
494-
_inVarnames.add(varName);
495-
checkIfRegisteringInputAllowed();
496-
}
497-
498480
// All CSV related methods call this ... It provides access to dimensions, nnz, file properties.
499481
private void registerInput(String varName, JavaPairRDD<LongWritable, Text> textOrCsv_rdd, String format, long rlen, long clen, long nnz, FileFormatProperties props) throws DMLRuntimeException {
500482
if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
@@ -618,6 +600,24 @@ public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock>
618600
checkIfRegisteringInputAllowed();
619601
}
620602

603+
public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
604+
MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, mb.getNonZeros());
605+
registerInput(varName, mb, mc);
606+
}
607+
608+
public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
609+
if(_variables == null)
610+
_variables = new LocalVariableMap();
611+
if(_inVarnames == null)
612+
_inVarnames = new ArrayList<String>();
613+
MatrixObject mo = new MatrixObject(ValueType.DOUBLE, "temp", new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
614+
mo.acquireModify(mb);
615+
mo.release();
616+
_variables.put(varName, mo);
617+
_inVarnames.add(varName);
618+
checkIfRegisteringInputAllowed();
619+
}
620+
621621
// =============================================================================================
622622

623623
/**
@@ -1240,56 +1240,80 @@ private MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] arg
12401240
* @throws ParseException
12411241
*/
12421242
private synchronized MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] args, boolean isFile, boolean isNamedArgument, boolean isPyDML, String configFilePath) throws IOException, DMLException {
1243-
// Set active MLContext.
1244-
_activeMLContext = this;
1245-
1246-
if(_monitorUtils != null) {
1247-
_monitorUtils.resetMonitoringData();
1248-
}
1249-
1250-
if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) {
1251-
1252-
// Depending on whether registerInput/registerOutput was called initialize the variables
1253-
String[] inputs; String[] outputs;
1254-
if(_inVarnames != null) {
1255-
inputs = _inVarnames.toArray(new String[0]);
1256-
}
1257-
else {
1258-
inputs = new String[0];
1259-
}
1260-
if(_outVarnames != null) {
1261-
outputs = _outVarnames.toArray(new String[0]);
1262-
}
1263-
else {
1264-
outputs = new String[0];
1243+
try {
1244+
if(getActiveMLContext() != null) {
1245+
throw new DMLRuntimeException("SystemML (and hence by definition MLContext) doesnot support parallel execute() calls from same or different MLContexts. "
1246+
+ "As a temporary fix, please do explicit synchronization, i.e. synchronized(MLContext.class) { ml.execute(...) } ");
12651247
}
1266-
Map<String, MatrixCharacteristics> outMetadata = new HashMap<String, MatrixCharacteristics>();
12671248

1268-
Map<String, String> argVals = DMLScript.createArgumentsMap(isNamedArgument, args);
1249+
// Set active MLContext.
1250+
_activeMLContext = this;
12691251

1270-
// Run the DML script
1271-
ExecutionContext ec = executeUsingSimplifiedCompilationChain(dmlScriptFilePath, isFile, argVals, isPyDML, inputs, outputs, _variables, configFilePath);
1252+
if(_monitorUtils != null) {
1253+
_monitorUtils.resetMonitoringData();
1254+
}
12721255

1273-
// Now collect the output
1274-
if(_outVarnames != null) {
1275-
if(_variables == null) {
1276-
throw new DMLRuntimeException("The symbol table returned after executing the script is empty");
1256+
if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) {
1257+
1258+
Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> retVal = null;
1259+
1260+
// Depending on whether registerInput/registerOutput was called initialize the variables
1261+
String[] inputs; String[] outputs;
1262+
if(_inVarnames != null) {
1263+
inputs = _inVarnames.toArray(new String[0]);
1264+
}
1265+
else {
1266+
inputs = new String[0];
1267+
}
1268+
if(_outVarnames != null) {
1269+
outputs = _outVarnames.toArray(new String[0]);
12771270
}
1271+
else {
1272+
outputs = new String[0];
1273+
}
1274+
Map<String, MatrixCharacteristics> outMetadata = new HashMap<String, MatrixCharacteristics>();
1275+
1276+
Map<String, String> argVals = DMLScript.createArgumentsMap(isNamedArgument, args);
12781277

1279-
for( String ovar : _outVarnames ) {
1280-
if( _variables.keySet().contains(ovar) ) {
1281-
outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
1278+
// Run the DML script
1279+
ExecutionContext ec = executeUsingSimplifiedCompilationChain(dmlScriptFilePath, isFile, argVals, isPyDML, inputs, outputs, _variables, configFilePath);
1280+
1281+
// Now collect the output
1282+
if(_outVarnames != null) {
1283+
if(_variables == null) {
1284+
throw new DMLRuntimeException("The symbol table returned after executing the script is empty");
12821285
}
1283-
else {
1284-
throw new DMLException("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript.");
1286+
1287+
for( String ovar : _outVarnames ) {
1288+
if( _variables.keySet().contains(ovar) ) {
1289+
if(retVal == null) {
1290+
retVal = new HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>>();
1291+
}
1292+
retVal.put(ovar, ((SparkExecutionContext) ec).getBinaryBlockRDDHandleForVariable(ovar));
1293+
outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
1294+
}
1295+
else {
1296+
throw new DMLException("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript.");
1297+
}
12851298
}
12861299
}
1300+
1301+
return new MLOutput(retVal, outMetadata);
12871302
}
1288-
1289-
return new MLOutput(_variables, ec, outMetadata);
1303+
else {
1304+
throw new DMLRuntimeException("Unsupported runtime:" + DMLScript.rtplatform.name());
1305+
}
1306+
12901307
}
1291-
else {
1292-
throw new DMLRuntimeException("Unsupported runtime:" + DMLScript.rtplatform.name());
1308+
finally {
1309+
// Remove global dml config and all thread-local configs
1310+
// TODO enable cleanup whenever invalid GNMF MLcontext is fixed
1311+
// (the test is invalid because it assumes that status of previous execute is kept)
1312+
//ConfigurationManager.setGlobalConfig(new DMLConfig());
1313+
//ConfigurationManager.clearLocalConfigs();
1314+
1315+
// Reset active MLContext.
1316+
_activeMLContext = null;
12931317
}
12941318
}
12951319

@@ -1451,4 +1475,4 @@ public MLMatrix read(SQLContext sqlContext, String filePath, String format) thro
14511475
// return MLMatrix.createMLMatrix(this, sqlContext, blocks, mc);
14521476
// }
14531477

1454-
}
1478+
}

src/main/java/org/apache/sysml/api/MLOutput.java

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,14 @@
3939
import org.apache.spark.sql.types.StructField;
4040
import org.apache.spark.sql.types.StructType;
4141
import org.apache.sysml.runtime.DMLRuntimeException;
42-
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
43-
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
4442
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
4543
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
4644
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
4745
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
4846
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
4947
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
5048
import org.apache.sysml.runtime.util.UtilFunctions;
51-
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
49+
5250
import scala.Tuple2;
5351

5452
/**
@@ -57,39 +55,31 @@
5755
*/
5856
public class MLOutput {
5957

60-
private LocalVariableMap _variables;
61-
private ExecutionContext _ec;
58+
Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs;
6259
private Map<String, MatrixCharacteristics> _outMetadata = null;
6360

64-
public MLOutput(LocalVariableMap variables, ExecutionContext ec, Map<String, MatrixCharacteristics> outMetadata) {
65-
this._variables = variables;
66-
this._ec = ec;
67-
this._outMetadata = outMetadata;
68-
}
69-
7061
public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
71-
if( _variables.keySet().contains(varName) ) {
72-
MatrixObject mo = _ec.getMatrixObject(varName);
73-
MatrixBlock mb = mo.acquireRead();
74-
mo.release();
75-
return mb;
76-
}
77-
else {
78-
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
79-
}
62+
MatrixCharacteristics mc = getMatrixCharacteristics(varName);
63+
// The matrix block is always pushed to an RDD and then we do collect
64+
// We can later avoid this by returning symbol table rather than "Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs"
65+
MatrixBlock mb = SparkExecutionContext.toMatrixBlock(getBinaryBlockedRDD(varName), (int) mc.getRows(), (int) mc.getCols(),
66+
mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
67+
return mb;
68+
}
69+
public MLOutput(Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
70+
this._outputs = outputs;
71+
this._outMetadata = outMetadata;
8072
}
8173

8274
public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
83-
if( _variables.keySet().contains(varName) ) {
84-
return ((SparkExecutionContext) _ec).getBinaryBlockRDDHandleForVariable(varName);
85-
}
86-
else {
87-
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
75+
if(_outputs.containsKey(varName)) {
76+
return _outputs.get(varName);
8877
}
78+
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
8979
}
9080

9181
public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
92-
if(_outMetadata.containsKey(varName)) {
82+
if(_outputs.containsKey(varName)) {
9383
return _outMetadata.get(varName);
9484
}
9585
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
@@ -255,15 +245,15 @@ public Iterable<Tuple2<Long, Tuple2<Long, Double[]>>> call(Tuple2<MatrixIndexes,
255245
int lclen = UtilFunctions.computeBlockSize(clen, blockColIndex, bclen);
256246
// ------------------------------------------------------------------
257247

258-
long startRowIndex = (kv._1.getRowIndex()-1) * bclen;
248+
long startRowIndex = (kv._1.getRowIndex()-1) * bclen + 1;
259249
MatrixBlock blk = kv._2;
260250
ArrayList<Tuple2<Long, Tuple2<Long, Double[]>>> retVal = new ArrayList<Tuple2<Long,Tuple2<Long,Double[]>>>();
261251
for(int i = 0; i < lrlen; i++) {
262252
Double[] partialRow = new Double[lclen];
263253
for(int j = 0; j < lclen; j++) {
264254
partialRow[j] = blk.getValue(i, j);
265255
}
266-
retVal.add(new Tuple2<Long, Tuple2<Long,Double[]>>(startRowIndex + i + 1, new Tuple2<Long,Double[]>(kv._1.getColumnIndex(), partialRow)));
256+
retVal.add(new Tuple2<Long, Tuple2<Long,Double[]>>(startRowIndex + i, new Tuple2<Long,Double[]>(kv._1.getColumnIndex(), partialRow)));
267257
}
268258
return retVal;
269259
}
@@ -427,4 +417,4 @@ public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> arg0)
427417
return RowFactory.create(row);
428418
}
429419
}
430-
}
420+
}

src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import org.apache.spark.api.java.JavaPairRDD;
2323
import org.apache.spark.sql.DataFrame;
24+
import org.apache.sysml.runtime.DMLRuntimeException;
25+
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
2426
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
2527
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
2628
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -97,6 +99,13 @@ public BinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks,
9799
public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
98100
return binaryBlocks;
99101
}
102+
103+
public MatrixBlock getMatrixBlock() throws DMLRuntimeException {
104+
MatrixCharacteristics mc = getMatrixCharacteristics();
105+
MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(),
106+
mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
107+
return mb;
108+
}
100109

101110
/**
102111
* Obtain the SystemML binary-block matrix characteristics

src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ public static double[][] matrixObjectToDoubleMatrix(MatrixObject matrixObject) {
676676
* @return the {@code MatrixObject} converted to a {@code DataFrame}
677677
*/
678678
public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
679-
SparkExecutionContext sparkExecutionContext) {
679+
SparkExecutionContext sparkExecutionContext, boolean isVectorDF) {
680680
try {
681681
@SuppressWarnings("unchecked")
682682
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockMatrix = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext
@@ -686,8 +686,17 @@ public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
686686
MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
687687
SparkContext sc = activeMLContext.getSparkContext();
688688
SQLContext sqlContext = new SQLContext(sc);
689-
DataFrame df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
689+
DataFrame df = null;
690+
if(isVectorDF) {
691+
df = RDDConverterUtilsExt.binaryBlockToVectorDataFrame(binaryBlockMatrix, matrixCharacteristics,
692+
sqlContext);
693+
}
694+
else {
695+
df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
690696
sqlContext);
697+
}
698+
699+
691700
return df;
692701
} catch (DMLRuntimeException e) {
693702
throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e);

0 commit comments

Comments
 (0)