diff --git a/.gitignore b/.gitignore index f3c28571bdf..6db9af8e619 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,7 @@ venv/* # resource optimization scripts/resource/output *.pem +*.log +build_log.txt +*.log +build_log.txt diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 423679d038c..627e0ae2181 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -403,7 +403,9 @@ public enum Builtins { UNIQUE("unique", false, true), UPPER_TRI("upper.tri", false, true), XDUMMY1("xdummy1", true), //error handling test - XDUMMY2("xdummy2", true); //error handling test + XDUMMY2("xdummy2", true), //error handling test + GETCOLNAMES("getColNames", false, true), + SETCOLNAMES("setColNames", false, true); Builtins(String name, boolean script) { this(name, null, script, false, ReturnType.SINGLE_RETURN); diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 28c5a7a6a8e..d28117ed314 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -123,6 +123,7 @@ public enum Opcodes { FREPLICATE("freplicate", InstructionType.Binary), VALUESWAP("valueSwap", InstructionType.Binary), APPLYSCHEMA("applySchema", InstructionType.Binary), + SETCOLNAMES("setColNames", InstructionType.Binary), MAP("_map", InstructionType.Ternary), NMAX("nmax", InstructionType.BuiltinNary), @@ -164,6 +165,7 @@ public enum Opcodes { TYPEOF("typeOf", InstructionType.Unary), DETECTSCHEMA("detectSchema", InstructionType.Unary), COLNAMES("colnames", InstructionType.Unary), + GETCOLNAMES("getColNames", InstructionType.Unary), ISNA("isna", InstructionType.Unary), ISNAN("isnan", InstructionType.Unary), ISINF("isinf", InstructionType.Unary), diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index ae582b052b2..12f527bfaf6 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -751,12 +751,47 @@ else if(((ConstIdentifier) getThirdExpr().getOutput()) else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + + case GETCOLNAMES: + checkNumParameters(1); + Expression getNamesExpr = getFirstExpr(); + validFrameInput(getNamesExpr, _opcode.toString()); + + DataIdentifier getNamesOut = (DataIdentifier) getOutputs()[0]; + getNamesOut.setDataType(DataType.FRAME); + getNamesOut.setValueType(ValueType.STRING); + getNamesOut.setDimensions(1, getNamesExpr.getOutput().getDim2()); + getNamesOut.setBlocksize(getNamesExpr.getOutput().getBlocksize()); + break; + + case SETCOLNAMES: + checkNumParameters(2); + Expression target = getFirstExpr(); + Expression nameRow = getSecondExpr(); + validFrameInput(target, _opcode + " (first parameter)"); + validFrameInput(nameRow, _opcode + " (second parameter)"); + if (nameRow.getOutput().getDim1() != 1) { + raiseValidateError("Second parameter of set names must be a single row frame", false); + } + DataIdentifier setNamesOut = (DataIdentifier) getOutputs()[0]; + setNamesOut.setDataType(DataType.FRAME); + setNamesOut.setValueType(target.getOutput().getValueType()); + setNamesOut.setDimensions(target.getOutput().getDim1(), target.getOutput().getDim2()); + setNamesOut.setBlocksize(target.getOutput().getBlocksize()); + break; default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } } + private void validFrameInput(Expression expr, String context) { + if (expr == null || expr.getOutput() == null || expr.getOutput().getDataType() != DataType.FRAME) { + String dtype = (expr != null && expr.getOutput() != null) ? expr.getOutput().getDataType().toString() : "null"; + raiseValidateError("Expecting frame parameter for " + context, false); + } + } + private static boolean isPowerOfTwo(long n) { return (n > 0) && ((n & (n - 1)) == 0); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 63cadb43cf4..e146e1f69b4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -360,6 +360,36 @@ public void setColumnName(int index, String name) { _colnames[index] = name; } + /** + * Returns the column names of FrameBlock + * If column names are not set, return default names (e.g. "C1", "C2"...) + * @return an array of column names + * The actual function is the same to getColumnNamesAsFrame() + */ + public FrameBlock getColNames() { + return getColumnNamesAsFrame(); + } + + public void setColNames(FrameBlock names) { + if (names == null){ + throw new DMLRuntimeException("Input FrameBlock can not be null."); + } + if (names.getNumRows() != 1) { + throw new DMLRuntimeException("Input FrameBlock must be single line."); + } + if (names.getNumColumns() != this.getNumColumns()) { + throw new DMLRuntimeException("Number of columns does not match."); + } + this._colnames = new String[names.getNumColumns()]; + for (int j = 0; j < names.getNumColumns(); j++) { + String name = names.getString(0, j); + if (name == null) { + throw new DMLRuntimeException("Column names can not contain null values"); + } + _colnames[j] = name; + } + } + public ColumnMetadata[] getColumnMetadata() { return _colmeta; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index e9771b2e7fe..30a30f4464c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -63,6 +63,12 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) { final FrameBlock out = FrameLibApplySchema.applySchema(inBlock1, inBlock2, k); ec.setFrameOutput(output.getName(), out); } + else if (getOpcode().equals(Opcodes.SETCOLNAMES.toString())) { + FrameBlock fb = ec.getFrameInput(input1.getName()); + FrameBlock nameRow = ec.getFrameInput(input2.getName()); + fb.setColNames(nameRow); + ec.setFrameOutput(output.getName(), fb); + } else { // Execute binary operations BinaryOperator dop = (BinaryOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java index 107cab79d79..a39de10d30a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java @@ -52,6 +52,12 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) { ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), retBlock); } + else if (getOpcode().equals(Opcodes.GETCOLNAMES.toString())) { + FrameBlock inBlock = ec.getFrameInput(input1.getName()); + FrameBlock retBlock = inBlock.getColNames(); + ec.releaseFrameInput(input1.getName()); + ec.setFrameOutput(output.getName(), retBlock); + } else throw new DMLScriptException("Opcode '" + getOpcode() + "' is not a valid UnaryFrameCPInstruction"); } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java new file mode 100644 index 00000000000..ecc1a0763e8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesScriptTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import org.junit.Test; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.common.Types.ExecMode; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; + +public class BuiltinGetSetNamesScriptTest extends AutomatedTestBase { + private static final String TEST_NAME = "BuiltinGetSetNamesTest"; + private static final String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGetSetNamesScriptTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"})); + setExecMode(ExecMode.SINGLE_NODE); + } + + @Test + public void testGetSetNames() { + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; + String tempFilePath = output("B"); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + programArgs = new String[]{"-args", tempFilePath}; + runTest(true, false, null, -1); + try (BufferedReader br = new BufferedReader(new FileReader(tempFilePath))) { + String header = br.readLine(); + if (header == null || !header.equals("ID,Value")) { + throw new AssertionError("Test failed: Expected header 'ID,Value', but got: " + header); + } + } catch (IOException e) { + throw new AssertionError("Test failed: Unable to read output file: " + e.getMessage()); + } + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java new file mode 100644 index 00000000000..d924ef1cebc --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinGetSetNamesTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.builtin.part1; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class BuiltinGetSetNamesTest { + @Test + public void testGetDefaultNames() { + FrameBlock fb = new FrameBlock(3, ValueType.STRING); + FrameBlock names = fb.getNames(); + assertEquals("C1", names.getString(0, 0)); + assertEquals("C2", names.getString(0, 1)); + assertEquals("C3", names.getString(0, 2)); + } + + @Test + public void testSetAndGetCustomNames() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + FrameBlock nameRow = new FrameBlock(2, ValueType.STRING); + nameRow.appendRow(new String[] {"name", "age"}); + + fb.setNames(nameRow); + + FrameBlock result = fb.getNames(); + assertEquals("name", result.getString(0, 0)); + assertEquals("age", result.getString(0, 1)); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesNullFrame() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + fb.setNames(null); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesWrongRowCount() { + FrameBlock fb = new FrameBlock(2, ValueType.STRING); + FrameBlock nameRows = new FrameBlock(2, ValueType.STRING); + nameRows.appendRow(new String[] {"name", "age"}); + nameRows.appendRow(new String[] {"x", "y"}); + + fb.setNames(nameRows); + } + + @Test(expected = DMLRuntimeException.class) + public void testSetNamesWrongColCount() { + FrameBlock fb = new FrameBlock(3, ValueType.STRING); + FrameBlock nameRow = new FrameBlock(2, ValueType.STRING); + nameRow.appendRow(new String[] {"a", "b"}); + + fb.setNames(nameRow); + } +} + diff --git a/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml new file mode 100644 index 00000000000..c8a87754abb --- /dev/null +++ b/src/test/scripts/functions/builtin/BuiltinGetSetNamesTest.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +tempFile = $1 + +data = matrix(c(1, 2, 3, 4), 2, 2) +frame1 = as.frame(data) +colNames1 = as.frame(matrix(c("ID", "Value"), 1, 2)) + +frame1 = setColNames(frame1, colNames1) +retrievedNames = getColNames(frame1) + +if (!all(retrievedNames == colNames1[1,])) { + stop("Name mismatch: Expected " + toString(colNames1[1,]) + " but got " + toString(retrievedNames)) +} + +write(frame1, tempFile, format="csv", header=TRUE) +frame2 = read(tempFile, format="csv", header=TRUE) + +reloadedNames = getColNames(frame2) +if (!all(reloadedNames == colNames1[1,])) { + stop("CSV reload name mismatch: Expected " + toString(colNames1[1,]) + " but got " + toString(reloadedNames)) +} + +print("All tests passed successfully!")