Skip to content

Commit

Permalink
pickle-less average perceptron tagger
Browse files Browse the repository at this point in the history
  • Loading branch information
alvations committed Jul 5, 2024
1 parent 8c233dc commit 8669576
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 38 deletions.
12 changes: 6 additions & 6 deletions nltk/tag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@

from nltk.data import load, find

RUS_PICKLE = (
"taggers/averaged_perceptron_tagger_ru/averaged_perceptron_tagger_ru.pickle"
)

PRETRAINED_TAGGERS = {
"rus": "taggers/averaged_perceptron_tagger_rus/",
"eng": "taggers/averaged_perceptron_tagger_eng/",
}


def _get_tagger(lang=None):
if lang == "rus":
tagger = PerceptronTagger(False)
ap_russian_model_loc = "file:" + str(find(RUS_PICKLE))
tagger.load(ap_russian_model_loc)
tagger = PerceptronTagger(lang=lang)
else:
tagger = PerceptronTagger()
return tagger
Expand Down
84 changes: 58 additions & 26 deletions nltk/tag/perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# This module is provided under the terms of the MIT License.

import logging
import pickle
import json
import random
from collections import defaultdict

Expand All @@ -22,7 +22,26 @@
except ImportError:
pass

PICKLE = "averaged_perceptron_tagger.pickle"
TRAINED_TAGGER_PATH = "averaged_perceptron_tagger/"

TAGGER_JSONS = {
'eng': {
'weights': 'averaged_perceptron_tagger_eng.weights.json',
'tagdict': 'averaged_perceptron_tagger_eng.tagdict.json',
'classes': 'averaged_perceptron_tagger_eng.classes.json'
},
'rus': {
'weights': 'averaged_perceptron_tagger_rus.weights.json',
'tagdict': 'averaged_perceptron_tagger_rus.tagdict.json',
'classes': 'averaged_perceptron_tagger_rus.classes.json'
},
'xxx': {
'weights': 'averaged_perceptron_tagger.xxx.weights.json',
'tagdict': 'averaged_perceptron_tagger.xxx.tagdict.json',
'classes': 'averaged_perceptron_tagger.xxx.classes.json'
}

}


@jsontags.register_tag
Expand Down Expand Up @@ -103,13 +122,14 @@ def average_weights(self):
self.weights[feat] = new_feat_weights

def save(self, path):
"""Save the pickled model weights."""
with open(path, "wb") as fout:
return pickle.dump(dict(self.weights), fout)
"""Save the model weights as json"""
with open(path, 'w') as fout:
return json.dump(self.weights, fout)

def load(self, path):
"""Load the pickled model weights."""
self.weights = load(path)
"""Load the json model weights."""
with open(path) as fin:
self.weights = json.load(fin)

def encode_json_obj(self):
return self.weights
Expand Down Expand Up @@ -155,18 +175,16 @@ class PerceptronTagger(TaggerI):
START = ["-START-", "-START2-"]
END = ["-END-", "-END2-"]

def __init__(self, load=True):
def __init__(self, load=True, lang='eng'):
"""
:param load: Load the pickled model upon instantiation.
:param load: Load the json model upon instantiation.
"""
self.model = AveragedPerceptron()
self.tagdict = {}
self.classes = set()
if load:
AP_MODEL_LOC = "file:" + str(
find("taggers/averaged_perceptron_tagger/" + PICKLE)
)
self.load(AP_MODEL_LOC)
self.load_from_json(lang)


def tag(self, tokens, return_conf=False, use_tagdict=True):
"""
Expand Down Expand Up @@ -198,7 +216,7 @@ def train(self, sentences, save_loc=None, nr_iter=5):
:param sentences: A list or iterator of sentences, where each sentence
is a list of (words, tags) tuples.
:param save_loc: If not ``None``, saves a pickled model in this location.
:param save_loc: If not ``None``, saves a json model in this location.
:param nr_iter: Number of training iterations.
"""
# We'd like to allow ``sentences`` to be either a list or an iterator,
Expand Down Expand Up @@ -233,23 +251,37 @@ def train(self, sentences, save_loc=None, nr_iter=5):
logging.info(f"Iter {iter_}: {c}/{n}={_pc(c, n)}")

# We don't need the training sentences anymore, and we don't want to
# waste space on them when we pickle the trained tagger.
# waste space on them when we the trained tagger.
self._sentences = None

self.model.average_weights()
# Pickle as a binary file
# Save to json files.
if save_loc is not None:
with open(save_loc, "wb") as fout:
# changed protocol from -1 to 2 to make pickling Python 2 compatible
pickle.dump((self.model.weights, self.tagdict, self.classes), fout, 2)
self.save_to_json(loc)

def load(self, loc):
"""
:param loc: Load a pickled model at location.
:type loc: str
"""

self.model.weights, self.tagdict, self.classes = load(loc)
def save_to_json(self, loc, lang='xxx'):
# TODO:
assert os.isdir(TRAINED_TAGGER_PATH), f"Path set for saving needs to be a directory"

with open(loc + TAGGER_JSONS[lang]['weights'], 'w') as fout:
json.dump(self.model.weights, fout)
with open(loc + TAGGER_JSONS[lang]['tagdict'], 'w') as fout:
json.dump(self.tagdict, fout)
with open(loc + TAGGER_JSONS[lang]['classes'], 'w') as fout:
json.dump(self.classes, fout)


def load_from_json(self, lang='eng'):
# Automatically find path to the tagger if location is not specified.
loc = find(f"taggers/averaged_perceptron_tagger_{lang}/")
with open(loc + TAGGER_JSONS[lang]['weights']) as fin:
self.model.weights = json.load(fin)
with open(loc + TAGGER_JSONS[lang]['tagdict']) as fin:
self.tagdict = json.load(fin)
with open(loc + TAGGER_JSONS[lang]['classes']) as fin:
self.classes = set(json.load(fin))

self.model.classes = self.classes

def encode_json_obj(self):
Expand Down Expand Up @@ -362,7 +394,7 @@ def _get_pretrain_model():
testing = _load_data_conll_format("english_ptb_test.conll")
print("Size of training and testing (sentence)", len(training), len(testing))
# Train and save the model
tagger.train(training, PICKLE)
tagger.train(training, TRAINED_TAGGER_PATH)
print("Accuracy : ", tagger.accuracy(testing))


Expand Down
9 changes: 3 additions & 6 deletions nltk/test/tag.doctest
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ to tag to the learned data.
The lackluster accuracy here can be explained with the following example:

>>> unigram_tagger.tag(["I", "would", "like", "this", "sentence", "to", "be", "tagged"])
[('I', 'NNP'), ('would', 'MD'), ('like', None), ('this', 'DT'), ('sentence', None),
('to', 'TO'), ('be', 'VB'), ('tagged', None)]
[('I', 'NNP'), ('would', 'MD'), ('like', None), ('this', 'DT'), ('sentence', None), ('to', 'TO'), ('be', 'VB'), ('tagged', None)]

As you can see, many tokens are tagged as ``None``, as these tokens are OOV (out of vocabulary).
The ``UnigramTagger`` has never seen them, and as a result they are not in its database of known terms.
Expand Down Expand Up @@ -429,9 +428,7 @@ templates to attempt to improve the performance of the tagger.

>>> tagged, test_stats = tagger1.batch_tag_incremental(testing_data, gold_data)
>>> tagged[33][12:]
[('foreign', 'NN'), ('debt', 'NN'), ('of', 'IN'), ('$', '$'), ('64', 'CD'),
('billion', 'CD'), ('*U*', '-NONE-'), ('--', ':'), ('the', 'DT'), ('third-highest', 'NN'),
('in', 'IN'), ('the', 'DT'), ('developing', 'VBG'), ('world', 'NN'), ('.', '.')]
[('foreign', 'NN'), ('debt', 'NN'), ('of', 'IN'), ('$', '$'), ('64', 'CD'), ('billion', 'CD'), ('*U*', '-NONE-'), ('--', ':'), ('the', 'DT'), ('third-highest', 'NN'), ('in', 'IN'), ('the', 'DT'), ('developing', 'VBG'), ('world', 'NN'), ('.', '.')]

Regression Tests
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -472,4 +469,4 @@ strictly defined.

>>> from nltk.tag import pos_tag
>>> pos_tag(['', 'is', 'a', 'beautiful', 'day'])
[...]
[('', 'NN'), ('is', 'VBZ'), ('a', 'DT'), ('beautiful', 'JJ'), ('day', 'NN')]

0 comments on commit 8669576

Please sign in to comment.