Skip to content

Commit

Permalink
formalized api for string and io.Reader
Browse files Browse the repository at this point in the history
  • Loading branch information
n3integration committed Jan 31, 2018
1 parent 2716ce5 commit e449119
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 80 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# classifier
A concurrent naive bayes text classifier.
A naive bayes text classifier.

[ ![Codeship Status for n3integration/classifier](https://app.codeship.com/projects/a9a8adf0-d14a-0135-6c51-26e28af241d2/status?branch=master)](https://app.codeship.com/projects/262403)
[![codecov](https://codecov.io/gh/n3integration/classifier/branch/master/graph/badge.svg)](https://codecov.io/gh/n3integration/classifier)
Expand All @@ -18,11 +18,11 @@ go get github.com/n3integration/classifier
import "github.com/n3integration/classifier/naive"

classifier := naive.New()
classifier.Train("The quick brown fox jumped over the lazy dog", "ham")
classifier.Train("Earn a degree online", "ham")
classifier.Train("Earn cash quick online", "spam")
classifier.TrainString("The quick brown fox jumped over the lazy dog", "ham")
classifier.TrainString("Earn a degree online", "ham")
classifier.TrainString("Earn cash quick online", "spam")

if classification, err := classifier.Classify("Earn your masters degree online"); err == nil {
if classification, err := classifier.ClassifyString("Earn your masters degree online"); err == nil {
fmt.Println("Classification => ", classification) // ham
} else {
fmt.Println("error: ", err)
Expand All @@ -34,12 +34,12 @@ if classification, err := classifier.Classify("Earn your masters degree online")
- Fork the repository
- Create a local feature branch
- Run `gofmt`
- Bump the `VERSION` file according to [semantic versioning](https://semver.org/)
- Bump the `VERSION` file using [semantic versioning](https://semver.org/)
- Submit a pull request

## License

Copyright 2016 n3integration@gmail.com
Copyright 2018 n3integration@gmail.com

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.2.0
version=0.3.0
20 changes: 10 additions & 10 deletions classifier.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
package classifier

import (
"regexp"
"strings"
)
import "io"

// Classifier provides a simple interface for different text classifiers
type Classifier interface {
// Train allows clients to train the classifier
Train(doc string, category string) error
Train(io.Reader, string) error
// TrainString allows clients to train the classifier using a string
TrainString(string, string) error
// Classify performs a classification on the input corpus and assumes that
// the underlying classifier has been trained.
Classify(doc string) (string, error)
Classify(io.Reader) (string, error)
// ClassifyString performs text classification using a string
ClassifyString(string) (string, error)
}

// WordCounts extracts term frequencies from a text corpus
func WordCounts(doc string) (map[string]int, error) {
tokens := Tokenize(doc)
func WordCounts(r io.Reader) (map[string]int, error) {
instream := NewTokenizer().Tokenize(r)
wc := make(map[string]int)
for _, token := range tokens {
for token := range instream {
wc[token] = wc[token] + 1
}

return wc, nil
}
17 changes: 3 additions & 14 deletions classifier_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
package classifier

import "testing"

var (
text = "The quick brown fox jumped over the lazy dog"
expected = 7
import (
"testing"
)

func TestTokenize(t *testing.T) {
tokens := Tokenize(text)

if len(tokens) != expected {
t.Errorf("Expected %d tokens; actual: %d", expected, len(tokens))
}
}

func TestWordCounts(t *testing.T) {
wc, err := WordCounts(text)
wc, err := WordCounts(toReader(text))

if err != nil {
t.Error("failed to get word counts:", err)
Expand Down
32 changes: 19 additions & 13 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,28 @@ type Predicate func(string) bool
type Mapper func(string) string

// Map applies f to each element of the supplied input slice
func Map(vs []string, f Mapper) []string {
vsm := make([]string, len(vs))
for i, v := range vs {
vsm[i] = f(v)
}
return vsm
func Map(vs chan string, f Mapper) chan string {
outstream := make(chan string)
go func() {
for v := range vs {
outstream <- f(v)
}
close(outstream)
}()
return outstream
}

// Filter removes elements from the input slice where the supplied predicate
// is satisfied
func Filter(vs []string, f Predicate) []string {
vsf := make([]string, 0)
for _, v := range vs {
if f(v) {
vsf = append(vsf, v)
func Filter(vs chan string, f Predicate) chan string {
outstream := make(chan string)
go func() {
for v := range vs {
if f(v) {
outstream <- v
}
}
}
return vsf
close(outstream)
}()
return outstream
}
33 changes: 25 additions & 8 deletions func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,42 @@ var words = []string{
"hello", "world",
}

func streamWords() chan string {
stream := make(chan string)
go func() {
for _, word := range words {
stream <- word
}
close(stream)
}()
return stream
}

func TestMap(t *testing.T) {
result := Map(words, strings.ToUpper)
for i, word := range result {
i := 0
results := Map(streamWords(), strings.ToUpper)
for word := range results {
expected := strings.ToUpper(words[i])
if expected != word {
t.Errorf("did not match expected result %v <> %v", expected, word)
}
i++
}
}

func TestFilter(t *testing.T) {
result := Filter(words, func(s string) bool {
return s != "hello"
results := Filter(streamWords(), func(s string) bool {
return s != words[0]
})

if len(result) != 1 {
t.Error("incorrect number of results:", len(result))
i := 0
for word := range results {
i++
if word != words[1] {
t.Error("incorrect result:", word)
}
}
if result[0] != "world" {
t.Error("incorrect result:", result[0])
if i != 1 {
t.Error("incorrect number of results:", i)
}
}
46 changes: 31 additions & 15 deletions naive/naive.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package naive

import (
"bytes"
"errors"
"io"
"sync"

"github.com/n3integration/classifier"
Expand All @@ -12,36 +14,42 @@ var ErrNotClassified = errors.New("unable to classify document")

// Classifier implements a naive bayes classifier
type Classifier struct {
feat2cat map[string]map[string]int
catCount map[string]int
feat2cat map[string]map[string]int
catCount map[string]int
tokenizer classifier.Tokenizer
sync.RWMutex
}

// New initializes a new naive Classifier
func New() *Classifier {
return &Classifier{
feat2cat: make(map[string]map[string]int),
catCount: make(map[string]int),
feat2cat: make(map[string]map[string]int),
catCount: make(map[string]int),
tokenizer: classifier.NewTokenizer(),
}
}

// Train provides supervisory training to the classifier
func (c *Classifier) Train(doc string, category string) error {
features := classifier.Tokenize(doc)

func (c *Classifier) Train(r io.Reader, category string) error {
c.Lock()
defer c.Unlock()

for _, feature := range features {
for feature := range c.tokenizer.Tokenize(r) {
c.addFeature(feature, category)
}

c.addCategory(category)
return nil
}

// TrainString provides supervisory training to the classifier
func (c *Classifier) TrainString(doc string, category string) error {
return c.Train(asReader(doc), category)
}

// Classify attempts to classify a document. If the document cannot be classified
// (eg. because the classifier has not been trained), an error is returned.
func (c *Classifier) Classify(doc string) (string, error) {
func (c *Classifier) Classify(r io.Reader) (string, error) {
max := 0.0
var err error
classification := ""
Expand All @@ -51,7 +59,7 @@ func (c *Classifier) Classify(doc string) (string, error) {
defer c.RUnlock()

for _, category := range c.categories() {
if probabilities[category], err = c.probability(doc, category); err != nil {
if probabilities[category], err = c.probability(r, category); err != nil {
return "", err
}
if probabilities[category] > max {
Expand All @@ -65,6 +73,11 @@ func (c *Classifier) Classify(doc string) (string, error) {
return classification, err
}

// ClassifyString provides convenience classification for strings
func (c *Classifier) ClassifyString(doc string) (string, error) {
return c.Classify(asReader(doc))
}

func (c *Classifier) addFeature(feature string, category string) {
if _, ok := c.feat2cat[feature]; !ok {
c.feat2cat[feature] = make(map[string]int)
Expand Down Expand Up @@ -126,20 +139,23 @@ func (c *Classifier) variableWeightedProbability(feature string, category string
return ((weight * assumedProb) + (sum * probability)) / (weight + sum)
}

func (c *Classifier) probability(doc string, category string) (float64, error) {
func (c *Classifier) probability(r io.Reader, category string) (float64, error) {
categoryProbability := c.categoryCount(category) / float64(c.count())
docProbability, err := c.docProbability(doc, category)
docProbability, err := c.docProbability(r, category)
if err != nil {
return 0.0, nil
}
return docProbability * categoryProbability, nil
}

func (c *Classifier) docProbability(doc string, category string) (float64, error) {
features := classifier.Tokenize(doc)
func (c *Classifier) docProbability(r io.Reader, category string) (float64, error) {
probability := 1.0
for _, feature := range features {
for feature := range c.tokenizer.Tokenize(r) {
probability *= c.weightedProbability(feature, category)
}
return probability, nil
}

func asReader(text string) io.Reader {
return bytes.NewBuffer([]byte(text))
}
14 changes: 7 additions & 7 deletions naive/naive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"testing"
)

const (
var (
ham = "The quick brown fox jumps over the lazy dog"
spam = "Earn cash quick online"
)
Expand Down Expand Up @@ -32,11 +32,11 @@ func TestAddCategory(t *testing.T) {
func TestTrain(t *testing.T) {
classifier := New()

if err := classifier.Train(ham, "good"); err != nil {
if err := classifier.TrainString(ham, "good"); err != nil {
t.Error("classifier training failed")
}

if err := classifier.Train(spam, "bad"); err != nil {
if err := classifier.TrainString(spam, "bad"); err != nil {
t.Error("classifier training failed")
}

Expand All @@ -51,16 +51,16 @@ func TestClassify(t *testing.T) {
text := "Quick way to make cash"

t.Run("Empty classifier", func(t *testing.T) {
if _, err := classifier.Classify(text); err != ErrNotClassified {
if _, err := classifier.ClassifyString(text); err != ErrNotClassified {
t.Errorf("expected classification error; received: %v", err)
}
})

t.Run("Trained classifier", func(t *testing.T) {
classifier.Train(ham, "good")
classifier.Train(spam, "bad")
classifier.TrainString(ham, "good")
classifier.TrainString(spam, "bad")

if _, err := classifier.Classify(text); err != nil {
if _, err := classifier.ClassifyString(text); err != nil {
t.Error("document incorrectly classified")
}
})
Expand Down
Loading

0 comments on commit e449119

Please sign in to comment.