Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DB mode issues fixes #184

Merged
merged 3 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions convokit/model/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def __init__(
"You are in DB mode, but no collection prefix was specified and no filename was given from which to infer one."
"Will use a randomly generated unique prefix " + db_collection_prefix
)
self.id = get_corpus_id(db_collection_prefix, filename)
self.id = get_corpus_id(
db_collection_prefix, filename, check_mongodb_compatibility=(storage_type == "db")
jpwchang marked this conversation as resolved.
Show resolved Hide resolved
)
self.storage_type = storage_type
self.storage = initialize_storage(self, storage, storage_type, db_host)

self.meta_index = ConvoKitIndex(self)
Expand Down Expand Up @@ -622,7 +625,7 @@ def filter_conversations_by(self, selector: Callable[[Conversation], bool]):
for speaker in self.iter_speakers():
meta_ids.append(speaker.meta.storage_key)
self.storage.purge_obsolete_entries(
self.get_utterance_ids, self.get_conversation_ids(), self.get_speaker_ids(), meta_ids
self.get_utterance_ids(), self.get_conversation_ids(), self.get_speaker_ids(), meta_ids
)

return self
Expand Down Expand Up @@ -1278,14 +1281,10 @@ def load_info(self, obj_type, fields=None, dir_name=None):
for field in fields:
# self.aux_info[field] = self.load_jsonlist_to_dict(
# os.path.join(dir_name, 'feat.%s.jsonl' % field))
getter = lambda oid: self.get_object(obj_type, oid)
entries = load_jsonlist_to_dict(os.path.join(dir_name, "info.%s.jsonl" % field))
for k, v in entries.items():
try:
obj = getter(k)
obj.add_meta(field, v)
except:
continue
if self.storage_type == "mem":
load_info_to_mem(self, dir_name, obj_type, field)
elif self.storage_type == "db":
load_info_to_db(self, dir_name, obj_type, field)

def dump_info(self, obj_type, fields, dir_name=None):
"""
Expand Down
106 changes: 101 additions & 5 deletions convokit/model/corpus_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import bson
from pymongo import UpdateOne

from convokit.util import warn
from convokit.util import warn, create_safe_id
from .conversation import Conversation
from .convoKitIndex import ConvoKitIndex
from .convoKitMeta import ConvoKitMeta
Expand All @@ -33,15 +33,44 @@
JSONLIST_BUFFER_SIZE = 1000


def get_corpus_id(db_collection_prefix: Optional[str], filename: Optional[str]) -> Optional[str]:
def get_corpus_id(
db_collection_prefix: Optional[str], filename: Optional[str], check_mongodb_compatibility: bool
) -> Optional[str]:
if db_collection_prefix is not None:
# treat the unique collection prefix as the ID (even if a filename is specified)
return db_collection_prefix
corpus_id = db_collection_prefix
elif filename is not None:
# automatically derive an ID from the file path
return os.path.basename(os.path.normpath(filename))
corpus_id = os.path.basename(os.path.normpath(filename))
else:
return None
corpus_id = None

if check_mongodb_compatibility and corpus_id is not None:
compatibility_msg = check_id_for_mongodb(corpus_id)
if compatibility_msg is not None:
random_id = create_safe_id()
warn(
f'Attempting to use "{corpus_id}" as DB collection prefix failed because: {compatibility_msg}. Will instead use randomly generated prefix {random_id}.'
)
corpus_id = random_id

return corpus_id


def check_id_for_mongodb(corpus_id):
# List of collection name restrictions from official MongoDB docs:
# https://www.mongodb.com/docs/manual/reference/limits/#mongodb-limit-Restriction-on-Collection-Names
if "$" in corpus_id:
return "contains the restricted character '$'"
if len(corpus_id) == 0:
return "string is empty"
if "\0" in corpus_id:
return "contains a null character"
if "system." in corpus_id:
return 'starts with the restricted prefix "system."'
if not (corpus_id[0] == "_" or corpus_id[0].isalpha()):
return "name must start with an underscore or letter character"
return None


def get_corpus_dirpath(filename: str) -> Optional[str]:
Expand Down Expand Up @@ -648,6 +677,7 @@ def load_jsonlist_to_db(
if reply_key is None:
# fix for misnamed reply_to in subreddit corpora
reply_key = "reply-to" if "reply-to" in utt_obj else "reply_to"
utt_obj = defaultdict(lambda: None, utt_obj)
utt_insertion_buffer.append(
UpdateOne(
{"_id": utt_obj["id"]},
Expand Down Expand Up @@ -767,6 +797,72 @@ def load_corpus_info_to_db(filename, db, collection_prefix, exclude_meta=None, b
)


def load_info_to_mem(corpus, dir_name, obj_type, field):
"""
Helper for load_info in mem mode that reads the file for the specified extra
info field, loads it into memory, and assigns the entries to their
corresponding corpus components.
"""
getter = lambda oid: corpus.get_object(obj_type, oid)
entries = load_jsonlist_to_dict(os.path.join(dir_name, "info.%s.jsonl" % field))
for k, v in entries.items():
try:
obj = getter(k)
obj.add_meta(field, v)
except:
continue


def load_info_to_db(corpus, dir_name, obj_type, field, index_key="id", value_key="value"):
"""
Helper for load_info in DB mode that reads the jsonlist file for the
specified extra info field in a batched line-by-line manner, populates
its contents into the DB, and updates the Corpus' metadata index
"""
filename = os.path.join(dir_name, "info.%s.jsonl" % field)
meta_collection = corpus.storage.get_collection("meta")

# attept to use saved type information
index_file = os.path.join(dir_name, "index.json")
with open(index_file) as f:
raw_index = json.load(f)
try:
field_type = raw_index[f"{obj_type}s-index"][field]
corpus.meta_index.get_index(obj_type)[field] = field_type
index_updated = True
except:
# field not recorded in the index file; we will need to infer
# types during insertion time
index_updated = False

# iteratively insert the info in the DB in batched fashion
with open(filename) as f:
info_insertion_buffer = []
for line in f:
info_json = json.loads(line)
obj_id, info_val = info_json[index_key], info_json[value_key]
if not index_updated:
# we were previously unable to fetch the type info from the
# index file, so we must infer it now
ConvoKitMeta._check_type_and_update_index(
corpus.meta_index, obj_type, field, info_val
)
info_insertion_buffer.append(
UpdateOne(
{"_id": "{}_{}".format(obj_type, obj_id)},
{"$set": {field: info_val}},
upsert=True,
)
)
if len(info_insertion_buffer) >= JSONLIST_BUFFER_SIZE:
meta_collection.bulk_write(info_insertion_buffer)
info_insertion_buffer = []
# after loop termination, insert any remaining items in the buffer
if len(info_insertion_buffer) > 0:
meta_collection.bulk_write(info_insertion_buffer)
info_insertion_buffer = []


def clean_up_excluded_meta(meta_index, exclude_meta):
"""
Remove excluded metadata from the metadata index
Expand Down
12 changes: 4 additions & 8 deletions convokit/model/storageManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,10 @@ def _get_collection_name(self, component_type: str) -> str:
return f"{self.collection_prefix}_{component_type}"

def get_collection_ids(self, component_type: str):
# from StackOverflow: get all keys in a MongoDB collection
# https://stackoverflow.com/questions/2298870/get-names-of-all-keys-in-the-collection
map = bson.Code("function() { for (var key in this) { emit(key, null); } }")
reduce = bson.Code("function(key, stuff) { return null; }")
result = self.db[self._get_collection_name(component_type)].map_reduce(
map, reduce, "get_collection_ids_result"
)
return result.distinct("_id")
return [
doc["_id"]
for doc in self.db[self._get_collection_name(component_type)].find(projection=["_id"])
]

def has_data_for_component(self, component_type: str, component_id: str) -> bool:
collection = self.get_collection(component_type)
Expand Down
18 changes: 9 additions & 9 deletions convokit/prompt_types/promptTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,30 +213,30 @@ def transform(
reference_min_dists = reference_dists.min(axis=1)

corpus.set_vector_matrix(
self.output_field + "__prompt_dists.%s" % self.default_n_types,
self.output_field + "__prompt_dists__%s" % self.default_n_types,
ids=prompt_df.index,
matrix=prompt_dists,
columns=["type_%d_dist" % x for x in range(prompt_dists.shape[1])],
)
corpus.set_vector_matrix(
self.output_field + "__reference_dists.%s" % self.default_n_types,
self.output_field + "__reference_dists__%s" % self.default_n_types,
ids=reference_df.index,
matrix=reference_dists,
columns=["type_%d_dist" % x for x in range(prompt_dists.shape[1])],
)
for id, assign, dist in zip(prompt_df.index, prompt_assigns, prompt_min_dists):
corpus.get_utterance(id).add_meta(
self.output_field + "__prompt_type.%s" % self.default_n_types, assign
self.output_field + "__prompt_type__%s" % self.default_n_types, assign
)
corpus.get_utterance(id).add_meta(
self.output_field + "__prompt_type_dist.%s" % self.default_n_types, float(dist)
self.output_field + "__prompt_type_dist__%s" % self.default_n_types, float(dist)
)
for id, assign, dist in zip(reference_df.index, reference_assigns, reference_min_dists):
corpus.get_utterance(id).add_meta(
self.output_field + "__reference_type.%s" % self.default_n_types, assign
self.output_field + "__reference_type__%s" % self.default_n_types, assign
)
corpus.get_utterance(id).add_meta(
self.output_field + "__reference_type_dist.%s" % self.default_n_types, float(dist)
self.output_field + "__reference_type_dist__%s" % self.default_n_types, float(dist)
)
return corpus

Expand Down Expand Up @@ -274,13 +274,13 @@ def _transform_utterance_side(self, utterance, side):
min_dist = min(dists)
assign = vals[-1]
utterance.add_meta(
self.output_field + "__%s_type.%s" % (side, self.default_n_types), assign
self.output_field + "__%s_type__%s" % (side, self.default_n_types), assign
)
utterance.add_meta(
self.output_field + "__%s_type_dist.%s" % (side, self.default_n_types), float(min_dist)
self.output_field + "__%s_type_dist__%s" % (side, self.default_n_types), float(min_dist)
)
utterance.add_meta(
self.output_field + "__%s_dists.%s" % (side, self.default_n_types),
self.output_field + "__%s_dists__%s" % (side, self.default_n_types),
[float(x) for x in dists],
)
utterance.add_meta(self.output_field + "__%s_repr" % side, [float(x) for x in utt_vects[0]])
Expand Down
2 changes: 1 addition & 1 deletion convokit/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def burr_spacy_sentence_doc_4():


def reload_corpus_in_db_mode(corpus):
corpus_id = uuid4().hex
corpus_id = "_" + uuid4().hex
try:
corpus.dump(corpus_id, base_path=".")
db_corpus = Corpus(corpus_id, storage_type="db")
Expand Down
2 changes: 1 addition & 1 deletion convokit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,4 @@ def deprecation(prev_name: str, new_name: str, stacklevel: int = 3):


def create_safe_id():
return uuid.uuid4().hex
return "_" + uuid.uuid4().hex
Loading