Skip to content

Commit

Permalink
Remove deprecated args / funcs (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebchiam authored Sep 24, 2022
1 parent 9dce288 commit 685c97a
Show file tree
Hide file tree
Showing 20 changed files with 95 additions and 304 deletions.
21 changes: 5 additions & 16 deletions convokit/classifier/classifier.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from convokit import Transformer
from convokit.classifier.util import *
from convokit import Transformer, CorpusComponent
from convokit.util import deprecation


class Classifier(Transformer):
Expand Down Expand Up @@ -34,9 +33,7 @@ def __init__(
labeller: Callable[[CorpusComponent], bool] = lambda x: True,
clf=None,
clf_attribute_name: str = "prediction",
clf_feat_name=None,
clf_prob_attribute_name: str = "pred_score",
clf_prob_feat_name=None,
):
self.pred_feats = pred_feats
self.labeller = labeller
Expand All @@ -50,16 +47,8 @@ def __init__(
)
print("Initialized default classification model (standard scaled logistic regression).")
self.clf = clf
self.clf_attribute_name = clf_attribute_name if clf_feat_name is None else clf_feat_name
self.clf_prob_attribute_name = (
clf_prob_attribute_name if clf_prob_feat_name is None else clf_prob_feat_name
)

if clf_feat_name is not None:
deprecation("Classifier's clf_feat_name parameter", "clf_attribute_name")

if clf_prob_feat_name is not None:
deprecation("Classifier's clf_prob_feat_name parameter", "clf_prob_attribute_name")
self.clf_attribute_name = clf_attribute_name
self.clf_prob_attribute_name = clf_prob_attribute_name

def fit(
self, corpus: Corpus, y=None, selector: Callable[[CorpusComponent], bool] = lambda x: True
Expand Down
17 changes: 7 additions & 10 deletions convokit/coordination/coordination.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pkg_resources
from convokit.model import Corpus, Speaker, Utterance
from collections import defaultdict
from typing import Callable, Tuple, List, Dict, Optional, Collection, Union
from .coordinationScore import CoordinationScore, CoordinationWordCategories

import pkg_resources

from convokit.model import Corpus, Speaker, Utterance
from convokit.transformer import Transformer
from convokit.util import deprecation
from .coordinationScore import CoordinationScore, CoordinationWordCategories


class Coordination(Transformer):
Expand Down Expand Up @@ -445,15 +446,13 @@ def _scores_over_utterances(
speaker_thresh_indiv: int,
target_thresh_indiv: int,
utterances_thresh_indiv: int,
utterance_thresh_func: Optional[Callable[[Tuple[Utterance, Utterance]], bool]] = None,
utterance_thresh_func: Optional[Callable[[Utterance, Utterance], bool]] = None,
focus: str = "speakers",
split_by_attribs: Optional[List[str]] = None,
speaker_utterance_selector: Callable[
[Tuple[Utterance, Utterance]], bool
] = lambda utt1, utt2: True,
target_utterance_selector: Callable[
[Tuple[Utterance, Utterance]], bool
[Utterance, Utterance], bool
] = lambda utt1, utt2: True,
target_utterance_selector: Callable[[Utterance, Utterance], bool] = lambda utt1, utt2: True,
) -> CoordinationScore:
assert not isinstance(speakers, str)
assert focus == "speakers" or focus == "targets"
Expand All @@ -479,8 +478,6 @@ def _scores_over_utterances(
speaker, utt2, split_by_attribs
), Coordination._annot_speaker(target, utt1, split_by_attribs)

# speaker_has_attribs = Coordination._utterance_has_attribs(utt2, speaker_attribs)
# target_has_attribs = Coordination._utterance_has_attribs(utt1, target_attribs)
speaker_filter = speaker_utterance_selector(utt2, utt1)
target_filter = target_utterance_selector(utt2, utt1)

Expand Down
10 changes: 3 additions & 7 deletions convokit/coordination/coordinationScore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from convokit.model import Speaker
from convokit.util import deprecation
from collections import defaultdict
from typing import Callable, Tuple, List, Dict, Optional, Collection, Hashable, Union
from typing import Dict, Optional, Hashable, Union

from convokit.model import Speaker

CoordinationWordCategories = [
"article",
Expand Down Expand Up @@ -48,10 +48,6 @@ def scores_for_marker(self, marker: str) -> Dict[Union[Speaker, Hashable], float
"""
return {speaker: scores[marker] for speaker, scores in self.items()}

def averages_by_user(self):
deprecation("averages_by_user()", "averages_by_speaker()")
return {speaker: sum(scores.values()) / len(scores) for speaker, scores in self.items()}

def averages_by_speaker(self) -> Dict[Union[Speaker, Hashable], float]:
"""Return a dictionary from speakers to the average of each speaker's
marker scores."""
Expand Down
36 changes: 19 additions & 17 deletions convokit/expected_context_framework/demos/demo_text_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from convokit.text_processing import TextProcessor, TextParser, TextToArcs
from convokit.phrasing_motifs import CensorNouns, QuestionSentences
from convokit.convokitPipeline import ConvokitPipeline
from convokit.phrasing_motifs import CensorNouns, QuestionSentences
from convokit.text_processing import TextProcessor, TextParser, TextToArcs

"""
Some pipelines to compute the feature representations used in each Expected Context Model demo.
Expand All @@ -11,11 +11,12 @@ def parliament_arc_pipeline():
return ConvokitPipeline(
[
# to avoid most computations, we'll only run the pipeline if the desired attributes don't exist
("parser", TextParser(input_filter=lambda utt, aux: utt.get_info("arcs") is None)),
("parser", TextParser(input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None)),
(
"censor_nouns",
CensorNouns(
"parsed_censored", input_filter=lambda utt, aux: utt.get_info("arcs") is None
"parsed_censored",
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
Expand All @@ -24,15 +25,15 @@ def parliament_arc_pipeline():
"arc_arr",
input_field="parsed_censored",
root_only=True,
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
"question_sentence_filter",
QuestionSentences(
"q_arc_arr",
input_field="arc_arr",
input_filter=lambda utt, aux: utt.get_info("q_arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("q_arcs") is None,
),
),
(
Expand All @@ -41,7 +42,7 @@ def parliament_arc_pipeline():
output_field="arcs",
input_field="arc_arr",
proc_fn=lambda x: "\n".join(x),
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
Expand All @@ -50,7 +51,7 @@ def parliament_arc_pipeline():
output_field="q_arcs",
input_field="q_arc_arr",
proc_fn=lambda x: "\n".join(x),
input_filter=lambda utt, aux: utt.get_info("q_arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("q_arcs") is None,
),
),
]
Expand All @@ -63,14 +64,15 @@ def wiki_arc_pipeline():
(
"parser",
TextParser(
input_filter=lambda utt, aux: (utt.get_info("arcs") is None)
and (utt.get_info("parsed") is None)
input_filter=lambda utt, aux: (utt.retrieve_meta("arcs") is None)
and (utt.retrieve_meta("parsed") is None)
),
),
(
"censor_nouns",
CensorNouns(
"parsed_censored", input_filter=lambda utt, aux: utt.get_info("arcs") is None
"parsed_censored",
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
Expand All @@ -79,7 +81,7 @@ def wiki_arc_pipeline():
"arc_arr",
input_field="parsed_censored",
root_only=False,
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
Expand All @@ -88,7 +90,7 @@ def wiki_arc_pipeline():
output_field="arcs",
input_field="arc_arr",
proc_fn=lambda x: "\n".join(x),
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
]
Expand All @@ -98,14 +100,14 @@ def wiki_arc_pipeline():
def scotus_arc_pipeline():
return ConvokitPipeline(
[
("parser", TextParser(input_filter=lambda utt, aux: utt.get_info("arcs") is None)),
("parser", TextParser(input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None)),
(
"arcs",
TextToArcs(
"arc_arr",
input_field="parsed",
root_only=False,
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
(
Expand All @@ -114,7 +116,7 @@ def scotus_arc_pipeline():
output_field="arcs",
input_field="arc_arr",
proc_fn=lambda x: "\n".join(x),
input_filter=lambda utt, aux: utt.get_info("arcs") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("arcs") is None,
),
),
]
Expand All @@ -130,7 +132,7 @@ def switchboard_text_pipeline():
TextProcessor(
proc_fn=lambda x: x,
output_field="alpha_text",
input_filter=lambda utt, aux: utt.get_info("alpha_text") is None,
input_filter=lambda utt, aux: utt.retrieve_meta("alpha_text") is None,
),
)
]
Expand Down
20 changes: 2 additions & 18 deletions convokit/forecaster/forecasterModel.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,19 @@
from abc import ABC, abstractmethod
from convokit.util import deprecation


class ForecasterModel(ABC):
def __init__(
self,
forecast_attribute_name: str = "prediction",
forecast_feat_name=None,
forecast_prob_attribute_name: str = "score",
forecast_prob_feat_name=None,
):
"""
:param forecast_attribute_name: name for DataFrame column containing predictions, default: "prediction"
:param forecast_prob_attribute_name: name for column containing prediction scores, default: "score"
"""
self.forecast_attribute_name = (
forecast_attribute_name if forecast_feat_name is None else forecast_feat_name
)
self.forecast_prob_attribute_name = (
forecast_prob_attribute_name
if forecast_prob_feat_name is None
else forecast_prob_feat_name
)

for deprecated_set in [
(forecast_feat_name, "forecast_feat_name", "forecast_attribute_name"),
(forecast_prob_feat_name, "forecast_prob_feat_name", "forecast_prob_attribute_name"),
]:
if deprecated_set[0] is not None:
deprecation(f"Forecaster's {deprecated_set[1]} parameter", f"{deprecated_set[2]}")
self.forecast_attribute_name = forecast_attribute_name
self.forecast_prob_attribute_name = forecast_prob_attribute_name

@abstractmethod
def train(self, id_to_context_reply_label):
Expand Down
14 changes: 5 additions & 9 deletions convokit/hyperconvo/hyperconvo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, Optional, Callable

import numpy as np
import scipy.stats
import pandas as pd
import scipy.stats
from scipy.sparse import csr_matrix
from typing import Dict, Optional, Callable

from convokit.util import deprecation
from convokit.transformer import Transformer
from convokit.model import Corpus, Conversation
from convokit.transformer import Transformer
from .hypergraph import Hypergraph


Expand Down Expand Up @@ -71,15 +71,11 @@ def __init__(
prefix_len: int = 10,
min_convo_len: int = 10,
vector_name: str = "hyperconvo",
feat_name=None,
invalid_val: float = np.nan,
):
self.prefix_len = prefix_len
self.min_convo_len = min_convo_len
self.vector_name = vector_name if feat_name is None else feat_name
if feat_name is not None:
deprecation("HyperConvo's feat_name parameter", "vector_name")

self.vector_name = vector_name
self.invalid_val = invalid_val

def transform(
Expand Down
1 change: 0 additions & 1 deletion convokit/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@
from .corpusComponent import CorpusComponent
from .corpus_helpers import *
from .speaker import Speaker
from .user import User
from .utterance import Utterance
from .utteranceNode import UtteranceNode
36 changes: 6 additions & 30 deletions convokit/model/conversation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict
from typing import Dict, List, Callable, Generator, Optional
from .utterance import Utterance
from .speaker import Speaker
from convokit.util import deprecation, warn

from convokit.util import warn
from .corpusComponent import CorpusComponent
from collections import defaultdict
from .utteranceNode import UtteranceNode
from .corpusUtil import *
from .speaker import Speaker
from .utterance import Utterance
from .utteranceNode import UtteranceNode


class Conversation(CorpusComponent):
Expand Down Expand Up @@ -90,23 +91,6 @@ def get_utterances_dataframe(
"""
return get_utterances_dataframe(self, selector, exclude_meta)

def get_usernames(self) -> List[str]:
"""Produces a list of names of all speakers in the Conversation, which can
be used in calls to get_speaker() to retrieve specific speakers. Provides no
ordering guarantees for the list.
:return: a list of usernames
"""
deprecation("get_usernames()", "get_speaker_ids()")
if self._speaker_ids is None:
# first call to get_usernames or iter_speakers; precompute cached list
# of usernames
self._speaker_ids = set()
for ut_id in self._utterance_ids:
ut = self._owner.get_utterance(ut_id)
self._speaker_ids.add(ut.speaker.name)
return list(self._speaker_ids)

def get_speaker_ids(self) -> List[str]:
"""
Produces a list of ids of all speakers in the Conversation, which can be used in calls to get_speaker()
Expand All @@ -133,10 +117,6 @@ def get_speaker(self, speaker_id: str) -> Speaker:
# any Utterances
return self._owner.get_speaker(speaker_id)

def get_user(self, speaker_id: str):
deprecation("get_user()", "get_speaker()")
return self.get_speaker(speaker_id)

def iter_speakers(
self, selector: Callable[[Speaker], bool] = lambda speaker: True
) -> Generator[Speaker, None, None]:
Expand Down Expand Up @@ -176,10 +156,6 @@ def get_speakers_dataframe(
"""
return get_speakers_dataframe(self, selector, exclude_meta)

def iter_users(self, selector=lambda speaker: True):
deprecation("iter_users()", "iter_speakers()")
return self.iter_speakers(selector)

def print_conversation_stats(self):
"""
Helper function for printing the number of Utterances and Spekaers in the Conversation.
Expand Down
Loading

0 comments on commit 685c97a

Please sign in to comment.