Skip to content

Commit

Permalink
resolved #12 fixed using pointer to Decoder interface in Tokenizer st…
Browse files Browse the repository at this point in the history
…ruct
  • Loading branch information
sugarme committed Oct 17, 2020
1 parent 18e4c34 commit 4e049d7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 8 additions & 0 deletions example/basic/bert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"

"github.com/sugarme/tokenizer"
"github.com/sugarme/tokenizer/decoder"
"github.com/sugarme/tokenizer/model/wordpiece"
"github.com/sugarme/tokenizer/normalizer"
"github.com/sugarme/tokenizer/pretokenizer"
Expand Down Expand Up @@ -39,6 +40,10 @@ func getBert() (retVal *tokenizer.Tokenizer) {
bertPreTokenizer := pretokenizer.NewBertPreTokenizer()
tk.WithPreTokenizer(bertPreTokenizer)

wpDecoder := decoder.NewWordPieceDecoder("", false)

tk.WithDecoder(wpDecoder)

return tk
}

Expand Down Expand Up @@ -100,6 +105,9 @@ func bertTokenize() {
fmt.Printf("offsets: %v\n", en.GetOffsets())
fmt.Printf("word Ids: %v\n", en.GetWords())

decodedStr := tk.Decode(en.Ids, true)
fmt.Printf("decodedStr: '%v'\n", decodedStr)

// Output:
// original: 'Hello, y'all! How are you 😁 ?'
// tokens: [[CLS] hello , y ' all ! how are you [UNK] ? [SEP]]
Expand Down
8 changes: 4 additions & 4 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ type Tokenizer struct {
preTokenizer *PreTokenizer // optional
model Model
postProcessor *PostProcessor // optional
decoder *Decoder // optional
decoder Decoder // optional - interface

// Added vocabulary capability
addedVocabulary AddedVocabulary
Expand Down Expand Up @@ -222,11 +222,11 @@ func (t *Tokenizer) GetPostProcessor() *PostProcessor {
return t.postProcessor
}

func (t *Tokenizer) WithDecoder(decoder *Decoder) {
func (t *Tokenizer) WithDecoder(decoder Decoder) {
t.decoder = decoder
}

func (t *Tokenizer) GetDecoder() *Decoder {
func (t *Tokenizer) GetDecoder() Decoder {
return t.decoder
}

Expand Down Expand Up @@ -430,7 +430,7 @@ func (t *Tokenizer) Decode(ids []int, skipSpecialTokens bool) (retVal string) {
}

if t.decoder != nil {
return (*t.decoder).Decode(tokens)
return (t.decoder).Decode(tokens)
}

return strings.Join(tokens, " ")
Expand Down

0 comments on commit 4e049d7

Please sign in to comment.