Skip to content

Commit

Permalink
breaking changes: improved customization support
Browse files Browse the repository at this point in the history
  • Loading branch information
n3integration committed Sep 28, 2018
1 parent 1cd7f48 commit df74ede
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 66 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.3.1
version=0.4.0
48 changes: 33 additions & 15 deletions func.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
package classifier

const defaultBufferSize = 50

// Predicate provides a predicate function
type Predicate func(string) bool

// Mapper provides a map function
type Mapper func(string) string

// Map applies f to each element of the supplied input slice
func Map(vs chan string, f Mapper) chan string {
outstream := make(chan string)
// Map applies f to each element of the supplied input channel
func Map(vs chan string, f ...Mapper) chan string {
stream := make(chan string, defaultBufferSize)

go func() {
for v := range vs {
outstream <- f(v)
for _, fn := range f {
v = fn(v)
}
stream <- v
}
close(outstream)
close(stream)
}()
return outstream

return stream
}

// Filter removes elements from the input slice where the supplied predicate
// Filter removes elements from the input channel where the supplied predicate
// is satisfied
func Filter(vs chan string, f Predicate) chan string {
outstream := make(chan string)
// Filter is a Predicate aggregation
func Filter(vs chan string, filters ...Predicate) chan string {
stream := make(chan string, defaultBufferSize)
apply := func(text string) bool {
for _, f := range filters {
if !f(text) {
return false
}
}
return true
}

go func() {
for v := range vs {
if f(v) {
outstream <- v
for text := range vs {
if apply(text) {
stream <- text
}
}
close(outstream)
close(stream)
}()
return outstream
}

return stream
}
34 changes: 25 additions & 9 deletions naive/naive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,42 @@ import (
// ErrNotClassified indicates that a document could not be classified
var ErrNotClassified = errors.New("unable to classify document")

// Option provides a functional setting for the Classifier
type Option func(c *Classifier) error

// Classifier implements a naive bayes classifier
type Classifier struct {
feat2cat map[string]map[string]int
catCount map[string]int
tokenizer classifier.Tokenizer
sync.RWMutex
mu sync.RWMutex
}

// New initializes a new naive Classifier
func New() *Classifier {
return &Classifier{
// New initializes a new naive Classifier using the standard tokenizer
func New(opts ...Option) *Classifier {
c := &Classifier{
feat2cat: make(map[string]map[string]int),
catCount: make(map[string]int),
tokenizer: classifier.NewTokenizer(),
}
for _, opt := range opts {
opt(c)
}
return c
}

// Tokenizer overrides the classifier's default Tokenizer
func Tokenizer(t classifier.Tokenizer) Option {
return func(c *Classifier) error {
c.tokenizer = t
return nil
}
}

// Train provides supervisory training to the classifier
func (c *Classifier) Train(r io.Reader, category string) error {
c.Lock()
defer c.Unlock()
c.mu.Lock()
defer c.mu.Unlock()

for feature := range c.tokenizer.Tokenize(r) {
c.addFeature(feature, category)
Expand All @@ -55,8 +70,8 @@ func (c *Classifier) Classify(r io.Reader) (string, error) {
classification := ""
probabilities := make(map[string]float64)

c.RLock()
defer c.RUnlock()
c.mu.RLock()
defer c.mu.RUnlock()

for _, category := range c.categories() {
probabilities[category] = c.probability(r, category)
Expand All @@ -65,6 +80,7 @@ func (c *Classifier) Classify(r io.Reader) (string, error) {
classification = category
}
}

if classification == "" {
return "", ErrNotClassified
}
Expand Down Expand Up @@ -152,5 +168,5 @@ func (c *Classifier) docProbability(r io.Reader, category string) float64 {
}

func asReader(text string) io.Reader {
return bytes.NewBuffer([]byte(text))
return bytes.NewBufferString(text)
}
80 changes: 43 additions & 37 deletions tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,48 @@ package classifier

import (
"bufio"
"bytes"
"io"
"regexp"
"strings"
"unsafe"
)

// Tokenizer provides a common interface to tokenize documents
type Tokenizer interface {
// Tokenize breaks the provided document into a token slice
Tokenize(r io.Reader) chan string
// Tokenize breaks the provided document into a channel of tokens
Tokenize(io.Reader) chan string
}

type regexTokenizer struct {
tokenizer *regexp.Regexp
}
// StdOption provides configuration settings for a StdTokenizer
type StdOption func(*StdTokenizer)

type stdTokenizer struct {
// StdTokenizer provides a common document tokenizer that splits a
// document by word boundaries
type StdTokenizer struct {
transforms []Mapper
filters []Predicate
bufferSize int
}

// NewTokenizer initializes a new standard Tokenizer instance
func NewTokenizer() Tokenizer {
return &stdTokenizer{}
}

// NewRegexTokenizer initializes a new regular expression Tokenizer instance
func NewRegexTokenizer() Tokenizer {
return &regexTokenizer{
tokenizer: regexp.MustCompile("\\W+"),
func NewTokenizer(opts ...StdOption) *StdTokenizer {
tokenizer := &StdTokenizer{
bufferSize: 100,
transforms: []Mapper{
strings.ToLower,
},
filters: []Predicate{
IsNotStopWord,
},
}
for _, opt := range opts {
opt(tokenizer)
}
return tokenizer
}

func (t *stdTokenizer) Tokenize(r io.Reader) chan string {
func (t *StdTokenizer) Tokenize(r io.Reader) chan string {
tokenizer := bufio.NewScanner(r)
tokenizer.Split(bufio.ScanWords)
tokens := make(chan string)
tokens := make(chan string, t.bufferSize)

go func() {
for tokenizer.Scan() {
Expand All @@ -46,27 +52,27 @@ func (t *stdTokenizer) Tokenize(r io.Reader) chan string {
close(tokens)
}()

return pipeline(tokens)
return t.pipeline(tokens)
}

// Tokenize extracts and normalizes all words from a text corpus
func (t *regexTokenizer) Tokenize(r io.Reader) chan string {
buffer := new(bytes.Buffer)
buffer.ReadFrom(r)
b := buffer.Bytes()
doc := *(*string)(unsafe.Pointer(&b))
tokens := make(chan string)

go func() {
for _, token := range t.tokenizer.Split(doc, -1) {
tokens <- token
}
close(tokens)
}()
func (t *StdTokenizer) pipeline(in chan string) chan string {
return Map(Filter(in, t.filters...), t.transforms...)
}

return pipeline(tokens)
func BufferSize(size int) StdOption {
return func(t *StdTokenizer) {
t.bufferSize = size
}
}

func pipeline(tokens chan string) chan string {
return Map(Filter(tokens, IsNotStopWord), strings.ToLower)
func Transforms(m ...Mapper) StdOption {
return func(t *StdTokenizer) {
t.transforms = m
}
}

func Filters(f ...Predicate) StdOption {
return func(t *StdTokenizer) {
t.filters = f
}
}
4 changes: 0 additions & 4 deletions tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ func TestTokenize(t *testing.T) {
tokens := NewTokenizer().Tokenize(toReader(text))
doTokenizeTest(t, tokens)
})
t.Run("Regexp Tokenizer", func(t *testing.T) {
tokens := NewRegexTokenizer().Tokenize(toReader(text))
doTokenizeTest(t, tokens)
})
}

func toReader(text string) io.Reader {
Expand Down

0 comments on commit df74ede

Please sign in to comment.