Skip to content

Commit

Permalink
Caching & more informative logging for data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
tuetschek committed Mar 3, 2019
1 parent c594812 commit 0618c71
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions ratpred/futil.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tgen.data import DA
from tgen.futil import tokenize
from tgen.delex import delex_sent
from tgen.logf import log_info


def preprocess_sent(da, sent, delex_slots, delex_slot_names):
Expand Down Expand Up @@ -61,45 +62,63 @@ def read_data(filename, target_cols, das_type='cambridge',
delex_slots=set(), delex_slot_names=False, delex_das=False):
"""Read the input data from a TSV file."""

refs_cache = {}
def cached_preprocess_sent(da, sent):
"""we're caching since with generated data, we're likely to parse the same sentence many times."""
if (da, sent) not in refs_cache:
refs_cache[(da, sent)] = preprocess_sent(da, sent, delex_slots, delex_slot_names)
return list(refs_cache[(da, sent)])

log_info("Reading %s..." % filename)
data = pd.read_csv(filename, sep=b"\t", encoding='UTF-8')
log_info("Loaded %d instances." % len(data))

# force data type to string if the data set doesn't contain human references
data['orig_ref'] = data['orig_ref'].apply(lambda x: '' if not isinstance(x, basestring) else x)
log_info("Adapted refs data type.")

if das_type == 'text': # for MT output classification
das = [[(tok, None) for tok in preprocess_sent(None, sent, False, False)]
for sent in data['mr']]
else:
das = [DA.parse_cambridge_da(da) for da in data['mr']]
log_info("Parsed DAs.")

texts_ref = [[(tok, None) for tok in preprocess_sent(da, sent, delex_slots, delex_slot_names)]
texts_ref = [[(tok, None) for tok in cached_preprocess_sent(da, sent)]
for da, sent in zip(das, data['orig_ref'])]
texts_hyp = [[(tok, None) for tok in preprocess_sent(da, sent, delex_slots, delex_slot_names)]
log_info("Preprocessed human refs.")
texts_hyp = [[(tok, None) for tok in cached_preprocess_sent(da, sent)]
for da, sent in zip(das, data['system_ref'])]
log_info("Preprocessed system outputs.")

# alternative reference with rating difference / use to compare
if 'system_ref2' in data.columns:
texts_hyp2 = [[(tok, None) for tok in preprocess_sent(da, sent, delex_slots, delex_slot_names)]
texts_hyp2 = [[(tok, None) for tok in cached_preprocess_sent(da, sent)]
if isinstance(sent, basestring) else None
for da, sent in zip(das, data['system_ref2'])]
else:
texts_hyp2 = [None] * len(texts_hyp)
log_info("Preprocessed 2nd system outputs.")

# DA delexicalization must take place after text delexicalization
if das_type != 'text' and delex_das:
das = [da.get_delexicalized(delex_slots) for da in das]
log_info("Delexicalized DAs.")

# fake data indicator
if 'is_real' in data.columns:
real_indics = [0 if indic == 0 else 1 for indic in data['is_real']]
else:
real_indics = [1 for _ in xrange(len(data))]
log_info("Retrieved is_real indications.")

inputs = [(da, ref, hyp, hyp2, ri)
for da, ref, hyp, hyp2, ri in zip(das, texts_ref, texts_hyp, texts_hyp2, real_indics)]
log_info("Built inputs list.")

targets = np.array(data[[target_cols] if not isinstance(target_cols, list) else target_cols],
dtype=np.float)
log_info("Built targets list.")

return inputs, targets

Expand Down

0 comments on commit 0618c71

Please sign in to comment.