Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public class DMLOptions {
public ExecMode execMode = OptimizerUtils.getDefaultExecutionMode(); // Execution mode standalone, MR, Spark or a hybrid
public boolean gpu = false; // Whether to use the GPU
public boolean forceGPU = false; // Whether to ignore memory & estimates and always use the GPU
public boolean ooc = false; // Whether to use the OOC backend
public boolean debug = false; // to go into debug mode to be able to step through a program
public String filePath = null; // path to script
public String script = null; // the script itself
Expand Down Expand Up @@ -109,6 +110,7 @@ public String toString() {
", execMode=" + execMode +
", gpu=" + gpu +
", forceGPU=" + forceGPU +
", ooc=" + ooc +
", debug=" + debug +
", filePath='" + filePath + '\'' +
", script='" + script + '\'' +
Expand Down Expand Up @@ -182,6 +184,7 @@ else if (lineageType.equalsIgnoreCase("debugger"))
}
}
}
dmlOptions.ooc = line.hasOption("ooc");
if (line.hasOption("exec")){
String execMode = line.getOptionValue("exec");
if (execMode.equalsIgnoreCase("singlenode")) dmlOptions.execMode = ExecMode.SINGLE_NODE;
Expand Down Expand Up @@ -388,6 +391,8 @@ private static Options createCLIOptions() {
Option gpuOpt = OptionBuilder.withArgName("force")
.withDescription("uses CUDA instructions when reasonable; set <force> option to skip conservative memory estimates and use GPU wherever possible; default off")
.hasOptionalArg().create("gpu");
Option oocOpt = OptionBuilder.withDescription("uses OOC backend")
.create("ooc");
Option debugOpt = OptionBuilder.withDescription("runs in debug mode; default off")
.create("debug");
Option pythonOpt = OptionBuilder
Expand Down Expand Up @@ -441,6 +446,7 @@ private static Options createCLIOptions() {
options.addOption(explainOpt);
options.addOption(execOpt);
options.addOption(gpuOpt);
options.addOption(oocOpt);
options.addOption(debugOpt);
options.addOption(lineageOpt);
options.addOption(fedOpt);
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public class DMLScript
public static boolean FORCE_ACCELERATOR = DMLOptions.defaultOptions.forceGPU;
// Enable synchronizing GPU after every instruction
public static boolean SYNCHRONIZE_GPU = true;
// Set OOC backend
public static boolean USE_OOC = DMLOptions.defaultOptions.ooc;
// Enable eager CUDA free on rmvar
public static boolean EAGER_CUDA_FREE = false;

Expand Down Expand Up @@ -266,6 +268,7 @@ public static boolean executeScript( String[] args )
JMLC_MEM_STATISTICS = dmlOptions.memStats;
USE_ACCELERATOR = dmlOptions.gpu;
FORCE_ACCELERATOR = dmlOptions.forceGPU;
USE_OOC = dmlOptions.ooc;
EXPLAIN = dmlOptions.explainType;
EXEC_MODE = dmlOptions.execMode;
LINEAGE = dmlOptions.lineage;
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface Types {
* Execution mode for entire script. This setting specify which {@link ExecType}s are allowed.
*/
public enum ExecMode {
/** Execute all operations in {@link ExecType#CP} and if available {@link ExecType#GPU} */
/** Execute all operations in {@link ExecType#CP}, {@link ExecType#OOC} and if available {@link ExecType#GPU} */
SINGLE_NODE,
/**
* The default and encouraged ExecMode. Execute operations while leveraging all available options:
Expand All @@ -58,6 +58,8 @@ public enum ExecType {
GPU,
/** FED: indicate that the instruction should be executed as a Federated instruction */
FED,
/** Out of Core: indicate that the operation should be executed out of core. */
OOC,
/** invalid is used for debugging or if it is undecided where the current instruction should be executed */
INVALID
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/hops/Hop.java
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ else if ( DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE && _etypeForced
if(_etypeForced != ExecType.CP && _etypeForced != ExecType.GPU)
_etypeForced = ExecType.CP;
}
else if (DMLScript.USE_OOC){
_etypeForced = ExecType.OOC;
}
else {
// enabled with -exec singlenode option
_etypeForced = ExecType.CP;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public enum IType {
BREAKPOINT,
SPARK,
GPU,
FEDERATED
FEDERATED,
OUT_OF_CORE
}

protected static final Log LOG = LogFactory.getLog(Instruction.class.getName());
Expand All @@ -53,6 +54,7 @@ protected Instruction(Operator _optr){
public static final String SP_INST_PREFIX = "sp_";
public static final String GPU_INST_PREFIX = "gpu_";
public static final String FEDERATED_INST_PREFIX = "fed_";
public static final String OOC_INST_PREFIX = "ooc_";

//basic instruction meta data
protected String instString = null;
Expand Down Expand Up @@ -184,6 +186,8 @@ else if( getType() == IType.GPU )
extendedOpcode = GPU_INST_PREFIX + getOpcode();
else if( getType() == IType.FEDERATED)
extendedOpcode = FEDERATED_INST_PREFIX + getOpcode();
else if( getType() == IType.OUT_OF_CORE)
extendedOpcode = OOC_INST_PREFIX + getOpcode();
else
extendedOpcode = getOpcode();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public static Instruction parseSingleInstruction ( String str ) {
if( fedtype == null )
throw new DMLRuntimeException("Unknown FEDERATED instruction: " + str);
return FEDInstructionParser.parseSingleInstruction (fedtype, str);
case OOC:
InstructionType ooctype = InstructionUtils.getOOCType(str);
if( ooctype == null )
throw new DMLRuntimeException("Unknown OOC instruction: " + str);
return OOCInstructionParser.parseSingleInstruction (ooctype, str);
default:
throw new DMLRuntimeException("Unknown execution type in instruction: " + str);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,11 @@ public static InstructionType getFEDType(String str) {
return Opcodes.getTypeByOpcode(op, Types.ExecType.FED);
}

public static InstructionType getOOCType(String str) {
String op = getOpCode(str);
return Opcodes.getTypeByOpcode(op, Types.ExecType.OOC);
}

public static boolean isBuiltinFunction( String opcode ) {
Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode);
return (bfc != null);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.InstructionType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;

public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());

public static OOCInstruction parseSingleInstruction(String str) {
if(str == null || str.isEmpty())
return null;
InstructionType ooctype = InstructionUtils.getOOCType(str);
if(ooctype == null)
throw new DMLRuntimeException("Unable derive ooctype for instruction: " + str);
OOCInstruction oocinst = parseSingleInstruction(ooctype, str);
if(oocinst == null)
throw new DMLRuntimeException("Unable to parse instruction: " + str);
return oocinst;
}

public static OOCInstruction parseSingleInstruction(InstructionType ooctype, String str) {
if(str == null || str.isEmpty())
return null;
switch(ooctype) {

// TODO:
case AggregateUnary:
case Binary:

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

public abstract class OOCInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());

public enum OOCType {
AggregateUnary, Binary
}

protected final OOCInstruction.OOCType _ooctype;
protected final boolean _requiresLabelUpdate;

protected OOCInstruction(OOCInstruction.OOCType type, String opcode, String istr) {
this(type, null, opcode, istr);
}

protected OOCInstruction(OOCInstruction.OOCType type, Operator op, String opcode, String istr) {
super(op);
_ooctype = type;
instString = istr;
instOpcode = opcode;

_requiresLabelUpdate = super.requiresLabelUpdate();
}

@Override
public IType getType() {
return IType.OUT_OF_CORE;
}

public OOCInstruction.OOCType getOOCInstructionType() {
return _ooctype;
}

@Override
public boolean requiresLabelUpdate() {
return _requiresLabelUpdate;
}

@Override
public String getGraphString() {
return getOpcode();
}

@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
// TODO
return null;
}

@Override
public abstract void processInstruction(ExecutionContext ec);

@Override
public void postprocessInstruction(ExecutionContext ec) {
if(DMLScript.LINEAGE_DEBUGGER)
ec.maintainLineageDebuggerInfo(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class SumScalarMultiplicationTest extends AutomatedTestBase {

private static final String TEST_NAME = "SumScalarMultiplication";
private static final String TEST_DIR = "functions/ooc/";
private static final String TEST_CLASS_DIR = TEST_DIR + SumScalarMultiplicationTest.class.getSimpleName() + "/";
private static final String INPUT_NAME = "X";
private static final String OUTPUT_NAME = "res";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
addTestConfiguration(TEST_NAME, config);
}

/**
* Test the sum of scalar multiplication, "sum(X*7)", with OOC backend.
*/
@Test
public void testSumScalarMult() {

Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;

try {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), output(OUTPUT_NAME)};

int rows = 3;
int cols = 4;
double sparsity = 0.8;

double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
writeInputMatrixWithMTD(INPUT_NAME, X, true);

runTest(true, false, null, -1);

HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir(OUTPUT_NAME);
// only one entry
Double result = dmlfile.get(new MatrixValue.CellIndex(1, 1));

double expected = 0.0;
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
expected += X[i][j] * 7;
}
}

Assert.assertEquals(expected, result, 1e-10);

String prefix = Instruction.OOC_INST_PREFIX;

boolean usedOOCMult = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.MULT);
Assert.assertTrue("OOC wasn't used for MULT", usedOOCMult);

boolean usedOOCSum = Statistics.getCPHeavyHitterOpCodes().contains(prefix + Opcodes.UAKP);
Assert.assertTrue("OOC wasn't used for SUM", usedOOCSum);

}
finally {
// reset
rtplatform = platformOld;
}
}
}
Loading
Loading