Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoupling language model computes from within Surprise #187

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions convokit/speakerConvoDiversity/speakerConvoDiversity2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def transform(self, corpus):
corpus,
"speaker",
target_text_func=lambda utt: self._get_utt_row(utt, input_table).tokens,
smooth=False
)
self._set_output(corpus, input_table)
return corpus
Expand Down Expand Up @@ -163,8 +164,7 @@ def _init_surprise(self, model_key_selector):
surprise_attr_name=self.surprise_attr_name,
target_sample_size=target_sample_size,
context_sample_size=context_sample_size,
n_samples=n_samples,
smooth=False,
n_samples=n_samples
)

def _get_text_func(self, utt: Utterance, df: pd.DataFrame):
Expand Down
14 changes: 14 additions & 0 deletions convokit/surprise/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
import importlib.util
import sys

from .convokit_lm import *
from .language_model import *
from .surprise import *

if "kenlm" in sys.modules:
from .kenlm import *
elif (spec := importlib.util.find_spec("kenlm")) is not None:
module = importlib.util.module_from_spec(spec)
sys.modules["kenlm"] = module
spec.loader.exec_module(module)

from .kenlm import *
86 changes: 86 additions & 0 deletions convokit/surprise/convokit_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections import Counter
from typing import Optional, Any, Union, List

import numpy as np

from .language_model import LanguageModel


class ConvoKitLanguageModel(LanguageModel):
"""A simple language model to compute the deviation of target from context.

This language model implements cross-entropy and perplexity language model evaluation functions,
to be used in evaluating the average deviation of target from the specified context.

:param model_type: The name (identifier) of the :py:class:`~convokit.ConvoKitLanguageModel`,
defaults to "convokit_lm". Note that the `model_type` can be accessed using the `type`
property (e.g., `lm.type`).
:param kwargs: Any additional keyword arguments needed in the language model evaluations. This
language model currently uses the following keyword arguments:

* `smooth`: Indicator of using Laplace smoothing in the computation of cross-entropy scores,
defaults to `True`.
* `n_jobs`: The number of concurrent threads to be used for routines that are parallelized
with `joblib`, defaults to 1.

The language model configuration can be retrieved using the `config` property of the model
class object (e.g., `lm.config`).
"""

def __init__(self, model_type: str = "convokit_lm", **kwargs: Optional[Any]):
super().__init__(model_type, **kwargs)

self._smooth = kwargs["smooth"] if "smooth" in kwargs else True

def cross_entropy(
self,
target: Union[List[str], np.ndarray],
context: Union[List[str], np.ndarray],
) -> float:
r"""Implements the base class method to compute the cross-entropy.

Calculates :math:`H(P, Q) = -\sum_{x \in X}P(x) \times \ln(Q(x))`. Note that we use the
natural logarithm; however, any base and corresponding exponent can be employed. For
instance, KenLM uses base-10 (see :py:class:`~convokit.Kenlm` for reference).

The smoothing boolean argument, `smooth`, is accessed from the setting in the language model
constructor (defaults to `True` when unspecified).

:param target: A list of tokens that make up the target text (P).
:param context: A list of tokens that make up the context text (Q).
:return: The cross-entropy score computed as :math:`H(P, Q)`.
"""
n_target, n_context = len(target), len(context)
if min(n_target, n_context) == 0:
return np.nan

context_counts = Counter(context)
smooth_v = len(context_counts) + 1 if self._smooth else 0
smooth_k = 1 if self._smooth else 0
value = 0 if self._smooth else 1

return (
sum(
-np.log((context_counts.get(token, value) + smooth_k) / (n_context + smooth_v))
for token in target
)
/ n_target
)

def perplexity(
self, target: Union[List[str], np.ndarray], context: Union[List[str], np.ndarray]
) -> float:
r"""Implements the base class method to compute perplexity.

Calculates :math:`\text{PPL}(P, Q) = \exp(-\sum_{x \in X}P(x) \times \ln(Q(x)))`. Note that
we use the natural logarithm; however, any base and corresponding exponent can be employed.
For instance, KenLM uses base-10 (see :py:class:`~convokit.Kenlm` for reference).

For convenience, the perplexity score is computed as the exponentiation of the cross-entropy
calculated using the `cross_entropy()` method.

:param target: A list of tokens that make up the target text (P).
:param context: A list of tokens that make up the context text (Q).
:return: The perplexity score computed as :math:`\text{PPL}(P, Q)`.
"""
return np.exp(self.cross_entropy(target, context))
Loading