Skip to content

Commit

Permalink
update check embeddings to resolve error in which vectors were being …
Browse files Browse the repository at this point in the history
…generated even when they were not requested
  • Loading branch information
xehu committed Aug 8, 2024
1 parent c9d946a commit 83c33d3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
output_file_path_chat_level = "./jury_TINY_output_chat_level.csv",
output_file_path_user_level = "./jury_TINY_output_user_level.csv",
output_file_path_conv_level = "./jury_TINY_output_conversation_level.csv",
turns = False,
turns = False
)
tiny_juries_feature_builder.featurize(col="message")

Expand Down
2 changes: 1 addition & 1 deletion src/team_comm_tools/feature_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"vect_data": False,
"bert_sentiment_data": False
},
"Positivity (BERT)": {
"Positivity (RoBERTa)": {
"columns": ["positive_bert", "negative_bert", "neutral_bert"],
"file": "./utils/check_embeddings.py",
"level": "Chat",
Expand Down
39 changes: 24 additions & 15 deletions src/team_comm_tools/utils/check_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,34 @@ def check_embeddings(chat_data, vect_path, bert_path, need_sentence, need_sentim
:return: None
:rtype: None
"""
if regenerate_vectors or (not os.path.isfile(vect_path)):
if (regenerate_vectors or (not os.path.isfile(vect_path))) and need_sentence:
generate_vect(chat_data, vect_path, message_col)
if regenerate_vectors or (not os.path.isfile(bert_path)):
if (regenerate_vectors or (not os.path.isfile(bert_path))) and need_sentiment:
generate_bert(chat_data, bert_path, message_col)
if (not os.path.isfile(Path(__file__).resolve().parent.parent/"features/lexicons/certainty.txt")):
# unpickle certainty
unpickle_certainty()

vector_df = pd.read_csv(vect_path)
bert_df = pd.read_csv(bert_path)
# check is given vector and bert data matches length of chat data
if len(vector_df) != len(chat_data):
print("ERROR: The length of the vector data does not match the length of the chat data.")
generate_vect(chat_data, vect_path, message_col)

if len(bert_df) != len(chat_data):
print("ERROR: The length of the bert data does not match the length of the chat data.")
generate_bert(chat_data, bert_path, message_col)

try:
vector_df = pd.read_csv(vect_path)
# check whether the given vector and bert data matches length of chat data
if len(vector_df) != len(chat_data):
print("ERROR: The length of the vector data does not match the length of the chat data. Regenerating...")
generate_vect(chat_data, vect_path, message_col)
except FileNotFoundError: # It's OK if we don't have the path, if the sentence vectors are not necessary
if need_sentence:
generate_vect(chat_data, vect_path, message_col)

try:
bert_df = pd.read_csv(bert_path)
if len(bert_df) != len(chat_data):
print("ERROR: The length of the sentiment data does not match the length of the chat data. Regenerating...")
generate_bert(chat_data, bert_path, message_col)
except FileNotFoundError:
if need_sentiment: # It's OK if we don't have the path, if the sentiment features are not necessary
generate_bert(chat_data, bert_path, message_col)

# Get the lexicon pickle
current_script_directory = Path(__file__).resolve().parent
LEXICON_PATH_STATIC = current_script_directory.parent/"features/lexicons_dict.pkl"
if (not os.path.isfile(LEXICON_PATH_STATIC)):
Expand Down Expand Up @@ -135,7 +144,7 @@ def generate_vect(chat_data, output_path, message_col):
:rtype: None
"""

print(f"Generating sentence vectors....")
print(f"Generating SBERT sentence vectors...")

embedding_arr = [row.tolist() for row in model_vect.encode(chat_data[message_col])]
embedding_df = pd.DataFrame({'message': chat_data[message_col], 'message_embedding': embedding_arr})
Expand All @@ -158,7 +167,7 @@ def generate_bert(chat_data, output_path, message_col):
:return: None
:rtype: None
"""
print(f"Generating BERT sentiments....")
print(f"Generating RoBERTa sentiments...")

messages = chat_data[message_col]
sentiments = messages.apply(get_sentiment)
Expand Down

0 comments on commit 83c33d3

Please sign in to comment.