Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_last_documents to return valid Document objects #201

Merged
merged 3 commits into from
Dec 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Added

max_docs = 10

# Define specific parser for the third column (index 2), which takes ``text``,
# Define specific parser for the third column (index 2), which takes ``text``,
# ``name=None``, ``type="text"``, and ``delim=None`` as input and generate
# ``(content type, content name, content)`` for ``build_node``
# in ``fonduer.utils.utils_parser``.
Expand All @@ -46,6 +46,9 @@ Added
Fixed
^^^^^
* `@HiromuHota`_: Modify docstring of functions that return get_sparse_matrix
* `@lukehsiao`_: Fix the behavior of ``get_last_documents`` to return Documents
that are correctly linked to the database and can be navigated by the user.
(`#201 <https://github.com/HazyResearch/fonduer/pull/201>`_)

Changed
^^^^^^^
Expand Down
12 changes: 12 additions & 0 deletions src/fonduer/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def clear(self, pdf_path=None):
"""
self.session.query(Context).delete()

def get_last_documents(self):
"""Return the most recently parsed list of ``Documents``.

:rtype: A list of the most recently parsed ``Documents`` ordered by name.
"""
return (
self.session.query(Document)
.filter(Document.name.in_(self.last_docs))
.order_by(Document.name)
.all()
)

def get_documents(self):
"""Return all the parsed ``Documents`` in the database.

Expand Down
17 changes: 3 additions & 14 deletions src/fonduer/utils/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def apply(
if clear:
self.clear(**kwargs)

# Clear the last operated documents
self.last_docs.clear()
# Track the last documents parsed by apply
self.last_docs = set(doc.name for doc in doc_loader)

# Execute the UDF
self.logger.info("Running UDF...")
Expand Down Expand Up @@ -80,20 +80,12 @@ def clear(self, **kwargs):
"""Clear the associated data from the database."""
raise NotImplementedError()

def get_last_documents(self):
"""Return the last set of documents that was operated on with apply().

:rtype: list of ``Documents`` operated on in the last call to ``apply()``.
"""
return list(self.last_docs)

def _apply_st(self, doc_loader, **kwargs):
"""Run the UDF single-threaded, optionally with progress bar"""
udf = self.udf_class(**self.udf_init_kwargs)

# Run single-thread
for doc in doc_loader:
self.last_docs.add(doc)
if self.pb is not None:
self.pb.update(1)

Expand All @@ -120,16 +112,13 @@ def fill_input_queue(in_queue, doc_loader, terminal_signal):

total_count = len(doc_loader)

for doc in doc_loader:
self.last_docs.add(doc)

# Start UDF Processes
for i in range(parallelism):
udf = self.udf_class(
in_queue=in_queue,
out_queue=out_queue,
worker_id=i,
**self.udf_init_kwargs
**self.udf_init_kwargs,
)
udf.apply_kwargs = kwargs
self.udfs.append(udf)
Expand Down
11 changes: 9 additions & 2 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def test_incremental(caplog):
assert num_docs == max_docs

docs = corpus_parser.get_documents()
last_docs = corpus_parser.get_documents()

assert len(docs[0].sentences) == len(last_docs[0].sentences)

# Mention Extraction
part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
Expand Down Expand Up @@ -242,9 +245,13 @@ def test_e2e(caplog):
logger.info("Sentences: {}".format(num_sentences))

# Divide into test and train
docs = corpus_parser.get_documents()
docs = sorted(corpus_parser.get_documents())
last_docs = sorted(corpus_parser.get_last_documents())

ld = len(docs)
assert ld == len(corpus_parser.get_last_documents())
assert ld == len(last_docs)
assert len(docs[0].sentences) == len(last_docs[0].sentences)

assert len(docs[0].sentences) == 799
assert len(docs[1].sentences) == 663
assert len(docs[2].sentences) == 784
Expand Down