From b9161929133448edbfc4e97bdb4177df80ed9161 Mon Sep 17 00:00:00 2001 From: Peter Kleiweg Date: Thu, 5 Dec 2024 00:09:45 +0100 Subject: [PATCH] v2/verbose: changed score into score devided by number of tokens --- v2/textcat.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/v2/textcat.go b/v2/textcat.go index 29227c1..b73df62 100644 --- a/v2/textcat.go +++ b/v2/textcat.go @@ -34,8 +34,9 @@ type TextCat struct { } type resultType struct { - score int - lang string + score int + lang string + tokens int } type resultsType []*resultType @@ -193,11 +194,14 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) { languages = make([]string, 0, tc.maxCandidates) - if tc.raw && len(text) < tc.minDocSize { + rawsize := len(text) + utfsize := utf8.RuneCountInString(strings.TrimSpace(reInvalid.ReplaceAllString(text, " "))) + + if tc.raw && rawsize < tc.minDocSize { err = errShort return } - if tc.utf8 && utf8.RuneCountInString(strings.TrimSpace(reInvalid.ReplaceAllString(text, " "))) < tc.minDocSize { + if tc.utf8 && utfsize < tc.minDocSize { err = errShort return } @@ -211,10 +215,13 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) { pattypes = append(pattypes, false) } for _, utf8 := range pattypes { + patt := GetPatterns(text, utf8) suffix := ".raw" + tokens := rawsize if utf8 { suffix = ".utf8" + tokens = utfsize } for lang := range tc.lang { if !tc.lang[lang] || !strings.HasSuffix(lang, suffix) { @@ -237,7 +244,7 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) { score += i - n } } - scores = append(scores, &resultType{score, lang}) + scores = append(scores, &resultType{score, lang, tokens}) } } if len(scores) == 0 { @@ -272,7 +279,7 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) { sort.Sort(resultsType(lowScores)) for _, sco := range lowScores { if tc.verbose { - languages = append(languages, fmt.Sprint(sco.lang, "|", sco.score)) + languages = append(languages, fmt.Sprintf("%s|%.3f", sco.lang, float64(sco.score)/float64(sco.tokens))) } else { languages = append(languages, sco.lang) }