Skip to content

Commit

Permalink
added one input and two input ifft to dml
Browse files Browse the repository at this point in the history
  • Loading branch information
fzoepffel committed Jan 22, 2024
1 parent 8f240eb commit 6101dc8
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnMatrixMatrixBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnComplexMatrixBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.PMMJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
Expand Down Expand Up @@ -331,7 +331,7 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put("lu", CPType.MultiReturnBuiltin);
String2CPInstructionType.put("eigen", CPType.MultiReturnBuiltin);
String2CPInstructionType.put("fft", CPType.MultiReturnBuiltin);
String2CPInstructionType.put("ifft", CPType.MultiReturnMatrixMatrixBuiltin);
String2CPInstructionType.put("ifft", CPType.MultiReturnComplexMatrixBuiltin);
String2CPInstructionType.put("svd", CPType.MultiReturnBuiltin);

String2CPInstructionType.put("partition", CPType.Partition);
Expand Down Expand Up @@ -423,8 +423,8 @@ public static CPInstruction parseSingleInstruction(CPType cptype, String str) {
case MultiReturnParameterizedBuiltin:
return MultiReturnParameterizedBuiltinCPInstruction.parseInstruction(str);

case MultiReturnMatrixMatrixBuiltin:
return MultiReturnMatrixMatrixBuiltinCPInstruction.parseInstruction(str);
case MultiReturnComplexMatrixBuiltin:
return MultiReturnComplexMatrixBuiltinCPInstruction.parseInstruction(str);

case MultiReturnBuiltin:
return MultiReturnBuiltinCPInstruction.parseInstruction(str);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public abstract class CPInstruction extends Instruction {
public enum CPType {
AggregateUnary, AggregateBinary, AggregateTernary,
Unary, Binary, Ternary, Quaternary, BuiltinNary, Ctable,
MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, MultiReturnMatrixMatrixBuiltin,
MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, MultiReturnComplexMatrixBuiltin,
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused,
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ public static MultiReturnBuiltinCPInstruction parseInstruction(String str) {

return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str);

// } else if (opcode.equalsIgnoreCase("ifft")) {
// // one input and two outputs
// CPOperand in1 = new CPOperand(parts[1]);
// CPOperand in2 = new CPOperand(parts[2]);
// outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));
// outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX));

// return new MultiReturnBuiltinCPInstruction(null, in1, in2, outputs, opcode,
// str);
} else if (opcode.equalsIgnoreCase("svd")) {
CPOperand in1 = new CPOperand(parts[1]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,24 @@
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class MultiReturnMatrixMatrixBuiltinCPInstruction extends ComputationCPInstruction {
public class MultiReturnComplexMatrixBuiltinCPInstruction extends ComputationCPInstruction {

protected ArrayList<CPOperand> _outputs;

private MultiReturnMatrixMatrixBuiltinCPInstruction(Operator op, CPOperand input1, CPOperand input2,
private MultiReturnComplexMatrixBuiltinCPInstruction(Operator op, CPOperand input1, CPOperand input2,
ArrayList<CPOperand> outputs, String opcode,
String istr) {
super(CPType.MultiReturnBuiltin, op, input1, input2, outputs.get(0), opcode, istr);
_outputs = outputs;
}

private MultiReturnComplexMatrixBuiltinCPInstruction(Operator op, CPOperand input1, ArrayList<CPOperand> outputs,
String opcode,
String istr) {
super(CPType.MultiReturnBuiltin, op, input1, null, outputs.get(0), opcode, istr);
_outputs = outputs;
}

public CPOperand getOutput(int i) {
return _outputs.get(i);
}
Expand All @@ -57,21 +64,30 @@ public String[] getOutputNames() {
return _outputs.parallelStream().map(output -> output.getName()).toArray(String[]::new);
}

public static MultiReturnMatrixMatrixBuiltinCPInstruction parseInstruction(String str) {
public static MultiReturnComplexMatrixBuiltinCPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
ArrayList<CPOperand> outputs = new ArrayList<>();
// first part is always the opcode
String opcode = parts[0];

if (opcode.equalsIgnoreCase("ifft")) {
if (parts.length == 5 && opcode.equalsIgnoreCase("ifft")) {
// one input and two outputs
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX));

return new MultiReturnMatrixMatrixBuiltinCPInstruction(null, in1, in2, outputs, opcode, str);
} else {
return new MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, in2, outputs, opcode, str);
} else if (parts.length == 4 && opcode.equalsIgnoreCase("ifft")) {
// one input and two outputs
CPOperand in1 = new CPOperand(parts[1]);
outputs.add(new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX));
outputs.add(new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX));

return new MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, outputs, opcode, str);
}

{
throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
}

Expand All @@ -83,6 +99,27 @@ public int getNumOutputs() {

@Override
public void processInstruction(ExecutionContext ec) {
if (input2 == null)
processOneInputInstruction(ec);
else
processTwoInputInstruction(ec);
}

private void processOneInputInstruction(ExecutionContext ec) {
if (!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode());

MatrixBlock in = ec.getMatrixInput(input1.getName());
MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in, getOpcode());

ec.releaseMatrixInput(input1.getName());

for (int i = 0; i < _outputs.size(); i++) {
ec.setMatrixOutput(_outputs.get(i).getName(), out[i]);
}
}

private void processTwoInputInstruction(ExecutionContext ec) {
if (!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode());

Expand Down
Loading

0 comments on commit 6101dc8

Please sign in to comment.