diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index 9976adc09f8..03fc6a16a15 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -32,6 +32,7 @@ import java.util.Collections; import java.util.Date; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Scanner; @@ -48,6 +49,7 @@ import org.apache.hadoop.util.GenericOptionsParser; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.ConfigurationManager; @@ -65,13 +67,14 @@ import org.apache.sysml.parser.ParserWrapper; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.io.IOUtilFunctions; import org.apache.sysml.runtime.matrix.CleanupMR; @@ -79,13 +82,9 @@ import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration; import org.apache.sysml.runtime.util.LocalFileUtils; import org.apache.sysml.runtime.util.MapReduceTool; -import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.NativeHelper; -import org.apache.sysml.utils.Explain.ExplainCounts; import org.apache.sysml.utils.Explain.ExplainType; import org.apache.sysml.utils.Statistics; -import org.apache.sysml.yarn.DMLAppMasterUtils; -import org.apache.sysml.yarn.DMLYarnClientProxy; public class DMLScript @@ -111,10 +110,9 @@ public enum EvictionPolicy { // ARC, // https://dbs.uni-leipzig.de/file/ARC.pdf // LOOP_AWARE // different policies for operations in for/while/parfor loop vs out-side the loop } - - // TODO: Anthony - public static boolean JMLC_MEM_STATISTICS = false; // whether to gather memory use stats in JMLC - + + public static RUNTIME_PLATFORM rtplatform = DMLOptions.defaultOptions.execMode; // the execution mode + // debug mode is deprecated and will be removed soon. public static boolean ENABLE_DEBUG_MODE = DMLOptions.defaultOptions.debug; // debug mode @@ -141,7 +139,7 @@ public enum EvictionPolicy { public static boolean VALIDATOR_IGNORE_ISSUES = false; public static String _uuid = IDHandler.createDistributedUniqueID(); - private static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); + static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); /////////////////////////////// // public external interface @@ -195,6 +193,7 @@ public static void main(String[] args) } } + /** * Single entry point for all public invocation alternatives (e.g., * main, executeScript, JaqlUdf etc) @@ -215,8 +214,7 @@ public static boolean executeScript( Configuration conf, String[] args ) { { dmlOptions = DMLOptions.parseCLArguments(args); ConfigurationManager.setGlobalOptions(dmlOptions); - - JMLC_MEM_STATISTICS = dmlOptions.memStats; + EXPLAIN = dmlOptions.explainType; ENABLE_DEBUG_MODE = dmlOptions.debug; SCRIPT_TYPE = dmlOptions.scriptType; @@ -378,104 +376,56 @@ else if (debug.equalsIgnoreCase("trace")) // (core compilation and execute) //////// - /** - * The running body of DMLScript execution. This method should be called after execution properties have been correctly set, - * and customized parameters have been put into _argVals - * - * @param dmlScriptStr DML script string - * @param fnameOptConfig configuration file - * @param argVals map of argument values - * @param allArgs arguments - * @param scriptType type of script (DML or PyDML) - * @throws IOException if IOException occurs - */ - private static void execute(String dmlScriptStr, String fnameOptConfig, Map argVals, String[] allArgs, ScriptType scriptType) - throws IOException - { - SCRIPT_TYPE = scriptType; - - //print basic time and environment info - printStartExecInfo( dmlScriptStr ); - - //Step 1: parse configuration files & write any configuration specific global variables - DMLConfig dmlconf = DMLConfig.readConfigurationFile(fnameOptConfig); - ConfigurationManager.setGlobalConfig(dmlconf); - CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(dmlconf); - ConfigurationManager.setGlobalConfig(cconf); - LOG.debug("\nDML config: \n" + dmlconf.getConfigInfo()); - - setGlobalFlags(dmlconf); - - //Step 2: set local/remote memory if requested (for compile in AM context) - if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ - DMLAppMasterUtils.setupConfigRemoteMaxMemory(dmlconf); - } - - //Step 3: parse dml script - Statistics.startCompileTimer(); - ParserWrapper parser = ParserFactory.createParser(scriptType); - DMLProgram prog = parser.parse(DML_FILE_PATH_ANTLR_PARSER, dmlScriptStr, argVals); - - //Step 4: construct HOP DAGs (incl LVA, validate, and setup) - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - - //init working directories (before usage by following compilation steps) - initHadoopExecution( dmlconf ); - - //Step 5: rewrite HOP DAGs (incl IPA and memory estimates) - dmlt.rewriteHopsDAG(prog); - - //Step 6: construct lops (incl exec type and op selection) - dmlt.constructLops(prog); - - if (LOG.isDebugEnabled()) { - LOG.debug("\n********************** LOPS DAG *******************"); - dmlt.printLops(prog); - dmlt.resetLopsDAGVisitStatus(prog); - } - - //Step 7: generate runtime program, incl codegen - Program rtprog = dmlt.getRuntimeProgram(prog, dmlconf); - - //launch SystemML appmaster (if requested and not already in launched AM) - if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ - if( !isActiveAM() && DMLYarnClientProxy.launchDMLYarnAppmaster(dmlScriptStr, dmlconf, allArgs, rtprog) ) - return; //if AM launch unsuccessful, fall back to normal execute - if( isActiveAM() ) //in AM context (not failed AM launch) - DMLAppMasterUtils.setupProgramMappingRemoteMaxMemory(rtprog); - } - - //Step 9: prepare statistics [and optional explain output] - //count number compiled MR jobs / SP instructions - ExplainCounts counts = Explain.countDistributedOperations(rtprog); - Statistics.resetNoOfCompiledJobs( counts.numJobs ); - - //explain plan of program (hops or runtime) - if( EXPLAIN != ExplainType.NONE ) - System.out.println(Explain.display(prog, rtprog, EXPLAIN, counts)); - - Statistics.stopCompileTimer(); - - //double costs = CostEstimationWrapper.getTimeEstimate(rtprog, ExecutionContextFactory.createContext()); - //System.out.println("Estimated costs: "+costs); - - //Step 10: execute runtime program - ExecutionContext ec = null; - try { - ec = ExecutionContextFactory.createContext(rtprog); - ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, dmlconf, ConfigurationManager.isStatistics() ? ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, null); - } - finally { - if(ec != null && ec instanceof SparkExecutionContext) - ((SparkExecutionContext) ec).close(); - LOG.info("END DML run " + getDateTime() ); - //cleanup scratch_space and all working dirs - cleanupHadoopExecution( dmlconf ); - } - } + /** + * The running body of DMLScript execution. This method should be called after execution properties have been correctly set, + * and customized parameters have been put into _argVals + * + * @param dmlScriptStr DML script string + * @param fnameOptConfig configuration file + * @param argVals map of argument values + * @param allArgs arguments + * @param scriptType type of script (DML or PyDML) + * @throws IOException if IOException occurs + */ + private static void execute(String dmlScriptStr, String fnameOptConfig, Map argVals, String[] allArgs, ScriptType scriptType) + throws IOException + { + SCRIPT_TYPE = scriptType; + + //print basic time and environment info + printStartExecInfo( dmlScriptStr ); + + //Step 1: parse configuration files & write any configuration specific global variables + DMLConfig dmlconf = DMLConfig.readConfigurationFile(fnameOptConfig); + ConfigurationManager.setGlobalConfig(dmlconf); + CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(dmlconf); + ConfigurationManager.setGlobalConfig(cconf); + LOG.debug("\nDML config: \n" + dmlconf.getConfigInfo()); + + setGlobalFlags(dmlconf); + Program rtprog = ScriptExecutorUtils.compileRuntimeProgram(dmlScriptStr, argVals, allArgs, + scriptType, dmlconf, SystemMLAPI.DMLScript); + List gCtxs = ConfigurationManager.getDMLOptions().gpu ? GPUContextPool.getAllGPUContexts() : null; + + //double costs = CostEstimationWrapper.getTimeEstimate(rtprog, ExecutionContextFactory.createContext()); + //System.out.println("Estimated costs: "+costs); + + //Step 10: execute runtime program + ExecutionContext ec = null; + try { + ec = ScriptExecutorUtils.executeRuntimeProgram( + rtprog, dmlconf, ConfigurationManager.isStatistics() ? + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, + new LocalVariableMap(), null, SystemMLAPI.DMLScript, gCtxs); + } + finally { + if(ec != null && ec instanceof SparkExecutionContext) + ((SparkExecutionContext) ec).close(); + LOG.info("END DML run " + getDateTime() ); + //cleanup scratch_space and all working dirs + cleanupHadoopExecution( dmlconf ); + } + } /** * Sets the global flags in DMLScript based on user provided configuration diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java index 995651889d2..7b2da91698e 100644 --- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java +++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,74 +19,271 @@ package org.apache.sysml.api; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; -import java.util.Set; +import java.util.Map; -import org.apache.sysml.api.mlcontext.ScriptExecutor; +import org.apache.sysml.api.jmlc.JMLCUtils; +import org.apache.sysml.api.mlcontext.MLContextUtil; +import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.hops.codegen.SpoofCompiler; -import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.hops.rewrite.ProgramRewriter; +import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DMLTranslator; +import org.apache.sysml.parser.LanguageException; +import org.apache.sysml.parser.ParseException; +import org.apache.sysml.parser.ParserFactory; +import org.apache.sysml.parser.ParserWrapper; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.instructions.gpu.context.GPUObject; +import org.apache.sysml.runtime.util.UtilFunctions; +import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.Statistics; +import org.apache.sysml.utils.Explain.ExplainCounts; +import org.apache.sysml.utils.Explain.ExplainType; +import org.apache.sysml.yarn.DMLAppMasterUtils; +import org.apache.sysml.yarn.DMLYarnClientProxy; +import org.apache.sysml.runtime.DMLRuntimeException; public class ScriptExecutorUtils { + public static final boolean IS_JCUDA_AVAILABLE; + static { + // Early detection of JCuda libraries avoids synchronization overhead for common JMLC scenario: + // i.e. CPU-only multi-threaded execution + boolean isJCudaAvailable = false; + try { + Class.forName("jcuda.Pointer"); + isJCudaAvailable = true; + } + catch (ClassNotFoundException e) { } + IS_JCUDA_AVAILABLE = isJCudaAvailable; + } + + public static enum SystemMLAPI { + DMLScript, + MLContext, + JMLC + } + + public static Program compileRuntimeProgram(String script, Map nsscripts, Map args, + String[] inputs, String[] outputs, ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api) { + return compileRuntimeProgram(script, nsscripts, args, null, null, inputs, outputs, + scriptType, dmlconf, api, true, false, false); + } + + public static Program compileRuntimeProgram(String script, Map args, String[] allArgs, + ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api) { + return compileRuntimeProgram(script, Collections.emptyMap(), args, allArgs, null, null, null, + scriptType, dmlconf, api, true, false, false); + } + /** - * Execute the runtime program. This involves execution of the program - * blocks that make up the runtime program and may involve dynamic - * recompilation. - * - * @param se - * script executor - * @param statisticsMaxHeavyHitters - * maximum number of statistics to print + * Compile a runtime program + * + * @param script string representing of the DML or PyDML script + * @param nsscripts map (name, script) of the DML or PyDML namespace scripts + * @param args map of input parameters ($) and their values + * @param allArgs commandline arguments + * @param symbolTable symbol table associated with MLContext + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param scriptType is this script DML or PyDML + * @param dmlconf configuration provided by the user + * @param api API used to execute the runtime program + * @param performHOPRewrites should perform hop rewrites + * @param maintainSymbolTable whether or not all values should be maintained in the symbol table after execution. + * @return compiled runtime program */ - public static void executeRuntimeProgram(ScriptExecutor se, int statisticsMaxHeavyHitters) { - Program prog = se.getRuntimeProgram(); - ExecutionContext ec = se.getExecutionContext(); - DMLConfig config = se.getConfig(); - executeRuntimeProgram(prog, ec, config, statisticsMaxHeavyHitters, se.getScript().getOutputVariables()); + public static Program compileRuntimeProgram(String script, Map nsscripts, Map args, String[] allArgs, + // Input/Outputs registered in MLContext and JMLC. These are set to null by DMLScript + LocalVariableMap symbolTable, String[] inputs, String[] outputs, + ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api, + // MLContext-specific flags + boolean performHOPRewrites, boolean maintainSymbolTable, + boolean init) { + DMLScript.SCRIPT_TYPE = scriptType; + + Program rtprog = null; + + if (ConfigurationManager.isGPU() && !IS_JCUDA_AVAILABLE) + throw new RuntimeException("Incorrect usage: Cannot use the GPU backend without JCuda libraries. Hint: Include systemml-*-extra.jar (compiled using mvn package -P distribution) into the classpath."); + else if (!ConfigurationManager.isGPU() && ConfigurationManager.isForcedGPU()) + throw new RuntimeException("Incorrect usage: Cannot force a GPU-execution without enabling GPU"); + + if(api == SystemMLAPI.JMLC) { + //check for valid names of passed arguments + String[] invalidArgs = args.keySet().stream() + .filter(k -> k==null || !k.startsWith("$")).toArray(String[]::new); + if( invalidArgs.length > 0 ) + throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs)); + + //check for valid names of input and output variables + String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream() + .filter(k -> k==null || k.startsWith("$")).toArray(String[]::new); + if( invalidVars.length > 0 ) + throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars)); + } + + String dmlParserFilePath = (api == SystemMLAPI.JMLC) ? null : DMLScript.DML_FILE_PATH_ANTLR_PARSER; + + try { + //Step 1: set local/remote memory if requested (for compile in AM context) + if(api == SystemMLAPI.DMLScript && dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ + DMLAppMasterUtils.setupConfigRemoteMaxMemory(dmlconf); + } + + // Start timer (disabled for JMLC) + if(api != SystemMLAPI.JMLC) + Statistics.startCompileTimer(); + + //Step 2: parse dml script + ParserWrapper parser = ParserFactory.createParser(scriptType, nsscripts); + DMLProgram prog = parser.parse(dmlParserFilePath, script, args); + + //Step 3: construct HOP DAGs (incl LVA, validate, and setup) + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + + //init working directories (before usage by following compilation steps) + if(api != SystemMLAPI.JMLC) + if ((api == SystemMLAPI.MLContext && init) || api != SystemMLAPI.MLContext) + DMLScript.initHadoopExecution( dmlconf ); + + + //Step 4: rewrite HOP DAGs (incl IPA and memory estimates) + if(performHOPRewrites) + dmlt.rewriteHopsDAG(prog); + + //Step 5: Remove Persistent Read/Writes + if(api == SystemMLAPI.JMLC) { + //rewrite persistent reads/writes + RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); + ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); + rewriter2.rewriteProgramHopDAGs(prog); + } + else if(api == SystemMLAPI.MLContext) { + //rewrite persistent reads/writes + RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, symbolTable); + ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); + rewriter2.rewriteProgramHopDAGs(prog); + } + + //Step 6: construct lops (incl exec type and op selection) + dmlt.constructLops(prog); + + if(DMLScript.LOG.isDebugEnabled()) { + DMLScript.LOG.debug("\n********************** LOPS DAG *******************"); + dmlt.printLops(prog); + dmlt.resetLopsDAGVisitStatus(prog); + } + + //Step 7: generate runtime program, incl codegen + rtprog = dmlt.getRuntimeProgram(prog, dmlconf); + + // Step 8: Cleanup/post-processing + if(api == SystemMLAPI.JMLC) { + JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); + } + else if(api == SystemMLAPI.DMLScript) { + //launch SystemML appmaster (if requested and not already in launched AM) + if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ + if( !DMLScript.isActiveAM() && DMLYarnClientProxy.launchDMLYarnAppmaster(script, dmlconf, allArgs, rtprog) ) + return null; //if AM launch unsuccessful, fall back to normal execute + if( DMLScript.isActiveAM() ) //in AM context (not failed AM launch) + DMLAppMasterUtils.setupProgramMappingRemoteMaxMemory(rtprog); + } + } + else if(api == SystemMLAPI.MLContext) { + if (maintainSymbolTable) { + MLContextUtil.deleteRemoveVariableInstructions(rtprog); + } else { + JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); + } + } + + //Step 9: prepare statistics [and optional explain output] + //count number compiled MR jobs / SP instructions + if(api != SystemMLAPI.JMLC) { + ExplainCounts counts = Explain.countDistributedOperations(rtprog); + Statistics.resetNoOfCompiledJobs( counts.numJobs ); + //explain plan of program (hops or runtime) + if( DMLScript.EXPLAIN != ExplainType.NONE ) + System.out.println(Explain.display(prog, rtprog, DMLScript.EXPLAIN, counts)); + + Statistics.stopCompileTimer(); + } + } + catch(ParseException pe) { + // don't chain ParseException (for cleaner error output) + throw pe; + } + catch(IOException ex) { + throw new DMLException(ex); + } + catch(Exception ex) { + throw new DMLException(ex); + } + return rtprog; } /** * Execute the runtime program. This involves execution of the program * blocks that make up the runtime program and may involve dynamic * recompilation. - * + * * @param rtprog * runtime program - * @param ec - * execution context * @param dmlconf * dml configuration * @param statisticsMaxHeavyHitters * maximum number of statistics to print + * @param symbolTable + * symbol table (that were registered as input as part of MLContext) * @param outputVariables - * output variables that were registered as part of MLContext + * output variables (that were registered as output as part of MLContext) + * @param api + * API used to execute the runtime program + * @param gCtxs + * list of GPU contexts + * @return execution context */ - public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters, Set outputVariables) { + public static ExecutionContext executeRuntimeProgram(Program rtprog, DMLConfig dmlconf, int statisticsMaxHeavyHitters, + LocalVariableMap symbolTable, HashSet outputVariables, + SystemMLAPI api, List gCtxs) { boolean exceptionThrown = false; - + + // Start timer Statistics.startRunTimer(); + + // Create execution context and attach registered outputs + ExecutionContext ec = ExecutionContextFactory.createContext(symbolTable, rtprog); + if(outputVariables != null) + ec.getVariables().setRegisteredOutputs(outputVariables); + + // Assign GPUContext to the current ExecutionContext + if(gCtxs != null) { + gCtxs.get(0).initializeThread(); + ec.setGPUContexts(gCtxs); + } + Exception finalizeException = null; try { // run execute (w/ exception handling to ensure proper shutdown) - if (ConfigurationManager.isGPU() && ec != null) { - List gCtxs = GPUContextPool.reserveAllGPUContexts(); - if (gCtxs == null) { - throw new DMLRuntimeException( - "GPU : Could not create GPUContext, either no GPU or all GPUs currently in use"); - } - gCtxs.get(0).initializeThread(); - ec.setGPUContexts(gCtxs); - } rtprog.execute(ec); } catch (Throwable e) { exceptionThrown = true; @@ -116,25 +313,32 @@ public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DM for(GPUContext gCtx : ec.getGPUContexts()) { gCtx.clearTemporaryMemory(); } - GPUContextPool.freeAllGPUContexts(); } catch (Exception e1) { exceptionThrown = true; finalizeException = e1; // do not throw exception while cleanup } + } if( ConfigurationManager.isCodegenEnabled() ) SpoofCompiler.cleanupCodeGenerator(); - - // display statistics (incl caching stats if enabled) + + //cleanup unnecessary outputs + if (outputVariables != null) + symbolTable.removeAllNotIn(outputVariables); + + // Display statistics (disabled for JMLC) Statistics.stopRunTimer(); - (exceptionThrown ? System.err : System.out) - .println(Statistics.display(statisticsMaxHeavyHitters > 0 ? - statisticsMaxHeavyHitters : ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters())); - ConfigurationManager.resetStatistics(); + if(api != SystemMLAPI.JMLC) { + (exceptionThrown ? System.err : System.out) + .println(Statistics.display(statisticsMaxHeavyHitters > 0 ? + statisticsMaxHeavyHitters : + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters())); + } } if(finalizeException != null) { throw new DMLRuntimeException("Error occured while GPU memory cleanup.", finalizeException); } - } + return ec; + } } diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java b/src/main/java/org/apache/sysml/api/jmlc/Connection.java index ea0d503f69c..71923c96d20 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java @@ -25,15 +25,17 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.CompilerConfig.ConfigType; @@ -41,18 +43,13 @@ import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.conf.DMLOptions; import org.apache.sysml.hops.codegen.SpoofCompiler; -import org.apache.sysml.hops.rewrite.ProgramRewriter; -import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite; -import org.apache.sysml.parser.DMLProgram; -import org.apache.sysml.parser.DMLTranslator; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.LanguageException; -import org.apache.sysml.parser.ParseException; -import org.apache.sysml.parser.ParserFactory; -import org.apache.sysml.parser.ParserWrapper; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.io.FrameReader; import org.apache.sysml.runtime.io.FrameReaderFactory; import org.apache.sysml.runtime.io.IOUtilFunctions; @@ -151,6 +148,8 @@ public Connection(DMLConfig dmlconfig, CompilerConfig.ConfigType... cconfigs) { * @param dmlconfig a dml configuration. */ public Connection(DMLConfig dmlconfig) { + DMLScript.rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; + //setup basic parameters for embedded execution //(parser, compiler, and runtime parameters) CompilerConfig cconf = new CompilerConfig(); @@ -193,7 +192,25 @@ public PreparedScript prepareScript( String script, String[] inputs, String[] ou /** * Prepares (precompiles) a script and registers input and output variables. - * + * + * @param script string representing the DML or PyDML script + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param useGpu {@code true} if prepare the script with GPU support, {@code false} + * @param forceGpu {@code true} if prepare the script with forced GPU support, {@code false} + * @param gpuIndex the GPU to use to execute the given prepared script + * @return PreparedScript object representing the precompiled script + */ + public PreparedScript prepareScript( + String script, String[] inputs, String[] outputs, boolean useGpu, boolean forceGpu, int gpuIndex) { + return prepareScript( + script, Collections.emptyMap(), Collections.emptyMap(), + inputs, outputs, false, useGpu, forceGpu, gpuIndex); + } + + /** + * Prepares (precompiles) a script and registers input and output variables. + * * @param script string representing the DML or PyDML script * @param inputs string array of input variables to register * @param outputs string array of output variables to register @@ -230,67 +247,72 @@ public PreparedScript prepareScript( String script, Map args, St * @return PreparedScript object representing the precompiled script */ public PreparedScript prepareScript(String script, Map nsscripts, Map args, String[] inputs, String[] outputs, boolean parsePyDML) { - DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML; + return prepareScript(script, nsscripts, args, inputs, outputs, parsePyDML, false, false, -1); + } + + /** + * List of available GPU contexts: + */ + static GPUContext [] AVAILABLE_GPU_CONTEXTS; + + + /** + * Prepares (precompiles) a script, sets input parameter values, and registers input and output variables. + * + * @param script string representing of the DML or PyDML script + * @param nsscripts map (name, script) of the DML or PyDML namespace scripts + * @param args map of input parameters ($) and their values + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param parsePyDML {@code true} if PyDML, {@code false} if DML + * @param useGPU {@code true} if prepare the script with GPU support, {@code false} + * @param forceGPU {@code true} if prepare the script with forced GPU support, {@code false} + * @param gpuIndex the GPU to use to execute the given prepared script + * @return PreparedScript object representing the precompiled script + */ + public PreparedScript prepareScript(String script, Map nsscripts, Map args, String[] inputs, String[] outputs, + boolean parsePyDML, boolean useGPU, boolean forceGPU, int gpuIndex) { - // Set DML Options here: - boolean gpu = false; boolean forceGPU = false; - ConfigurationManager.setLocalOptions(new DMLOptions(args, - false, 10, false, Explain.ExplainType.NONE, RUNTIME_PLATFORM.SINGLE_NODE, gpu, forceGPU, + DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML; + ConfigurationManager.setLocalOptions(new DMLOptions(args, + false, 10, false, + Explain.ExplainType.NONE, RUNTIME_PLATFORM.SINGLE_NODE, useGPU, forceGPU, parsePyDML ? ScriptType.PYDML : ScriptType.DML, null, script)); - - //check for valid names of passed arguments - String[] invalidArgs = args.keySet().stream() - .filter(k -> k==null || !k.startsWith("$")).toArray(String[]::new); - if( invalidArgs.length > 0 ) - throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs)); - - //check for valid names of input and output variables - String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream() - .filter(k -> k==null || k.startsWith("$")).toArray(String[]::new); - if( invalidVars.length > 0 ) - throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars)); - setLocalConfigs(); - - //simplified compilation chain - Program rtprog = null; - try { - //parsing - ParserWrapper parser = ParserFactory.createParser( - parsePyDML ? ScriptType.PYDML : ScriptType.DML, nsscripts); - DMLProgram prog = parser.parse(null, script, args); - - //language validate - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - - //hop construct/rewrite - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - - //rewrite persistent reads/writes - RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); - ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); - rewriter2.rewriteProgramHopDAGs(prog); - - //lop construct and runtime prog generation - dmlt.constructLops(prog); - rtprog = dmlt.getRuntimeProgram(prog, _dmlconf); - - //final cleanup runtime prog - JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); - } - catch(ParseException pe) { - // don't chain ParseException (for cleaner error output) - throw pe; - } - catch(Exception ex) { - throw new DMLException(ex); + + List _gpuCtx = new ArrayList<>(); + if (useGPU) { + if (AVAILABLE_GPU_CONTEXTS == null) { + synchronized (Connection.class) { + if (AVAILABLE_GPU_CONTEXTS == null) { + // Initialize the GPUs if not already + String oldAvailableGpus = GPUContextPool.AVAILABLE_GPUS; + GPUContextPool.AVAILABLE_GPUS = "-1"; // use all the GPUs in JMLC mode + List availableCtx = GPUContextPool.getAllGPUContexts(); + AVAILABLE_GPU_CONTEXTS = availableCtx.toArray(new GPUContext[availableCtx.size()]); + GPUContextPool.AVAILABLE_GPUS = oldAvailableGpus; + } + } + } + if (AVAILABLE_GPU_CONTEXTS.length == 0) + throw new DMLRuntimeException("No GPU Context in available"); + else if (gpuIndex < 0 || gpuIndex >= AVAILABLE_GPU_CONTEXTS.length) + throw new DMLRuntimeException("Cannot use the GPU " + gpuIndex + + ". Valid values: [0, " + (AVAILABLE_GPU_CONTEXTS.length - 1) + "]"); + // For simplicity of the API, the initial version statically associates a GPU to the prepared script. + // We can revisit this assumption if it turns out to be the overhead. + _gpuCtx.add(AVAILABLE_GPU_CONTEXTS[gpuIndex]); } - - //return newly create precompiled script - return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); + + Program rtprog = ScriptExecutorUtils.compileRuntimeProgram(script, nsscripts, args, inputs, outputs, + parsePyDML ? ScriptType.PYDML : ScriptType.DML, _dmlconf, SystemMLAPI.JMLC); + + + //return newly create precompiled script + PreparedScript ret = new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); + + if (useGPU) ret._gpuCtx = _gpuCtx; + return ret; } /** diff --git a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java index d5955f4bcd8..84601637f61 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java @@ -30,6 +30,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.ConfigurableAPI; import org.apache.sysml.api.DMLException; +import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.ConfigurationManager; @@ -53,6 +55,7 @@ import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MetaDataFormat; import org.apache.sysml.runtime.matrix.data.FrameBlock; @@ -80,8 +83,10 @@ public class PreparedScript implements ConfigurableAPI private final LocalVariableMap _vars; private final DMLConfig _dmlconf; private final CompilerConfig _cconf; + private boolean _isStatisticsEnabled = false; - + private boolean _gatherMemStats = false; + private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -141,7 +146,7 @@ protected PreparedScript( Program prog, String[] inputs, String[] outputs, DMLCo */ public void gatherMemStats(boolean stats) { this._isStatisticsEnabled = this._isStatisticsEnabled || ConfigurationManager.isStatistics(); - DMLScript.JMLC_MEM_STATISTICS = stats; + this._gatherMemStats = stats; } @Override @@ -433,7 +438,12 @@ public void setFrame(String varname, FrameBlock frame, boolean reuse) { public void clearParameters() { _vars.removeAll(); } - + + /** + * GPU Context to use for execution + */ + List _gpuCtx = null; + /** * Executes the prepared script over the bound inputs, creating the * result variables according to bound and registered outputs. @@ -443,20 +453,20 @@ public void clearParameters() { public ResultVariables executeScript() { //add reused variables _vars.putAll(_inVarReuse); - + //set thread-local configurations ConfigurationManager.setLocalConfig(_dmlconf); ConfigurationManager.setLocalConfig(_cconf); - + ConfigurationManager.setStatistics(_isStatisticsEnabled); + ConfigurationManager.setJMLCMemStats(_gatherMemStats); + ConfigurationManager.setFinegrainedStatistics(_gatherMemStats); + //create and populate execution context - ExecutionContext ec = ExecutionContextFactory.createContext(_vars, _prog); - - //core execute runtime program - _prog.execute(ec); - - //cleanup unnecessary outputs - _vars.removeAllNotIn(_outVarnames); - + ScriptExecutorUtils.executeRuntimeProgram( + _prog, _dmlconf, ConfigurationManager.isStatistics() ? + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, + _vars, _outVarnames, SystemMLAPI.JMLC, _gpuCtx); + //construct results ResultVariables rvars = new ResultVariables(); for( String ovar : _outVarnames ) { @@ -464,10 +474,9 @@ public ResultVariables executeScript() { if( tmpVar != null ) rvars.addResult(ovar, tmpVar); } - - //clear thread-local configurations + + // clear prior thread local configurations (for subsequent run) ConfigurationManager.clearLocalConfigs(); - ConfigurationManager.resetStatistics(); return rvars; diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index dee40605cda..f05e4dac010 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -242,8 +242,7 @@ public MLContext(JavaSparkContext javaSparkContext) { * execution mode, set MLContextProxy, set default config, set compiler * config. * - * @param sc - * SparkContext object. + * @param spark SparkContext object. */ private void initMLContext(SparkSession spark) { diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index 135e1cda05e..5c861c42a86 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -20,7 +20,9 @@ package org.apache.sysml.api.mlcontext; import java.io.IOException; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -28,6 +30,7 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.jmlc.JMLCUtils; import org.apache.sysml.api.mlcontext.MLContext.ExecutionType; import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; @@ -50,7 +53,8 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.Explain.ExplainCounts; import org.apache.sysml.utils.Explain.ExplainType; @@ -113,6 +117,7 @@ public class ScriptExecutor { protected ExecutionType executionType; protected int statisticsMaxHeavyHitters = 10; protected boolean maintainSymbolTable = false; + protected List gCtxs = null; /** * ScriptExecutor constructor. @@ -207,20 +212,6 @@ protected void countCompiledMRJobsAndSparkInstructions() { Statistics.resetNoOfCompiledJobs(counts.numJobs); } - /** - * Create an execution context and set its variables to be the symbol table - * of the script. - */ - protected void createAndInitializeExecutionContext() { - executionContext = ExecutionContextFactory.createContext(runtimeProgram); - LocalVariableMap symbolTable = script.getSymbolTable(); - if (symbolTable != null) - executionContext.setVariables(symbolTable); - //attach registered outputs (for dynamic recompile) - executionContext.getVariables().setRegisteredOutputs( - new HashSet(script.getOutputVariables())); - } - /** * Set the global flags (for example: statistics, gpu, etc). */ @@ -291,27 +282,23 @@ public void compile(Script script) { */ public void compile(Script script, boolean performHOPRewrites) { - // main steps in script execution setup(script); - if (statistics) { - Statistics.startCompileTimer(); - } - parseScript(); - liveVariableAnalysis(); - validateScript(); - constructHops(); - if(performHOPRewrites) - rewriteHops(); - rewritePersistentReadsAndWrites(); - constructLops(); - generateRuntimeProgram(); - showExplanation(); - countCompiledMRJobsAndSparkInstructions(); - initializeCachingAndScratchSpace(); - cleanupRuntimeProgram(); - if (statistics) { - Statistics.stopCompileTimer(); + + LocalVariableMap symbolTable = script.getSymbolTable(); + String[] inputs = null; String[] outputs = null; + if (symbolTable != null) { + inputs = (script.getInputVariables() == null) ? new String[0] + : script.getInputVariables().toArray(new String[0]); + outputs = (script.getOutputVariables() == null) ? new String[0] + : script.getOutputVariables().toArray(new String[0]); } + + Map args = MLContextUtil + .convertInputParametersForParser(script.getInputParameters(), script.getScriptType()); + runtimeProgram = ScriptExecutorUtils.compileRuntimeProgram(script.getScriptExecutionString(), Collections.emptyMap(), + args, null, symbolTable, inputs, outputs, script.getScriptType(), config, SystemMLAPI.MLContext, + performHOPRewrites, isMaintainSymbolTable(), init); + gCtxs = ConfigurationManager.isGPU() ? GPUContextPool.getAllGPUContexts() : null; } @@ -321,8 +308,6 @@ public void compile(Script script, boolean performHOPRewrites) { * *
    *
  1. {@link #compile(Script)}
  2. - *
  3. {@link #createAndInitializeExecutionContext()}
  4. - *
  5. {@link #executeRuntimeProgram()}
  6. *
  7. {@link #cleanupAfterExecution()}
  8. *
* @@ -352,8 +337,11 @@ public MLResults execute(Script script) { compile(script); try { - createAndInitializeExecutionContext(); - executeRuntimeProgram(); + executionContext = ScriptExecutorUtils.executeRuntimeProgram(getRuntimeProgram(), getConfig(), + statistics ? statisticsMaxHeavyHitters : 0, script.getSymbolTable(), + new HashSet(getScript().getOutputVariables()), SystemMLAPI.MLContext, gCtxs); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred while executing runtime program", e); } finally { cleanupAfterExecution(); } @@ -376,8 +364,17 @@ public MLResults execute(Script script) { */ protected void setup(Script script) { this.script = script; - checkScriptHasTypeAndString(); + if (script == null) { + throw new MLContextException("Script is null"); + } else if (script.getScriptType() == null) { + throw new MLContextException("ScriptType (DML or PYDML) needs to be specified"); + } else if (script.getScriptString() == null) { + throw new MLContextException("Script string is null"); + } else if (StringUtils.isBlank(script.getScriptString())) { + throw new MLContextException("Script string is blank"); + } script.setScriptExecutor(this); + // Set global variable indicating the script type DMLScript.SCRIPT_TYPE = script.getScriptType(); setGlobalFlags(); @@ -385,6 +382,7 @@ protected void setup(Script script) { Statistics.resetNoOfExecutedJobs(); if (statistics) Statistics.reset(); + DMLScript.EXPLAIN = (explainLevel != null) ? explainLevel.getExplainType() : ExplainType.NONE; } /** @@ -428,19 +426,6 @@ protected void cleanupRuntimeProgram() { } } - /** - * Execute the runtime program. This involves execution of the program - * blocks that make up the runtime program and may involve dynamic - * recompilation. - */ - protected void executeRuntimeProgram() { - try { - ScriptExecutorUtils.executeRuntimeProgram(this, statistics ? statisticsMaxHeavyHitters : 0); - } catch (DMLRuntimeException e) { - throw new MLContextException("Exception occurred while executing runtime program", e); - } - } - /** * Check security, create scratch space, cleanup working directories, * initialize caching, and reset statistics. @@ -467,12 +452,9 @@ protected void initializeCachingAndScratchSpace() { protected void parseScript() { try { ParserWrapper parser = ParserFactory.createParser(script.getScriptType()); - Map inputParameters = script.getInputParameters(); - Map inputParametersStringMaps = MLContextUtil - .convertInputParametersForParser(inputParameters, script.getScriptType()); - - String scriptExecutionString = script.getScriptExecutionString(); - dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps); + Map args = MLContextUtil + .convertInputParametersForParser(script.getInputParameters(), script.getScriptType()); + dmlProgram = parser.parse(null, script.getScriptExecutionString(), args); } catch (ParseException e) { throw new MLContextException("Exception occurred while parsing script", e); } diff --git a/src/main/java/org/apache/sysml/conf/ConfigurationManager.java b/src/main/java/org/apache/sysml/conf/ConfigurationManager.java index 96c3885a8be..b64a9f70420 100644 --- a/src/main/java/org/apache/sysml/conf/ConfigurationManager.java +++ b/src/main/java/org/apache/sysml/conf/ConfigurationManager.java @@ -155,7 +155,7 @@ public static DMLOptions getDMLOptions() { /** * Sets the current thread-local dml configuration to the given options. * - * @param conf the configuration + * @param opts the configuration */ public static void setLocalOptions( DMLOptions opts ) { _dmlOptions = opts; @@ -275,6 +275,7 @@ public static RUNTIME_PLATFORM getExecutionMode() { // _dmlconf.getBooleanValue(DMLConfig.EXTRA_FINEGRAINED_STATS); private static boolean STATISTICS = false; private static boolean FINEGRAINED_STATISTICS = false; + private static boolean JMLC_MEM_STATISTICS = false; /** * @return true if statistics is enabled @@ -289,7 +290,12 @@ public static boolean isStatistics() { public static boolean isFinegrainedStatistics() { return FINEGRAINED_STATISTICS; } - + + /** + * @return true if JMLC memory statistics are enabled + */ + public static boolean isJMLCMemStatistics() { return JMLC_MEM_STATISTICS; } + /** * Whether or not statistics about the DML/PYDML program should be output to * standard output. @@ -301,7 +307,31 @@ public static boolean isFinegrainedStatistics() { public static void setStatistics(boolean enabled) { STATISTICS = enabled; } - + + /** + * Whether or not detailed statistics about program memory use should be output + * to standard output when running under JMLC + * + * @param enabled + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public static void setJMLCMemStats(boolean enabled) { + JMLC_MEM_STATISTICS = enabled; + } + + + /** + * Whether or not finegrained statistics should be enabled + * + * @param enabled + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public static void setFinegrainedStatistics(boolean enabled) { + FINEGRAINED_STATISTICS = enabled; + } + /** * Reset the statistics flag. */ diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java index 6b241c175cf..7d4dd221f08 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java @@ -115,11 +115,8 @@ public void removeAllNotIn(Set blacklist) { } public boolean hasReferences( Data d ) { - //perf: avoid java streams here for reduced overhead in rmvar - for( Data o : localMap.values() ) - if( o instanceof ListObject ? ((ListObject)o).getData().contains(d) : o == d ) - return true; - return false; + return localMap.values().stream().anyMatch(e -> (e instanceof ListObject) ? + ((ListObject)e).getData().contains(d) : e == d); } public void setRegisteredOutputs(HashSet outputs) { @@ -143,7 +140,7 @@ public double getPinnedDataSize() { if( !dict.containsKey(hash) && e.getValue() instanceof CacheableData ) { dict.put(hash, e.getValue()); double size = ((CacheableData) e.getValue()).getDataSize(); - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) Statistics.maintainCPHeavyHittersMem(e.getKey(), size); total += size; } diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java index feb12348e9b..2b65c8d8105 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java @@ -259,7 +259,7 @@ private void executeSingleInstruction( Instruction currInst, ExecutionContext ec Statistics.maintainCPHeavyHitters( tmp.getExtendedOpcode(), System.nanoTime()-t0); } - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) ec.getVariables().getPinnedDataSize(); // optional trace information (instruction and runtime) diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index 15dd23e570f..6cfea793248 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -500,7 +500,7 @@ public T acquireModify(T newData) { if( ConfigurationManager.isStatistics() ){ long t1 = System.nanoTime(); CacheStatistics.incrementAcquireMTime(t1-t0); - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.addCPMemObject(System.identityHashCode(this), getDataSize()); } diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index f310d76a2ca..abedf31c6da 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -644,7 +644,7 @@ else if( dat instanceof ListObject ) } public void cleanupCacheableData(CacheableData mo) { - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.removeCPMemObject(System.identityHashCode(mo)); //early abort w/o scan of symbol table if no cleanup required boolean fileExists = (mo.isHDFSFileExists() && mo.getFileName() != null); diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 879133b847a..db2994698e6 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -42,6 +42,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; +import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContextUtil; @@ -1100,7 +1101,7 @@ public void cleanupCacheableData(CacheableData mo) //and hence is transparently used by rmvar instructions and other users. The //core difference is the lineage-based cleanup of RDD and broadcast variables. - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.removeCPMemObject(System.identityHashCode(mo)); if( !mo.isCleanupEnabled() ) diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java index 242127a601b..4db3ee040e4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java @@ -28,6 +28,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.utils.GPUStatistics; @@ -71,9 +72,9 @@ public class GPUContextPool { static List pool = new LinkedList<>(); /** - * Whether the pool of GPUs is reserved or not + * Used to throw an error in case of incorrect usage */ - static boolean reserved = false; + private static String oldAvailableGpus; /** * Static initialization of the number of devices @@ -99,6 +100,7 @@ public synchronized static void initializeGPU() { try { ArrayList listOfGPUs = parseListString(AVAILABLE_GPUS, deviceCount); + oldAvailableGpus = AVAILABLE_GPUS; // Initialize the list of devices & the pool of GPUContexts for (int i : listOfGPUs) { @@ -202,17 +204,17 @@ public static ArrayList parseListString(String str, int max) { } /** - * Reserves and gets an initialized list of GPUContexts + * Gets an initialized list of GPUContexts * * @return null if no GPUContexts in pool, otherwise a valid list of GPUContext */ - public static synchronized List reserveAllGPUContexts() { - if (reserved) - throw new DMLRuntimeException("Trying to re-reserve GPUs"); + public static synchronized List getAllGPUContexts() { if (!initialized) initializeGPU(); - reserved = true; - LOG.trace("GPU : Reserved all GPUs"); + if(!oldAvailableGpus.equals(AVAILABLE_GPUS)) { + LOG.warn("GPUContextPool was already initialized with " + DMLConfig.AVAILABLE_GPUS + "=" + oldAvailableGpus + + ". Cannot reinitialize it with " + DMLConfig.AVAILABLE_GPUS + "=" + AVAILABLE_GPUS); + } return pool; } @@ -249,17 +251,6 @@ public static int getDeviceCount() { return deviceCount; } - /** - * Unreserves all GPUContexts - */ - public static synchronized void freeAllGPUContexts() { - if (!reserved) - throw new DMLRuntimeException("Trying to free unreserved GPUs"); - reserved = false; - LOG.trace("GPU : Unreserved all GPUs"); - - } - /** * Gets the initial GPU memory budget. This is the minimum of the * available memories across all the GPUs on the machine(s) diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index ca555641e12..f3ba9621d55 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -975,7 +975,7 @@ public static String display(int maxHeavyHitters) sb.append("Cache hits (Mem, WB, FS, HDFS):\t" + CacheStatistics.displayHits() + ".\n"); sb.append("Cache writes (WB, FS, HDFS):\t" + CacheStatistics.displayWrites() + ".\n"); sb.append("Cache times (ACQr/m, RLS, EXP):\t" + CacheStatistics.displayTime() + " sec.\n"); - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) sb.append("Max size of live objects:\t" + byteCountToDisplaySize(getSizeofPinnedObjects()) + " (" + getNumPinnedObjects() + " total objects)" + "\n"); sb.append("HOP DAGs recompiled (PRED, SB):\t" + getHopRecompiledPredDAGs() + "/" + getHopRecompiledSBDAGs() + ".\n"); sb.append("HOP DAGs recompile time:\t" + String.format("%.3f", ((double)getHopRecompileTime())/1000000000) + " sec.\n"); @@ -1029,7 +1029,7 @@ public static String display(int maxHeavyHitters) sb.append("Total JVM GC time:\t\t" + ((double)getJVMgcTime())/1000 + " sec.\n"); LibMatrixDNN.appendStatistics(sb); sb.append("Heavy hitter instructions:\n" + getHeavyHitters(maxHeavyHitters)); - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) sb.append("Heavy hitter objects:\n" + getCPHeavyHittersMem(maxHeavyHitters)); } diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java index e1ae1ae9ff0..9068ffc7e5e 100644 --- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java +++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java @@ -56,9 +56,9 @@ public abstract class GPUTests extends AutomatedTestBase { protected static SparkSession spark; protected final double DOUBLE_PRECISION_THRESHOLD = 1e-9; // for relative error private static final boolean PRINT_MAT_ERROR = false; - - // We will use this flag until lower precision is supported on CP. - private final static String FLOATING_POINT_PRECISION = "double"; + + // We will use this flag until lower precision is supported on CP. + private final static String FLOATING_POINT_PRECISION = "double"; protected final double SINGLE_PRECISION_THRESHOLD = 1e-3; // for relative error @@ -106,7 +106,7 @@ protected synchronized void clearGPUMemory() { int freeCount = GPUContextPool.getAvailableCount(); Assert.assertTrue("All GPUContexts have not been returned to the GPUContextPool", count == freeCount); - List gCtxs = GPUContextPool.reserveAllGPUContexts(); + List gCtxs = GPUContextPool.getAllGPUContexts(); for (GPUContext gCtx : gCtxs) { gCtx.initializeThread(); try { @@ -116,9 +116,6 @@ protected synchronized void clearGPUMemory() { throw e; } } - GPUContextPool.freeAllGPUContexts(); - - } catch (DMLRuntimeException e) { // Ignore } diff --git a/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java b/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java new file mode 100644 index 00000000000..745edccf6db --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java @@ -0,0 +1,108 @@ +package org.apache.sysml.test.gpu; + +import java.util.Random; +import org.junit.Test; +import org.junit.Assert; +import org.apache.sysml.api.jmlc.Connection; +import org.apache.sysml.api.jmlc.PreparedScript; + + +public class JMLCTests extends GPUTests { + + static class ScriptContainer { + String dml; + String[] inputVarNames; + } + + @Test + public void testJMLC() { + try { + Connection conn = new Connection(); + + int numMatrices = 10; + int matrixNumRows = 100; + int numScriptInvocations = 10; + + ScriptContainer SC = generateDMLScript(numMatrices); + + PreparedScript script = conn.prepareScript( + SC.dml, SC.inputVarNames, new String[] { "Z" }, true, true, 0); + + // execute the script without pinning input matrices between invocations + executeDMLScript(script, numScriptInvocations, matrixNumRows, numMatrices, false); + + // execute the script while pinning input matrices between invocations + executeDMLScript(script, numScriptInvocations, matrixNumRows, numMatrices, true); + } catch (Exception e) { + Assert.fail("An unexpected exception occurred: " + e.getMessage()); + } + } + + // Generates a simple synthetic DML script which multiplies a sequence of square matrices. + // I.e. Z = X %*% W1 %*% W2 %*% W3 ... + // numMatrices determines the number of matrices in the sequences. The size of the matrices can be set + // in executeDMLScript + static ScriptContainer generateDMLScript(int numMatrices) { + ScriptContainer SC = new ScriptContainer(); + String[] inputVarNames = new String[numMatrices + 1]; + inputVarNames[0] = "x"; + + StringBuilder dml = new StringBuilder("x = read(\"/tmp/X.mtx\", rows=-1, cols=-1)\n"); + for (int ix=0; ix 1)\n print(as.scalar(Z[1,1]))\n"); + + SC.dml = dml.toString(); + SC.inputVarNames = inputVarNames; + + return SC; + } + + // Executes a PreparedScript generated by generateDMLScript. The parameter n determines the + // number of times the script is invoked. The parameter rows controls the shape of the matrices. + // Set this parameter larger to use more memory. The parameter numMatrices must be set to the same value as + // in generateDMLScript. The parameter pinWeights controls whether weight matrices should be + // pinned in memory between script invocations. + static void executeDMLScript(PreparedScript script, int n, int rows, int numMatrices, boolean pinWeights) { + for (int ix=0; ix sparsity) { + continue; + } + matrix[i][j] = (random.nextDouble() * (max - min) + min); + } + } + return matrix; + } + +} + diff --git a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java index 2f2022ebaf8..8cdba382ff3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java @@ -63,11 +63,12 @@ private void runJMLCParFor2ForTest(boolean par) PreparedScript pscript = conn.prepareScript( script, new String[]{}, new String[]{}, false); - ConfigurationManager.setStatistics(true); + pscript.setStatistics(true); pscript.executeScript(); conn.close(); + //check for existing or non-existing parfor - Assert.assertTrue(Statistics.getParforOptCount()==(par?1:0)); + Assert.assertTrue("INCORRECT PARFOR COUNT", Statistics.getParforOptCount()==(par?1:0)); } catch(Exception ex) { Assert.fail("JMLC parfor test failed: "+ex.getMessage());