Skip to content

Commit

Permalink
Merge pull request #583 from zinggAI/dbconnect
Browse files Browse the repository at this point in the history
Running using Databricks Connect #582
  • Loading branch information
sonalgoyal authored Jun 2, 2023
2 parents ede80ab + a3ddd46 commit 28dd2a9
Show file tree
Hide file tree
Showing 20 changed files with 764 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
@JsonInclude(Include.NON_NULL)
public class Arguments implements Serializable {

private static final long serialVersionUID = 1L;
// creates DriverArgs and invokes the main object
Pipe[] output;
Pipe[] data;
Expand Down
10 changes: 10 additions & 0 deletions common/client/src/main/java/zingg/common/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*
*/
public abstract class Client<S,D,R,C,T> implements Serializable {
private static final long serialVersionUID = 1L;
protected Arguments arguments;
protected IZingg<S,D,R,C> zingg;
protected ClientOptions options;
Expand Down Expand Up @@ -283,4 +284,13 @@ public ZFrame<D,R,C> getUnmarkedRecords() {
return zingg.getUnmarkedRecords();
}

public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException {
return zingg.getTrainingDataModel();
}

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return zingg.getLabelDataViewHelper();
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package zingg.common.client;

import java.util.List;

public interface ILabelDataViewHelper<S, D, R, C> {

ZFrame<D, R, C> getClusterIdsFrame(ZFrame<D, R, C> lines);

List<R> getClusterIds(ZFrame<D, R, C> lines);

List<C> getDisplayColumns(ZFrame<D, R, C> lines, Arguments args);

ZFrame<D, R, C> getCurrentPair(ZFrame<D, R, C> lines, int index, List<R> clusterIds, ZFrame<D, R, C> clusterLines);

double getScore(ZFrame<D, R, C> currentPair);

double getPrediction(ZFrame<D, R, C> currentPair);

String getMsg1(int index, int totalPairs);

String getMsg2(double prediction, double score);

void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage);

void printMarkedRecordsStat(long positivePairsCount, long negativePairsCount, long notSurePairsCount,
long totalCount);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package zingg.common.client;

import zingg.common.client.pipe.Pipe;

public interface ITrainingDataModel<S, D, R, C> {

public void setMarkedRecordsStat(ZFrame<D, R, C> markedRecords);

public ZFrame<D, R, C> updateRecords(int matchValue, ZFrame<D, R, C> newRecords, ZFrame<D, R, C> updatedRecords);

public void updateLabellerStat(int selected_option, int increment);

public void writeLabelledOutput(ZFrame<D, R, C> records, Arguments args) throws ZinggClientException;

public void writeLabelledOutput(ZFrame<D,R,C> records, Arguments args, Pipe p) throws ZinggClientException;

public long getPositivePairsCount();

public long getNegativePairsCount();

public long getNotSurePairsCount() ;

public long getTotalCount();


}
6 changes: 5 additions & 1 deletion common/client/src/main/java/zingg/common/client/IZingg.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,9 @@ public void init(Arguments args, String license)
public ClientOptions getClientOptions();

public void setSession(S session);


public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException;

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package zingg.common.core.executor;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.Arguments;
import zingg.common.client.ClientOptions;
import zingg.common.client.ILabelDataViewHelper;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.Context;
import zingg.common.core.util.LabelMatchType;

public class LabelDataViewHelper<S,D,R,C,T> extends ZinggBase<S, D, R, C, T> implements ILabelDataViewHelper<S, D, R, C> {

private static final long serialVersionUID = 1L;
public static final Log LOG = LogFactory.getLog(LabelDataViewHelper.class);

public LabelDataViewHelper(Context<S,D,R,C,T> context, ZinggOptions zinggOptions, ClientOptions clientOptions) {
setContext(context);
setZinggOptions(zinggOptions);
setClientOptions(clientOptions);
setName(this.getClass().getName());
}

@Override
public ZFrame<D,R,C> getClusterIdsFrame(ZFrame<D,R,C> lines) {
return lines.select(ColName.CLUSTER_COLUMN).distinct();
}


@Override
public List<R> getClusterIds(ZFrame<D,R,C> lines) {
return lines.collectAsList();
}


@Override
public List<C> getDisplayColumns(ZFrame<D,R,C> lines, Arguments args) {
return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
}


@Override
public ZFrame<D,R,C> getCurrentPair(ZFrame<D,R,C> lines, int index, List<R> clusterIds, ZFrame<D,R,C> clusterLines) {
return lines.filter(lines.equalTo(ColName.CLUSTER_COLUMN,
clusterLines.getAsString(clusterIds.get(index), ColName.CLUSTER_COLUMN))).cache();
}


@Override
public double getScore(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(),ColName.SCORE_COL);
}


@Override
public double getPrediction(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(), ColName.PREDICTION_COL);
}


@Override
public String getMsg1(int index, int totalPairs) {
return String.format("\tCurrent labelling round : %d/%d pairs labelled\n", index, totalPairs);
}


@Override
public String getMsg2(double prediction, double score) {
String msg2 = "";
String matchType = LabelMatchType.get(prediction).msg;
if (prediction == ColValues.IS_NOT_KNOWN_PREDICTION) {
msg2 = String.format(
"\tZingg does not do any prediction for the above pairs as Zingg is still collecting training data to build the preliminary models.");
} else {
msg2 = String.format("\tZingg predicts the above records %s with a similarity score of %.2f",
matchType, Math.floor(score * 100) * 0.01);
}
return msg2;
}


@Override
public void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage) {
//System.out.println();
System.out.println(preMessage);
records.show(false);
System.out.println(postMessage);
System.out.println("\tWhat do you think? Your choices are: ");
System.out.println();

System.out.println("\tNo, they do not match : 0");
System.out.println("\tYes, they match : 1");
System.out.println("\tNot sure : 2");
System.out.println();
System.out.println("\tTo exit : 9");
System.out.println();
System.out.print("\tPlease enter your choice [0,1,2 or 9]: ");
}

@Override
public void printMarkedRecordsStat(long positivePairsCount,long negativePairsCount,long notSurePairsCount,long totalCount) {
String msg = String.format(
"\tLabelled pairs so far : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount,
negativePairsCount, totalCount, notSurePairsCount, totalCount);

System.out.println();
System.out.println();
System.out.println();
System.out.println(msg);
}



@Override
public void execute() throws ZinggClientException {
throw new UnsupportedOperationException();
}

@Override
public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import zingg.common.core.util.LabelMatchType;

public abstract class LabelUpdater<S,D,R,C,T> extends Labeller<S,D,R,C,T> {
protected static String name = "zingg.LabelUpdater";
private static final long serialVersionUID = 1L;
protected static String name = "zingg.common.core.executor.LabelUpdater";
public static final Log LOG = LogFactory.getLog(LabelUpdater.class);

public LabelUpdater() {
Expand All @@ -33,12 +34,18 @@ public void execute() throws ZinggClientException {
}
}

public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
LOG.info("Processing Records for CLI updateLabelling");

if (lines != null && lines.count() > 0) {
getMarkedRecordsStat(lines);
printMarkedRecordsStat();
getTrainingDataModel().setMarkedRecordsStat(lines);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);


List<C> displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
try {
Expand All @@ -52,7 +59,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
do {
System.out.print("\n\tPlease enter the cluster id (or 9 to exit): ");
String cluster_id = sc.next();
if (cluster_id.equals("9")) {
if (cluster_id.equals(QUIT_LABELING.toString())) {
LOG.info("User has exit in the middle. Updating the records.");
break;
}
Expand All @@ -67,10 +74,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
String matchType = LabelMatchType.get(matchFlag).msg;
postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg);
updateLabellerStat(selectedOption, +1);
updateLabellerStat(matchFlag, -1);
printMarkedRecordsStat();
if (selectedOption == 9) {
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);

if (selectedOption == QUIT_LABELING) {
LOG.info("User has quit in the middle. Updating the records.");
break;
}
Expand All @@ -80,15 +93,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
updatedRecords = updatedRecords
.filter(updatedRecords.notEqual(ColName.CLUSTER_COLUMN,cluster_id));
}
updatedRecords = updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != 9);
updatedRecords = getTrainingDataModel().updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != QUIT_LABELING);

if (updatedRecords != null) {
updatedRecords = updatedRecords.union(recordsToUpdate);
}
writeLabelledOutput(updatedRecords);
getTrainingDataModel().writeLabelledOutput(updatedRecords,args,getOutputPipe());
sc.close();
LOG.info("Processing finished.");
return updatedRecords;
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
e.printStackTrace();
Expand All @@ -98,6 +112,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
}
} else {
LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.");
return null;
}
}

Expand Down
Loading

0 comments on commit 28dd2a9

Please sign in to comment.