Skip to content

Commit

Permalink
Merge pull request #610 from zinggAI/issue607
Browse files Browse the repository at this point in the history
filter methods in zframe
  • Loading branch information
sonalgoyal authored Jun 14, 2023
2 parents 7ff2bbb + 454c55f commit ad59f79
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 2 deletions.
10 changes: 9 additions & 1 deletion common/client/src/main/java/zingg/common/client/ZFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
public interface ZFrame<D, R, C> {

public static final String RIGHT_JOIN = "right";

public static final String LEFT_JOIN = "left";

public ZFrame<D, R, C> cache();
public ZFrame<D, R, C> as(String s);
public String[] columns();
Expand Down Expand Up @@ -138,5 +139,12 @@ public interface ZFrame<D, R, C> {
public FieldData[] fields();

public Object getMaxVal(String colName);

public ZFrame<D, R, C> filterInCond(String colName,ZFrame<D, R, C> innerDF, String innerDFCol);

public ZFrame<D, R, C> filterNotNullCond(String colName);

public ZFrame<D, R, C> filterNullCond(String colName);


}
18 changes: 17 additions & 1 deletion spark/client/src/main/java/zingg/spark/client/SparkFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -354,5 +354,21 @@ public Object getMaxVal(String colName) {
Row r = df.agg(functions.max(colName)).head();
return r.get(0);
}


@Override
public ZFrame<Dataset<Row>, Row, Column> filterInCond(String colName,ZFrame<Dataset<Row>, Row, Column> innerDF, String innerDFCol) {
ZFrame<Dataset<Row>, Row, Column> innerDF2 = innerDF.select(innerDF.col(innerDFCol).alias(colName));
return this.joinOnCol(innerDF2, colName);
}

@Override
public ZFrame<Dataset<Row>, Row, Column> filterNotNullCond(String colName) {
return this.filter(df.col(colName).isNotNull());
}

@Override
public ZFrame<Dataset<Row>, Row, Column> filterNullCond(String colName) {
return this.filter(df.col(colName).isNull());
}

}
22 changes: 22 additions & 0 deletions spark/client/src/test/java/zingg/client/TestSparkFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,27 @@ public void testRightJoinMultiCol(){
assertEquals(10,joinedData.count());
}

@Test
public void testFilterInCond(){
SparkFrame inpData = getInputData();
SparkFrame clusterData = getClusterDataWithNull();
ZFrame<Dataset<Row>, Row, Column> filteredData = inpData.filterInCond(ColName.ID_COL, clusterData, ColName.COL_PREFIX+ ColName.ID_COL);
assertEquals(5,filteredData.count());
}

@Test
public void testFilterNotNullCond(){
SparkFrame clusterData = getClusterDataWithNull();
ZFrame<Dataset<Row>, Row, Column> filteredData = clusterData.filterNotNullCond(ColName.SOURCE_COL);
assertEquals(3,filteredData.count());
}

@Test
public void testFilterNullCond(){
SparkFrame clusterData = getClusterDataWithNull();
ZFrame<Dataset<Row>, Row, Column> filteredData = clusterData.filterNullCond(ColName.SOURCE_COL);
assertEquals(2,filteredData.count());
}


}
17 changes: 17 additions & 0 deletions spark/client/src/test/java/zingg/client/TestSparkFrameBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,23 @@ protected SparkFrame getClusterData() {
return df;
}

protected SparkFrame getClusterDataWithNull() {
Row[] rows = {
RowFactory.create( 1,100,1001,"b"),
RowFactory.create( 2,100,1002,"a"),
RowFactory.create( 3,100,2001,null),
RowFactory.create( 4,900,2002,"c"),
RowFactory.create( 5,111,9002,null)
};
StructType schema = new StructType(new StructField[] {
new StructField(ColName.COL_PREFIX+ ColName.ID_COL, DataTypes.IntegerType, false, Metadata.empty()),
new StructField(ColName.CLUSTER_COLUMN, DataTypes.IntegerType, false, Metadata.empty()),
new StructField(ColName.SCORE_COL, DataTypes.IntegerType, false, Metadata.empty()),
new StructField(ColName.SOURCE_COL, DataTypes.StringType, true, Metadata.empty())});
SparkFrame df = new SparkFrame(spark.createDataFrame(Arrays.asList(rows), schema));
return df;
}

protected void assertTrueCheckingExceptOutput(ZFrame<Dataset<Row>, Row, Column> sf1, ZFrame<Dataset<Row>, Row, Column> sf2, String message) {
assertTrue(sf1.except(sf2).isEmpty(), message);
}
Expand Down

0 comments on commit ad59f79

Please sign in to comment.