Skip to content

Commit

Permalink
fix displayed score to value between 0 and maximum number of patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
pebbe committed Dec 5, 2024
1 parent b916192 commit 05083a2
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions textcat.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ type TextCat struct {
}

type resultType struct {
score int
lang string
tokens int
score float64
lang string
}

type resultsType []*resultType
Expand Down Expand Up @@ -215,13 +214,10 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
pattypes = append(pattypes, false)
}
for _, utf8 := range pattypes {

patt := GetPatterns(text, utf8)
suffix := ".raw"
tokens := rawsize
if utf8 {
suffix = ".utf8"
tokens = utfsize
}
for lang := range tc.lang {
if !tc.lang[lang] || !strings.HasSuffix(lang, suffix) {
Expand All @@ -244,15 +240,15 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
score += i - n
}
}
scores = append(scores, &resultType{score, lang, tokens})
scores = append(scores, &resultType{float64(score) / float64(len(patt)), lang})
}
}
if len(scores) == 0 {
err = errAvail
return
}

minScore := MaxPatterns * MaxPatterns
minScore := float64(MaxPatterns * MaxPatterns)
for _, sco := range scores {
if sco.score < minScore {
minScore = sco.score
Expand All @@ -261,7 +257,7 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {
threshold := float64(minScore) * tc.thresholdValue
nCandidates := 0
for _, sco := range scores {
if float64(sco.score) <= threshold {
if sco.score <= threshold {
nCandidates += 1
}
}
Expand All @@ -272,14 +268,14 @@ func (tc *TextCat) Classify(text string) (languages []string, err error) {

lowScores := make([]*resultType, 0, nCandidates)
for _, sco := range scores {
if float64(sco.score) <= threshold {
if sco.score <= threshold {
lowScores = append(lowScores, sco)
}
}
sort.Sort(resultsType(lowScores))
for _, sco := range lowScores {
if tc.verbose {
languages = append(languages, fmt.Sprintf("%s|%.3f", sco.lang, float64(sco.score)/float64(sco.tokens)))
languages = append(languages, fmt.Sprintf("%s|%.2f", sco.lang, sco.score))
} else {
languages = append(languages, sco.lang)
}
Expand Down

0 comments on commit 05083a2

Please sign in to comment.