Skip to content

Commit

Permalink
refactor get_basic_words (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda authored Sep 29, 2022
1 parent 879004e commit 187dc88
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 12 deletions.
17 changes: 5 additions & 12 deletions explainaboard/analysis/feature_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from explainaboard.info import SysOutputInfo
from explainaboard.utils import basic_words
from explainaboard.utils.logging import progress
from explainaboard.utils.tokenizer import Tokenizer
from explainaboard.utils.tokenizer import SingleSpaceTokenizer, Tokenizer
from explainaboard.utils.typing_utils import unwrap


Expand Down Expand Up @@ -68,17 +68,10 @@ def get_basic_words(text: str) -> float:
Returns:
The ratio of basic words.
"""
value_list = text.split(' ')
n_words = len(value_list)
n_basic_words = 0

for word in value_list:

lower = word.lower()
if lower in basic_words.BASIC_WORDS:
n_basic_words = n_basic_words + 1

return n_basic_words * 1.0 / n_words
tokens = SingleSpaceTokenizer()(text)
assert len(tokens) > 0, f"BUG: no tokens obtained from the text: '{text}'"
n_basic_words = sum(1 for t in tokens if t.lower() in basic_words.BASIC_WORDS)
return n_basic_words / len(tokens)


def get_lexical_richness(text: str) -> float:
Expand Down
35 changes: 35 additions & 0 deletions explainaboard/analysis/feature_funcs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Tests for explainaboard.analysis.feature_funcs."""


import unittest

from explainaboard.analysis.feature_funcs import get_basic_words


class FeatureFuncsTest(unittest.TestCase):
def test_get_basic_words(self) -> None:
# All examples should exactly match.

# zero word
self.assertEqual(get_basic_words(""), 0.0)
self.assertEqual(get_basic_words(" "), 0.0)

# one word
self.assertEqual(get_basic_words("the"), 1.0)
self.assertEqual(get_basic_words(" the"), 0.5)
self.assertEqual(get_basic_words(" the "), 1 / 3)
self.assertEqual(get_basic_words("USA"), 0.0)

# two words
self.assertEqual(get_basic_words("United States"), 0.0)
self.assertEqual(get_basic_words("The USA"), 0.5)
self.assertEqual(get_basic_words("The country"), 1.0)

# check capitalization
self.assertEqual(get_basic_words("The THE the tHE"), 1.0)

# check punctuation
self.assertEqual(get_basic_words("It is."), 0.5)
self.assertEqual(get_basic_words("It is ."), 2 / 3)
self.assertEqual(get_basic_words("It, is"), 0.5)
self.assertEqual(get_basic_words("It , is"), 2 / 3)
32 changes: 32 additions & 0 deletions explainaboard/utils/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,38 @@ def test_from_orig_and_tokens_invalid(self) -> None:


class TokenizerSerializerTest(unittest.TestCase):
def test_empty(self) -> None:
tokens = SingleSpaceTokenizer()("")
self.assertEqual(len(tokens), 1)
self.assertEqual(tokens.strs, [""])
self.assertEqual(tokens.positions, [0])

def test_only_0x20(self) -> None:
tokens = SingleSpaceTokenizer()(" ")
self.assertEqual(len(tokens), 4)
self.assertEqual(tokens.strs, ["", "", "", ""])
self.assertEqual(tokens.positions, [0, 1, 2, 3])

def test_isspace(self) -> None:
tokens = SingleSpaceTokenizer()("\t\v \n\r\f")
self.assertEqual(len(tokens), 2)
self.assertEqual(tokens.strs, ["\t\v", "\n\r\f"])
self.assertEqual(tokens.positions, [0, 3])

def test_sentence(self) -> None:
tokens = SingleSpaceTokenizer()("May the force be with you.")
self.assertEqual(len(tokens), 6)
self.assertEqual(tokens.strs, ["May", "the", "force", "be", "with", "you."])
self.assertEqual(tokens.positions, [0, 4, 8, 14, 17, 22])

def test_sentence_with_extra_whitespaces(self) -> None:
tokens = SingleSpaceTokenizer()(" May the force\nbe with you. ")
self.assertEqual(len(tokens), 8)
self.assertEqual(
tokens.strs, ["", "May", "", "the", "force\nbe", "with", "you.", ""]
)
self.assertEqual(tokens.positions, [0, 1, 5, 6, 10, 19, 24, 29])

def test_serialize(self) -> None:
serializer = PrimitiveSerializer()
self.assertEqual(
Expand Down

0 comments on commit 187dc88

Please sign in to comment.