6565import org .apache .sysml .runtime .controlprogram .caching .MatrixObject ;
6666import org .apache .sysml .runtime .controlprogram .context .ExecutionContext ;
6767import org .apache .sysml .runtime .controlprogram .context .ExecutionContextFactory ;
68+ import org .apache .sysml .runtime .controlprogram .context .SparkExecutionContext ;
6869import org .apache .sysml .runtime .instructions .Instruction ;
6970import org .apache .sysml .runtime .instructions .cp .Data ;
7071import org .apache .sysml .runtime .instructions .spark .data .RDDObject ;
@@ -476,25 +477,6 @@ public void registerInput(String varName, RDD<String> rdd, String format, long r
476477 registerInput (varName , rdd .toJavaRDD ().mapToPair (new ConvertStringToLongTextPair ()), format , rlen , clen , nnz , null );
477478 }
478479
479- public void registerInput (String varName , MatrixBlock mb ) throws DMLRuntimeException {
480- MatrixCharacteristics mc = new MatrixCharacteristics (mb .getNumRows (), mb .getNumColumns (), OptimizerUtils .DEFAULT_BLOCKSIZE , OptimizerUtils .DEFAULT_BLOCKSIZE , mb .getNonZeros ());
481- registerInput (varName , mb , mc );
482- }
483-
484- public void registerInput (String varName , MatrixBlock mb , MatrixCharacteristics mc ) throws DMLRuntimeException {
485- if (_variables == null )
486- _variables = new LocalVariableMap ();
487- if (_inVarnames == null )
488- _inVarnames = new ArrayList <String >();
489-
490- MatrixObject mo = new MatrixObject (ValueType .DOUBLE , "temp" , new MatrixFormatMetaData (mc , OutputInfo .BinaryBlockOutputInfo , InputInfo .BinaryBlockInputInfo ));
491- mo .acquireModify (mb );
492- mo .release ();
493- _variables .put (varName , mo );
494- _inVarnames .add (varName );
495- checkIfRegisteringInputAllowed ();
496- }
497-
498480 // All CSV related methods call this ... It provides access to dimensions, nnz, file properties.
499481 private void registerInput (String varName , JavaPairRDD <LongWritable , Text > textOrCsv_rdd , String format , long rlen , long clen , long nnz , FileFormatProperties props ) throws DMLRuntimeException {
500482 if (!(DMLScript .rtplatform == RUNTIME_PLATFORM .SPARK || DMLScript .rtplatform == RUNTIME_PLATFORM .HYBRID_SPARK )) {
@@ -618,6 +600,24 @@ public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock>
618600 checkIfRegisteringInputAllowed ();
619601 }
620602
603+ public void registerInput (String varName , MatrixBlock mb ) throws DMLRuntimeException {
604+ MatrixCharacteristics mc = new MatrixCharacteristics (mb .getNumRows (), mb .getNumColumns (), OptimizerUtils .DEFAULT_BLOCKSIZE , OptimizerUtils .DEFAULT_BLOCKSIZE , mb .getNonZeros ());
605+ registerInput (varName , mb , mc );
606+ }
607+
608+ public void registerInput (String varName , MatrixBlock mb , MatrixCharacteristics mc ) throws DMLRuntimeException {
609+ if (_variables == null )
610+ _variables = new LocalVariableMap ();
611+ if (_inVarnames == null )
612+ _inVarnames = new ArrayList <String >();
613+ MatrixObject mo = new MatrixObject (ValueType .DOUBLE , "temp" , new MatrixFormatMetaData (mc , OutputInfo .BinaryBlockOutputInfo , InputInfo .BinaryBlockInputInfo ));
614+ mo .acquireModify (mb );
615+ mo .release ();
616+ _variables .put (varName , mo );
617+ _inVarnames .add (varName );
618+ checkIfRegisteringInputAllowed ();
619+ }
620+
621621 // =============================================================================================
622622
623623 /**
@@ -1240,56 +1240,80 @@ private MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] arg
12401240 * @throws ParseException
12411241 */
12421242 private synchronized MLOutput compileAndExecuteScript (String dmlScriptFilePath , String [] args , boolean isFile , boolean isNamedArgument , boolean isPyDML , String configFilePath ) throws IOException , DMLException {
1243- // Set active MLContext.
1244- _activeMLContext = this ;
1245-
1246- if (_monitorUtils != null ) {
1247- _monitorUtils .resetMonitoringData ();
1248- }
1249-
1250- if (DMLScript .rtplatform == RUNTIME_PLATFORM .SPARK || DMLScript .rtplatform == RUNTIME_PLATFORM .HYBRID_SPARK ) {
1251-
1252- // Depending on whether registerInput/registerOutput was called initialize the variables
1253- String [] inputs ; String [] outputs ;
1254- if (_inVarnames != null ) {
1255- inputs = _inVarnames .toArray (new String [0 ]);
1256- }
1257- else {
1258- inputs = new String [0 ];
1259- }
1260- if (_outVarnames != null ) {
1261- outputs = _outVarnames .toArray (new String [0 ]);
1262- }
1263- else {
1264- outputs = new String [0 ];
1243+ try {
1244+ if (getActiveMLContext () != null ) {
1245+ throw new DMLRuntimeException ("SystemML (and hence by definition MLContext) doesnot support parallel execute() calls from same or different MLContexts. "
1246+ + "As a temporary fix, please do explicit synchronization, i.e. synchronized(MLContext.class) { ml.execute(...) } " );
12651247 }
1266- Map <String , MatrixCharacteristics > outMetadata = new HashMap <String , MatrixCharacteristics >();
12671248
1268- Map <String , String > argVals = DMLScript .createArgumentsMap (isNamedArgument , args );
1249+ // Set active MLContext.
1250+ _activeMLContext = this ;
12691251
1270- // Run the DML script
1271- ExecutionContext ec = executeUsingSimplifiedCompilationChain (dmlScriptFilePath , isFile , argVals , isPyDML , inputs , outputs , _variables , configFilePath );
1252+ if (_monitorUtils != null ) {
1253+ _monitorUtils .resetMonitoringData ();
1254+ }
12721255
1273- // Now collect the output
1274- if (_outVarnames != null ) {
1275- if (_variables == null ) {
1276- throw new DMLRuntimeException ("The symbol table returned after executing the script is empty" );
1256+ if (DMLScript .rtplatform == RUNTIME_PLATFORM .SPARK || DMLScript .rtplatform == RUNTIME_PLATFORM .HYBRID_SPARK ) {
1257+
1258+ Map <String , JavaPairRDD <MatrixIndexes ,MatrixBlock >> retVal = null ;
1259+
1260+ // Depending on whether registerInput/registerOutput was called initialize the variables
1261+ String [] inputs ; String [] outputs ;
1262+ if (_inVarnames != null ) {
1263+ inputs = _inVarnames .toArray (new String [0 ]);
1264+ }
1265+ else {
1266+ inputs = new String [0 ];
1267+ }
1268+ if (_outVarnames != null ) {
1269+ outputs = _outVarnames .toArray (new String [0 ]);
12771270 }
1271+ else {
1272+ outputs = new String [0 ];
1273+ }
1274+ Map <String , MatrixCharacteristics > outMetadata = new HashMap <String , MatrixCharacteristics >();
1275+
1276+ Map <String , String > argVals = DMLScript .createArgumentsMap (isNamedArgument , args );
12781277
1279- for ( String ovar : _outVarnames ) {
1280- if ( _variables .keySet ().contains (ovar ) ) {
1281- outMetadata .put (ovar , ec .getMatrixCharacteristics (ovar )); // For converting output to dataframe
1278+ // Run the DML script
1279+ ExecutionContext ec = executeUsingSimplifiedCompilationChain (dmlScriptFilePath , isFile , argVals , isPyDML , inputs , outputs , _variables , configFilePath );
1280+
1281+ // Now collect the output
1282+ if (_outVarnames != null ) {
1283+ if (_variables == null ) {
1284+ throw new DMLRuntimeException ("The symbol table returned after executing the script is empty" );
12821285 }
1283- else {
1284- throw new DMLException ("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript." );
1286+
1287+ for ( String ovar : _outVarnames ) {
1288+ if ( _variables .keySet ().contains (ovar ) ) {
1289+ if (retVal == null ) {
1290+ retVal = new HashMap <String , JavaPairRDD <MatrixIndexes ,MatrixBlock >>();
1291+ }
1292+ retVal .put (ovar , ((SparkExecutionContext ) ec ).getBinaryBlockRDDHandleForVariable (ovar ));
1293+ outMetadata .put (ovar , ec .getMatrixCharacteristics (ovar )); // For converting output to dataframe
1294+ }
1295+ else {
1296+ throw new DMLException ("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript." );
1297+ }
12851298 }
12861299 }
1300+
1301+ return new MLOutput (retVal , outMetadata );
12871302 }
1288-
1289- return new MLOutput (_variables , ec , outMetadata );
1303+ else {
1304+ throw new DMLRuntimeException ("Unsupported runtime:" + DMLScript .rtplatform .name ());
1305+ }
1306+
12901307 }
1291- else {
1292- throw new DMLRuntimeException ("Unsupported runtime:" + DMLScript .rtplatform .name ());
1308+ finally {
1309+ // Remove global dml config and all thread-local configs
1310+ // TODO enable cleanup whenever invalid GNMF MLcontext is fixed
1311+ // (the test is invalid because it assumes that status of previous execute is kept)
1312+ //ConfigurationManager.setGlobalConfig(new DMLConfig());
1313+ //ConfigurationManager.clearLocalConfigs();
1314+
1315+ // Reset active MLContext.
1316+ _activeMLContext = null ;
12931317 }
12941318 }
12951319
@@ -1451,4 +1475,4 @@ public MLMatrix read(SQLContext sqlContext, String filePath, String format) thro
14511475// return MLMatrix.createMLMatrix(this, sqlContext, blocks, mc);
14521476// }
14531477
1454- }
1478+ }
0 commit comments