Skip to content

Commit

Permalink
integrated 2 inputs for ifft
Browse files Browse the repository at this point in the history
  • Loading branch information
fzoepffel committed Jan 21, 2024
1 parent 1f46e22 commit 070327b
Show file tree
Hide file tree
Showing 8 changed files with 2,562 additions and 2,290 deletions.
3,292 changes: 1,662 additions & 1,630 deletions src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.sysds.runtime.instructions.cp;


import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
Expand All @@ -38,15 +37,16 @@

public abstract class CPInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(CPInstruction.class.getName());

public enum CPType {
AggregateUnary, AggregateBinary, AggregateTernary,
Unary, Binary, Ternary, Quaternary, BuiltinNary, Ctable,
MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin,
MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, MultiReturnMatrixMatrixBuiltin,
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused,
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
NoOp,
}
}

protected final CPType _cptype;
protected final boolean _requiresLabelUpdate;
Expand All @@ -64,7 +64,7 @@ protected CPInstruction(CPType type, Operator op, String opcode, String istr) {
instOpcode = opcode;
_requiresLabelUpdate = super.requiresLabelUpdate();
}

@Override
public IType getType() {
return IType.CONTROL_PROGRAM;
Expand All @@ -73,7 +73,7 @@ public IType getType() {
public CPType getCPInstructionType() {
return _cptype;
}

@Override
public boolean requiresLabelUpdate() {
return _requiresLabelUpdate;
Expand All @@ -86,92 +86,96 @@ public String getGraphString() {

@Override
public Instruction preprocessInstruction(ExecutionContext ec) {
//default preprocess behavior (e.g., debug state, lineage)
// default preprocess behavior (e.g., debug state, lineage)
Instruction tmp = super.preprocessInstruction(ec);

//instruction patching
if( tmp.requiresLabelUpdate() ) { //update labels only if required
//note: no exchange of updated instruction as labels might change in the general case
// instruction patching
if (tmp.requiresLabelUpdate()) { // update labels only if required
// note: no exchange of updated instruction as labels might change in the
// general case
String updInst = updateLabels(tmp.toString(), ec.getVariables());
tmp = CPInstructionParser.parseSingleInstruction(updInst);
// Corrected lineage trace for patched instructions
if (DMLScript.LINEAGE)
ec.traceLineage(tmp);
}
//robustness federated instructions (runtime assignment)
if( ConfigurationManager.isFederatedRuntimePlanner() ) {

// robustness federated instructions (runtime assignment)
if (ConfigurationManager.isFederatedRuntimePlanner()) {
tmp = FEDInstructionUtils.checkAndReplaceCP(tmp, ec);
//NOTE: Retracing of lineage is not needed as the lineage trace
//is same for an instruction and its FED version.
// NOTE: Retracing of lineage is not needed as the lineage trace
// is same for an instruction and its FED version.
}

tmp = PrivacyPropagator.preprocessInstruction(tmp, ec);
return tmp;
}

@Override
@Override
public abstract void processInstruction(ExecutionContext ec);

@Override
public void postprocessInstruction(ExecutionContext ec) {
if (DMLScript.LINEAGE_DEBUGGER)
ec.maintainLineageDebuggerInfo(this);
}

/**
* Takes a delimited string of instructions, and replaces ALL placeholder labels
* Takes a delimited string of instructions, and replaces ALL placeholder labels
* (such as ##mVar2## and ##Var5##) in ALL instructions.
*
* @param instList instruction list as string
*
* @param instList instruction list as string
* @param labelValueMapping local variable map
* @return instruction list after replacement
*/
public static String updateLabels (String instList, LocalVariableMap labelValueMapping) {
public static String updateLabels(String instList, LocalVariableMap labelValueMapping) {

if ( !instList.contains(Lop.VARIABLE_NAME_PLACEHOLDER) )
if (!instList.contains(Lop.VARIABLE_NAME_PLACEHOLDER))
return instList;

StringBuilder updateInstList = new StringBuilder();
String[] ilist = instList.split(Lop.INSTRUCTION_DELIMITOR);
for ( int i=0; i < ilist.length; i++ ) {
if ( i > 0 )
String[] ilist = instList.split(Lop.INSTRUCTION_DELIMITOR);

for (int i = 0; i < ilist.length; i++) {
if (i > 0)
updateInstList.append(Lop.INSTRUCTION_DELIMITOR);
updateInstList.append( updateInstLabels(ilist[i], labelValueMapping));

updateInstList.append(updateInstLabels(ilist[i], labelValueMapping));
}
return updateInstList.toString();
}

/**
* Replaces ALL placeholder strings (such as ##mVar2## and ##Var5##) in a single instruction.
*
/**
* Replaces ALL placeholder strings (such as ##mVar2## and ##Var5##) in a single
* instruction.
*
* @param inst string instruction
* @param map local variable map
* @param map local variable map
* @return string instruction after replacement
*/
private static String updateInstLabels(String inst, LocalVariableMap map) {
if ( inst.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ) {
if (inst.contains(Lop.VARIABLE_NAME_PLACEHOLDER)) {
int skip = Lop.VARIABLE_NAME_PLACEHOLDER.length();
while ( inst.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ) {
int startLoc = inst.indexOf(Lop.VARIABLE_NAME_PLACEHOLDER)+skip;
while (inst.contains(Lop.VARIABLE_NAME_PLACEHOLDER)) {
int startLoc = inst.indexOf(Lop.VARIABLE_NAME_PLACEHOLDER) + skip;
String varName = inst.substring(startLoc, inst.indexOf(Lop.VARIABLE_NAME_PLACEHOLDER, startLoc));
String replacement = getVarNameReplacement(inst, varName, map);
inst = inst.replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER + varName + Lop.VARIABLE_NAME_PLACEHOLDER, replacement);
inst = inst.replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER + varName + Lop.VARIABLE_NAME_PLACEHOLDER,
replacement);
}
}
return inst;
}

/**
* Computes the replacement string for a given variable name placeholder string
* (e.g., ##mVar2## or ##Var5##). The replacement is a HDFS filename for matrix
* variables, and is the actual value (stored in symbol table) for scalar variables.
* Computes the replacement string for a given variable name placeholder string
* (e.g., ##mVar2## or ##Var5##). The replacement is a HDFS filename for matrix
* variables, and is the actual value (stored in symbol table) for scalar
* variables.
*
* @param inst instruction
* @param inst instruction
* @param varName variable name
* @param map local variable map
* @param map local variable map
* @return string variable name
*/
private static String getVarNameReplacement(String inst, String varName, LocalVariableMap map) {
Expand All @@ -186,7 +190,8 @@ private static String getVarNameReplacement(String inst, String varName, LocalVa
replacement = "" + ((ScalarObject) val).getStringValue();
return replacement;
} else {
throw new DMLRuntimeException("Variable (" + varName + ") in Instruction (" + inst + ") is not found in the variablemap.");
throw new DMLRuntimeException(
"Variable (" + varName + ") in Instruction (" + inst + ") is not found in the variablemap.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,119 +43,114 @@ private MultiReturnBuiltinCPInstruction(Operator op, CPOperand input1, ArrayList
super(CPType.MultiReturnBuiltin, op, input1, null, outputs.get(0), opcode, istr);
_outputs = outputs;
}

public CPOperand getOutput(int i) {
return _outputs.get(i);
}

public List<CPOperand> getOutputs(){
public List<CPOperand> getOutputs() {
return _outputs;
}

public String[] getOutputNames(){
public String[] getOutputNames() {
return _outputs.parallelStream().map(output -> output.getName()).toArray(String[]::new);
}
public static MultiReturnBuiltinCPInstruction parseInstruction ( String str ) {

public static MultiReturnBuiltinCPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
ArrayList<CPOperand> outputs = new ArrayList<>();
// first part is always the opcode
String opcode = parts[0];
if ( opcode.equalsIgnoreCase("qr") ) {

if (opcode.equalsIgnoreCase("qr")) {
// one input and two ouputs
CPOperand in1 = new CPOperand(parts[1]);
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );
outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);
}
else if ( opcode.equalsIgnoreCase("lu") ) {
} else if (opcode.equalsIgnoreCase("lu")) {
CPOperand in1 = new CPOperand(parts[1]);

// one input and three outputs
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX) );

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);

}
else if ( opcode.equalsIgnoreCase("eigen") ) {
// one input and two outputs
CPOperand in1 = new CPOperand(parts[1]);
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );

outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX));

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);

}
else if ( opcode.equalsIgnoreCase("fft") ) {

} else if (opcode.equalsIgnoreCase("eigen")) {
// one input and two outputs
CPOperand in1 = new CPOperand(parts[1]);
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );
outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);

}
else if ( opcode.equalsIgnoreCase("ifft") ) {
} else if (opcode.equalsIgnoreCase("fft")) {
// one input and two outputs
CPOperand in1 = new CPOperand(parts[1]);
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );
outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);
}
else if ( opcode.equalsIgnoreCase("svd") ) {

// } else if (opcode.equalsIgnoreCase("ifft")) {
// // one input and two outputs
// CPOperand in1 = new CPOperand(parts[1]);
// CPOperand in2 = new CPOperand(parts[2]);
// outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));
// outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX));

// return new MultiReturnBuiltinCPInstruction(null, in1, in2, outputs, opcode,
// str);
} else if (opcode.equalsIgnoreCase("svd")) {
CPOperand in1 = new CPOperand(parts[1]);

// one input and three outputs
outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) );
outputs.add ( new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX) );
outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX));

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);

}
else {
} else {
throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
}

}

public int getNumOutputs() {
return _outputs.size();
}

@Override
@Override
public void processInstruction(ExecutionContext ec) {
if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
if (!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode());

MatrixBlock in = ec.getMatrixInput(input1.getName());
MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in, getOpcode());
ec.releaseMatrixInput(input1.getName());
for(int i=0; i < _outputs.size(); i++) {
for (int i = 0; i < _outputs.size(); i++) {
ec.setMatrixOutput(_outputs.get(i).getName(), out[i]);
}
}

@Override
public boolean hasSingleLineage() {
return false;
}


@Override
@SuppressWarnings("unchecked")
public Pair<String, LineageItem>[] getLineageItems(ExecutionContext ec) {
LineageItem[] inputLineage = LineageItemUtils.getLineage(ec, input1, input2, input3);
final Pair<String,LineageItem>[] ret = new Pair[_outputs.size()];
for(int i = 0; i < _outputs.size(); i++){
final Pair<String, LineageItem>[] ret = new Pair[_outputs.size()];
for (int i = 0; i < _outputs.size(); i++) {
CPOperand out = _outputs.get(i);
ret[i] = Pair.of(out.getName(), new LineageItem(getOpcode(), inputLineage));
}
return ret;
return ret;
}
}
Loading

0 comments on commit 070327b

Please sign in to comment.