Skip to content

Commit

Permalink
Upgrade dedupe dep, tweak+retrain model (#117)
Browse files Browse the repository at this point in the history
* build: Bump dedupe, 2.0.23 => 2.0.24

* build: Bump dedupe, 2.0.24 => 3.0.1

* Update dedupe usage for v3 api

* Tweak dedupe fields, data preproc

* Add new deduper model training artifacts

* Use new dedupe model in app
  • Loading branch information
bdewilde authored Jul 10, 2024
1 parent 220a03b commit a4dd491
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 75 deletions.
2 changes: 1 addition & 1 deletion colandr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
# files-on-disk config
COLANDR_APP_DIR = os.environ.get("COLANDR_APP_DIR", "/tmp")
DEDUPE_MODELS_DIR = os.path.join(
COLANDR_APP_DIR, "colandr_data", "dedupe-v2", "model_202403"
COLANDR_APP_DIR, "colandr_data", "dedupe-v2", "model_202407"
)
RANKING_MODELS_DIR = os.path.join(COLANDR_APP_DIR, "colandr_data", "ranking_models")
CITATIONS_DIR = os.path.join(COLANDR_APP_DIR, "colandr_data", "citations")
Expand Down
122 changes: 50 additions & 72 deletions colandr/lib/models/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterable

import dedupe
from dedupe import variables
from textacy import preprocessing

from .. import utils
Expand All @@ -15,62 +16,26 @@
LOGGER = logging.getLogger(__name__)

RE_DOI_HTTP = re.compile(r"^https?(://)?", flags=re.IGNORECASE)
RE_DOI_PREFIX = re.compile(r"^https?(://)?dx\.doi\.org/", flags=re.IGNORECASE)
RE_SPACE_OR_DOT = re.compile(r"[\s.]+")
RE_SPACED_HYPHEN = re.compile(r" *(–|-) *")

SETTINGS_FNAME = "deduper_settings"
TRAINING_FNAME = "deduper_training.json"
VARIABLES: list[dict[str, t.Any]] = [
{"field": "type_of_reference", "type": "Exact"},
{"field": "doi", "type": "String", "has missing": True},
{"field": "title", "type": "String", "variable name": "title"},
{
"field": "authors_joined",
"type": "String",
"has missing": True,
"variable name": "authors_joined",
},
{
"field": "authors_initials",
"type": "Set",
"has missing": True,
"variable name": "authors_initials",
},
{
"field": "pub_year",
"type": "Exact",
"has missing": True,
"variable name": "pub_year",
},
{
"field": "journal_name",
"type": "String",
"has missing": True,
"variable name": "journal_name",
},
{
"field": "journal_volume",
"type": "Exact",
"has missing": True,
"variable name": "journal_volume",
},
{
"field": "journal_issue_number",
"type": "Exact",
"has missing": True,
"variable name": "journal_issue_number",
},
{"field": "issn", "type": "String", "has missing": True, "variable name": "issn"},
{"field": "abstract", "type": "Text", "has missing": True},
{"type": "Interaction", "interaction variables": ["journal_name", "pub_year"]},
{
"type": "Interaction",
"interaction variables": [
"journal_name",
"journal_volume",
"journal_issue_number",
],
},
{"type": "Interaction", "interaction variables": ["issn", "pub_year"]},
{"type": "Interaction", "interaction variables": ["title", "authors_joined"]},
VARIABLES: list[variables.base.Variable] = [
variables.Exact("doi"),
variables.String("title", name="title"),
variables.String("authors_joined", has_missing=True, name="authors_joined"),
variables.Set("authors_initials", name="authors_initials"),
variables.Exact("pub_year", has_missing=True, name="pub_year"),
variables.String("journal_name", has_missing=True, name="journal_name"),
variables.Exact("journal_volume", name="journal_volume"),
variables.Exact("journal_number", name="journal_number"),
variables.String("issn", name="issn"),
variables.Text("abstract", has_missing=True),
variables.Interaction("journal_name", "journal_volume", "journal_number"),
variables.Interaction("issn", "pub_year"),
variables.Interaction("title", "authors_joined"),
]


Expand Down Expand Up @@ -115,42 +80,32 @@ def preprocess_data(
data: Iterable[dict[str, t.Any]],
id_key: str,
) -> dict[t.Any, dict[str, t.Any]]:
fields = [pv.field for pv in self.model.data_model.primary_variables]
fields = [pv.field for pv in self.model.data_model.field_variables]
LOGGER.info("preprocessing data with fields %s ...", fields)
return {record.pop(id_key): self._preprocess_record(record) for record in data}

def _preprocess_record(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
# base fields
record = {
"type_of_reference": (
record["type_of_reference"].strip().lower()
if record.get("type_of_reference")
else None
),
"doi": (_sanitize_doi(record["doi"]) if record.get("doi") else None),
"doi": (_standardize_doi(record["doi"]) if record.get("doi") else None),
"title": (
_standardize_str(record["title"]) if record.get("title") else None
),
"authors": (
tuple(
sorted(
_standardize_str(author.replace("-", " "))
for author in record["authors"]
)
sorted(_standardize_author(author) for author in record["authors"])
)
if record.get("authors")
else None
),
"pub_year": record.get("pub_year"),
"journal_name": (
preprocessing.remove.brackets(
_standardize_str(record["journal_name"]), only="round"
)
_standardize_journal_name(record["journal_name"])
if record.get("journal_name")
else None
),
"journal_volume": record.get("volume"),
"journal_issue_number": record.get("issue_number"),
"journal_number": record.get("issue_number"),
"issn": record["issn"].strip().lower() if record.get("issn") else None,
"abstract": (
_standardize_str(record["abstract"][:500]) # truncated for performance
Expand All @@ -161,8 +116,7 @@ def _preprocess_record(self, record: dict[str, t.Any]) -> dict[str, t.Any]:
# derivative fields
if record.get("authors"):
record["authors_initials"] = tuple(
"".join(name[0] for name in author.split())
for author in record["authors"]
_compute_author_initials(author) for author in record["authors"]
)
record["authors_joined"] = " ".join(record["authors"])
else:
Expand Down Expand Up @@ -209,14 +163,38 @@ def save(self, dir_path: str | pathlib.Path) -> None:
self.model.write_training(f)


def _sanitize_doi(value: str) -> str:
def _standardize_doi(value: str) -> str:
value = value.strip().lower()
if value.startswith("http://") or value.startswith("https://"):
value = urllib.parse.unquote(value)
value = RE_DOI_HTTP.sub("", value)
value = RE_DOI_PREFIX.sub("", value)
return value


def _standardize_author(value: str) -> str:
return _standardize_str(RE_SPACED_HYPHEN.sub(r"\1", value))


def _standardize_journal_name(
value: str, *, abbrevs_map: t.Optional[dict[str, str]] = None
) -> str:
if abbrevs_map:
value = " ".join(
abbrevs_map.get(tok, tok) for tok in RE_SPACE_OR_DOT.split(value.lower())
)
return preprocessing.remove.brackets(_standardize_str(value), only="round")


# def _standardize_issn(value: str) -> str:
# return sorted(
# preprocessing.remove.brackets(_standardize_str(value), only="round").split()
# )[0]


def _compute_author_initials(author: str) -> str:
return "".join(token[0] for token in author.split())


_standardize_str = preprocessing.make_pipeline(
functools.partial(
preprocessing.remove.punctuation, only=[".", "?", "!", ",", ";", "—"]
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions colandr_data/dedupe-v2/model_202407/deduper_training.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"bibtexparser~=1.4.0",
"celery~=5.4.0",
"click~=8.0",
"dedupe~=2.0.23",
"dedupe~=3.0.1",
"flask~=3.0",
"flask-caching~=2.1.0",
"flask-jwt-extended~=4.6.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements/prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ arrow~=1.3.0
bibtexparser~=1.4.0
celery~=5.4.0
click~=8.0
dedupe~=2.0.23
dedupe~=3.0.1
flask~=3.0.0
flask-caching~=2.1.0
flask-jwt-extended~=4.6.0
Expand Down

0 comments on commit a4dd491

Please sign in to comment.