Skip to content

Commit

Permalink
v2/verbose: changed score into score devided by number of tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
pebbe committed Dec 4, 2024
1 parent 130f12a commit b916192
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions v2/textcat.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ type TextCat struct {
}

type resultType struct {
score int
lang string
score int
lang string
tokens int
}

type resultsType []*resultType
Expand Down Expand Up @@ -193,11 +194,14 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {

languages = make([]string, 0, tc.maxCandidates)

if tc.raw && len(text) < tc.minDocSize {
rawsize := len(text)
utfsize := utf8.RuneCountInString(strings.TrimSpace(reInvalid.ReplaceAllString(text, " ")))

if tc.raw && rawsize < tc.minDocSize {
err = errShort
return
}
if tc.utf8 && utf8.RuneCountInString(strings.TrimSpace(reInvalid.ReplaceAllString(text, " "))) < tc.minDocSize {
if tc.utf8 && utfsize < tc.minDocSize {
err = errShort
return
}
Expand All @@ -211,10 +215,13 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
pattypes = append(pattypes, false)
}
for _, utf8 := range pattypes {

patt := GetPatterns(text, utf8)
suffix := ".raw"
tokens := rawsize
if utf8 {
suffix = ".utf8"
tokens = utfsize
}
for lang := range tc.lang {
if !tc.lang[lang] || !strings.HasSuffix(lang, suffix) {
Expand All @@ -237,7 +244,7 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
score += i - n
}
}
scores = append(scores, &resultType{score, lang})
scores = append(scores, &resultType{score, lang, tokens})
}
}
if len(scores) == 0 {
Expand Down Expand Up @@ -272,7 +279,7 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
sort.Sort(resultsType(lowScores))
for _, sco := range lowScores {
if tc.verbose {
languages = append(languages, fmt.Sprint(sco.lang, "|", sco.score))
languages = append(languages, fmt.Sprintf("%s|%.3f", sco.lang, float64(sco.score)/float64(sco.tokens)))
} else {
languages = append(languages, sco.lang)
}
Expand Down

0 comments on commit b916192

Please sign in to comment.