Skip to content

Commit

Permalink
Merge pull request #34 from allenai/more-than-single-added-node-crashes
Browse files Browse the repository at this point in the history
More than single added node crashes
  • Loading branch information
aryehgigi authored Apr 27, 2021
2 parents 6727be7 + 543232e commit 2ea220a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
7 changes: 6 additions & 1 deletion pybart/spacy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def parse_spacy_sent(sent):
def enhance_to_spacy_doc(orig_doc, converted_sentences, remove_enhanced_extra_info, remove_bart_extra_info):
offset = 0
for orig_span, converted_sentence in zip(orig_doc.sents, converted_sentences):
added_nodes_counter = 0
node_indices_map = dict()
nodes = []
edges = []
Expand All @@ -70,7 +71,11 @@ def enhance_to_spacy_doc(orig_doc, converted_sentences, remove_enhanced_extra_in
if new_id == '0':
continue
node_indices_map[new_id.token_str] = idx
_ = nodes.append((new_id.major - 1 + offset,) if new_id.minor == 0 else ())
if new_id.minor == 0:
_ = nodes.append((new_id.major - 1 + offset,))
else:
_ = nodes.append((len(converted_sentence) + added_nodes_counter,))
added_nodes_counter += 1
for idx, tok in enumerate(converted_sentence):
new_id = tok.get_conllu_field("id").token_str
if new_id == '0':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="pybart-nlp",
version="3.2.1",
version="3.2.2",
author="Aryeh Tiktinsky",
author_email="aryehgigi@gmail.com",
description="python converter from UD-tree to BART-graph representations",
Expand Down
25 changes: 10 additions & 15 deletions tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import pybart
from pybart.conllu_wrapper import parse_conllu, serialize_conllu
from pybart import converter, pybart_globals
from pybart import converter
from pybart import api
from pybart.graph_token import add_basic_edges
from pybart.converter import convert
from pybart.converter import Convert


class TestConversions:
Expand Down Expand Up @@ -52,22 +52,17 @@ def setup_class(cls):
cur_gold[test_name][specification] = [gold_line.split()]
else:
cur_gold[test_name] = {specification: [gold_line.split()]}

@staticmethod
def setup_method():
pybart_globals.g_remove_enhanced_extra_info = False
pybart_globals.g_remove_bart_extra_info = False
pybart.converter.g_remove_node_adding_conversions = False


@classmethod
def common_logic(cls, cur_name):
name = cur_name.split("test_")[1]
for spec, sent_ in cls.out[name].items():
sent = [v.copy() for v in sent_]
add_basic_edges(sent)
converted, _ = convert([sent], True, True, True, math.inf, False, False, False, False, False,
funcs_to_cancel=list(set(api.get_conversion_names()).difference({name, "extra_inner_weak_modifier_verb_reconstruction"})))
serialized_conllu = serialize_conllu(converted, [None], False)
con = Convert([sent], True, True, True, math.inf, False, False, False, False, False,
list(set(api.get_conversion_names()).difference({name, "extra_inner_weak_modifier_verb_reconstruction"})))
converted, _ = con()
serialized_conllu = serialize_conllu(converted, [None], False, False, False)
for gold_line, out_line in zip(cls.gold[name][spec], serialized_conllu.split("\n")):
assert out_line.split() == gold_line, spec + str(print("\n")) + str([print(s) for s in serialized_conllu.split("\n")])

Expand All @@ -77,9 +72,9 @@ def common_logic_combined(cls, cur_name, rnac=False):
for spec, sent_ in cls.out[name].items():
sent = [v.copy() for v in sent_]
add_basic_edges(sent)
converted, _ = \
convert([sent], True, True, True, math.inf, False, False, rnac, False, False, funcs_to_cancel=[])
serialized_conllu = serialize_conllu(converted, [None], False)
con = Convert([sent], True, True, True, math.inf, False, False, rnac, False, False, [])
converted, _ = con()
serialized_conllu = serialize_conllu(converted, [None], False, False, False)
for gold_line, out_line in zip(cls.gold_combined[name][spec], serialized_conllu.split("\n")):
assert out_line.split() == gold_line, spec + str(print("\n")) + str([print(s) for s in serialized_conllu.split("\n")])

Expand Down

0 comments on commit 2ea220a

Please sign in to comment.