Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running using Databricks Connect #582 #583

Merged
merged 25 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e54046d
using dbfs
vikasgupta78 May 2, 2023
64f7a06
refactor labeller to move common methods to training helper
vikasgupta78 May 5, 2023
681465e
supporting both normal and data bricks connect client
vikasgupta78 May 6, 2023
e0371aa
getting spark session thru new method in client
vikasgupta78 May 7, 2023
3eab8d7
getting spark session thru new method in client
vikasgupta78 May 7, 2023
99e9ca1
getting spark session thru new method in client
vikasgupta78 May 7, 2023
fe01f2c
getting spark session and JVM thru new method in client
vikasgupta78 May 8, 2023
6764867
label updater should overwrite and not append
vikasgupta78 May 8, 2023
4051246
Running using Databricks Connect
vikasgupta78 May 8, 2023
7097759
indentation issue
vikasgupta78 May 11, 2023
f23c1b2
util method to write using pandas df
vikasgupta78 May 12, 2023
645b03d
Merge pull request #1 from zinggAI/0.3.5
vikasgupta78 May 26, 2023
48b2134
refactor view and model into separate classes 1st cut
vikasgupta78 May 26, 2023
36878b7
extra null check removed
vikasgupta78 May 30, 2023
552d091
constant QUIT_LABELING = 9 defined
vikasgupta78 May 30, 2023
ea7e8f4
constant INCREMENT = 1 defined
vikasgupta78 May 30, 2023
47b493a
refactoring
vikasgupta78 May 31, 2023
18f9c77
lazy initialization
vikasgupta78 May 31, 2023
40b817a
compile error
vikasgupta78 May 31, 2023
26f1135
label update methods and refactoring
vikasgupta78 May 31, 2023
2df6704
validity check added
vikasgupta78 May 31, 2023
c5abe56
compile issues resolved
vikasgupta78 May 31, 2023
53ae447
DB Connect check added
vikasgupta78 May 31, 2023
880b4f6
syntax issues
vikasgupta78 May 31, 2023
a3ddd46
shell script changes
vikasgupta78 Jun 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
@JsonInclude(Include.NON_NULL)
public class Arguments implements Serializable {

private static final long serialVersionUID = 1L;
// creates DriverArgs and invokes the main object
Pipe[] output;
Pipe[] data;
Expand Down
10 changes: 10 additions & 0 deletions common/client/src/main/java/zingg/common/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*
*/
public abstract class Client<S,D,R,C,T> implements Serializable {
private static final long serialVersionUID = 1L;
protected Arguments arguments;
protected IZingg<S,D,R,C> zingg;
protected ClientOptions options;
Expand Down Expand Up @@ -283,4 +284,13 @@ public ZFrame<D,R,C> getUnmarkedRecords() {
return zingg.getUnmarkedRecords();
}

public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException {
return zingg.getTrainingDataModel();
}

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return zingg.getLabelDataViewHelper();
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package zingg.common.client;

import java.util.List;

public interface ILabelDataViewHelper<S, D, R, C> {

ZFrame<D, R, C> getClusterIdsFrame(ZFrame<D, R, C> lines);

List<R> getClusterIds(ZFrame<D, R, C> lines);

List<C> getDisplayColumns(ZFrame<D, R, C> lines, Arguments args);

ZFrame<D, R, C> getCurrentPair(ZFrame<D, R, C> lines, int index, List<R> clusterIds, ZFrame<D, R, C> clusterLines);

double getScore(ZFrame<D, R, C> currentPair);

double getPrediction(ZFrame<D, R, C> currentPair);

String getMsg1(int index, int totalPairs);

String getMsg2(double prediction, double score);

void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage);

void printMarkedRecordsStat(long positivePairsCount, long negativePairsCount, long notSurePairsCount,
long totalCount);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package zingg.common.client;

import zingg.common.client.pipe.Pipe;

public interface ITrainingDataModel<S, D, R, C> {

public void setMarkedRecordsStat(ZFrame<D, R, C> markedRecords);

public ZFrame<D, R, C> updateRecords(int matchValue, ZFrame<D, R, C> newRecords, ZFrame<D, R, C> updatedRecords);

public void updateLabellerStat(int selected_option, int increment);

public void writeLabelledOutput(ZFrame<D, R, C> records, Arguments args) throws ZinggClientException;

public void writeLabelledOutput(ZFrame<D,R,C> records, Arguments args, Pipe p) throws ZinggClientException;

public long getPositivePairsCount();

public long getNegativePairsCount();

public long getNotSurePairsCount() ;

public long getTotalCount();


}
6 changes: 5 additions & 1 deletion common/client/src/main/java/zingg/common/client/IZingg.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,9 @@ public void init(Arguments args, String license)
public ClientOptions getClientOptions();

public void setSession(S session);


public ITrainingDataModel<S, D, R, C> getTrainingDataModel() throws UnsupportedOperationException;

public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package zingg.common.core.executor;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.Arguments;
import zingg.common.client.ClientOptions;
import zingg.common.client.ILabelDataViewHelper;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.Context;
import zingg.common.core.util.LabelMatchType;

public class LabelDataViewHelper<S,D,R,C,T> extends ZinggBase<S, D, R, C, T> implements ILabelDataViewHelper<S, D, R, C> {

private static final long serialVersionUID = 1L;
public static final Log LOG = LogFactory.getLog(LabelDataViewHelper.class);

public LabelDataViewHelper(Context<S,D,R,C,T> context, ZinggOptions zinggOptions, ClientOptions clientOptions) {
setContext(context);
setZinggOptions(zinggOptions);
setClientOptions(clientOptions);
setName(this.getClass().getName());
}

@Override
public ZFrame<D,R,C> getClusterIdsFrame(ZFrame<D,R,C> lines) {
return lines.select(ColName.CLUSTER_COLUMN).distinct();
}


@Override
public List<R> getClusterIds(ZFrame<D,R,C> lines) {
return lines.collectAsList();
}


@Override
public List<C> getDisplayColumns(ZFrame<D,R,C> lines, Arguments args) {
return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
}


@Override
public ZFrame<D,R,C> getCurrentPair(ZFrame<D,R,C> lines, int index, List<R> clusterIds, ZFrame<D,R,C> clusterLines) {
return lines.filter(lines.equalTo(ColName.CLUSTER_COLUMN,
clusterLines.getAsString(clusterIds.get(index), ColName.CLUSTER_COLUMN))).cache();
}


@Override
public double getScore(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(),ColName.SCORE_COL);
}


@Override
public double getPrediction(ZFrame<D,R,C> currentPair) {
return currentPair.getAsDouble(currentPair.head(), ColName.PREDICTION_COL);
}


@Override
public String getMsg1(int index, int totalPairs) {
return String.format("\tCurrent labelling round : %d/%d pairs labelled\n", index, totalPairs);
}


@Override
public String getMsg2(double prediction, double score) {
String msg2 = "";
String matchType = LabelMatchType.get(prediction).msg;
if (prediction == ColValues.IS_NOT_KNOWN_PREDICTION) {
msg2 = String.format(
"\tZingg does not do any prediction for the above pairs as Zingg is still collecting training data to build the preliminary models.");
} else {
msg2 = String.format("\tZingg predicts the above records %s with a similarity score of %.2f",
matchType, Math.floor(score * 100) * 0.01);
}
return msg2;
}


@Override
public void displayRecords(ZFrame<D, R, C> records, String preMessage, String postMessage) {
//System.out.println();
System.out.println(preMessage);
records.show(false);
System.out.println(postMessage);
System.out.println("\tWhat do you think? Your choices are: ");
System.out.println();

System.out.println("\tNo, they do not match : 0");
System.out.println("\tYes, they match : 1");
System.out.println("\tNot sure : 2");
System.out.println();
System.out.println("\tTo exit : 9");
System.out.println();
System.out.print("\tPlease enter your choice [0,1,2 or 9]: ");
}

@Override
public void printMarkedRecordsStat(long positivePairsCount,long negativePairsCount,long notSurePairsCount,long totalCount) {
String msg = String.format(
"\tLabelled pairs so far : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount,
negativePairsCount, totalCount, notSurePairsCount, totalCount);

System.out.println();
System.out.println();
System.out.println();
System.out.println(msg);
}



@Override
public void execute() throws ZinggClientException {
throw new UnsupportedOperationException();
}

@Override
public ILabelDataViewHelper<S, D, R, C> getLabelDataViewHelper() throws UnsupportedOperationException {
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import zingg.common.core.util.LabelMatchType;

public abstract class LabelUpdater<S,D,R,C,T> extends Labeller<S,D,R,C,T> {
protected static String name = "zingg.LabelUpdater";
private static final long serialVersionUID = 1L;
protected static String name = "zingg.common.core.executor.LabelUpdater";
public static final Log LOG = LogFactory.getLog(LabelUpdater.class);

public LabelUpdater() {
Expand All @@ -33,12 +34,18 @@ public void execute() throws ZinggClientException {
}
}

public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
public ZFrame<D,R,C> processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to return a zframe here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done so that writing of labelled output happens in a a separate method. This is needed for python api to work.

LOG.info("Processing Records for CLI updateLabelling");

if (lines != null && lines.count() > 0) {
getMarkedRecordsStat(lines);
printMarkedRecordsStat();
getTrainingDataModel().setMarkedRecordsStat(lines);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);


List<C> displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise());
try {
Expand All @@ -52,7 +59,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
do {
System.out.print("\n\tPlease enter the cluster id (or 9 to exit): ");
String cluster_id = sc.next();
if (cluster_id.equals("9")) {
if (cluster_id.equals(QUIT_LABELING.toString())) {
LOG.info("User has exit in the middle. Updating the records.");
break;
}
Expand All @@ -67,10 +74,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
String matchType = LabelMatchType.get(matchFlag).msg;
postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg);
updateLabellerStat(selectedOption, +1);
updateLabellerStat(matchFlag, -1);
printMarkedRecordsStat();
if (selectedOption == 9) {
getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT);
getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT);
getLabelDataViewHelper().printMarkedRecordsStat(
getTrainingDataModel().getPositivePairsCount(),
getTrainingDataModel().getNegativePairsCount(),
getTrainingDataModel().getNotSurePairsCount(),
getTrainingDataModel().getTotalCount()
);

if (selectedOption == QUIT_LABELING) {
LOG.info("User has quit in the middle. Updating the records.");
break;
}
Expand All @@ -80,15 +93,16 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
updatedRecords = updatedRecords
.filter(updatedRecords.notEqual(ColName.CLUSTER_COLUMN,cluster_id));
}
updatedRecords = updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != 9);
updatedRecords = getTrainingDataModel().updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != QUIT_LABELING);

if (updatedRecords != null) {
updatedRecords = updatedRecords.union(recordsToUpdate);
}
writeLabelledOutput(updatedRecords);
getTrainingDataModel().writeLabelledOutput(updatedRecords,args,getOutputPipe());
sc.close();
LOG.info("Processing finished.");
return updatedRecords;
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
e.printStackTrace();
Expand All @@ -98,6 +112,7 @@ public void processRecordsCli(ZFrame<D,R,C> lines) throws ZinggClientException {
}
} else {
LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.");
return null;
}
}

Expand Down
Loading