Skip to content

Commit 59b3c47

Browse files
committed
feat(go bindings): add state abstraction
1 parent edea8a9 commit 59b3c47

File tree

8 files changed

+618
-66
lines changed

8 files changed

+618
-66
lines changed

bindings/go/pkg/whisper/consts.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ var (
1616
ErrProcessingFailed = errors.New("processing failed")
1717
ErrUnsupportedLanguage = errors.New("unsupported language")
1818
ErrModelNotMultilingual = errors.New("model is not multilingual")
19+
ErrUnableToCreateState = errors.New("unable to create state")
1920
)
2021

2122
///////////////////////////////////////////////////////////////////////////////

bindings/go/pkg/whisper/context.go

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ type context struct {
2020
params whisper.Params
2121
}
2222

23-
// Make sure context adheres to the interface
24-
var _ Context = (*context)(nil)
25-
2623
///////////////////////////////////////////////////////////////////////////////
2724
// LIFECYCLE
2825

@@ -241,26 +238,49 @@ func (context *context) Process(
241238
return nil
242239
}
243240

244-
// Return the next segment of tokens
241+
// NextSegment returns the next segment from the context buffer
245242
func (context *context) NextSegment() (Segment, error) {
246243
if context.model.ctx == nil {
247244
return Segment{}, ErrInternalAppError
248245
}
249246
if context.n >= context.model.ctx.Whisper_full_n_segments() {
250247
return Segment{}, io.EOF
251248
}
252-
253-
// Populate result
254249
result := toSegment(context.model.ctx, context.n)
255-
256-
// Increment the cursor
257250
context.n++
258-
259-
// Return success
260251
return result, nil
261252
}
262253

263-
// Test for text tokens
254+
///////////////////////////////////////////////////////////////////////////////
255+
// PRIVATE METHODS
256+
257+
func toSegment(ctx *whisper.Context, n int) Segment {
258+
return Segment{
259+
Num: n,
260+
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
261+
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
262+
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
263+
Tokens: toTokens(ctx, n),
264+
}
265+
}
266+
267+
func toTokens(ctx *whisper.Context, n int) []Token {
268+
result := make([]Token, ctx.Whisper_full_n_tokens(n))
269+
for i := 0; i < len(result); i++ {
270+
data := ctx.Whisper_full_get_token_data(n, i)
271+
272+
result[i] = Token{
273+
Id: int(ctx.Whisper_full_get_token_id(n, i)),
274+
Text: ctx.Whisper_full_get_token_text(n, i),
275+
P: ctx.Whisper_full_get_token_p(n, i),
276+
Start: time.Duration(data.T0()) * time.Millisecond * 10,
277+
End: time.Duration(data.T1()) * time.Millisecond * 10,
278+
}
279+
}
280+
return result
281+
}
282+
283+
// Token helpers
264284
func (context *context) IsText(t Token) bool {
265285
switch {
266286
case context.IsBEG(t):
@@ -280,70 +300,34 @@ func (context *context) IsText(t Token) bool {
280300
}
281301
}
282302

283-
// Test for "begin" token
284303
func (context *context) IsBEG(t Token) bool {
285304
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
286305
}
287306

288-
// Test for "start of transcription" token
289307
func (context *context) IsSOT(t Token) bool {
290308
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
291309
}
292310

293-
// Test for "end of transcription" token
294311
func (context *context) IsEOT(t Token) bool {
295312
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
296313
}
297314

298-
// Test for "start of prev" token
299315
func (context *context) IsPREV(t Token) bool {
300316
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
301317
}
302318

303-
// Test for "start of lm" token
304319
func (context *context) IsSOLM(t Token) bool {
305320
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
306321
}
307322

308-
// Test for "No timestamps" token
309323
func (context *context) IsNOT(t Token) bool {
310324
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
311325
}
312326

313-
// Test for token associated with a specific language
314327
func (context *context) IsLANG(t Token, lang string) bool {
315328
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
316329
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
317330
} else {
318331
return false
319332
}
320333
}
321-
322-
///////////////////////////////////////////////////////////////////////////////
323-
// PRIVATE METHODS
324-
325-
func toSegment(ctx *whisper.Context, n int) Segment {
326-
return Segment{
327-
Num: n,
328-
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
329-
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
330-
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
331-
Tokens: toTokens(ctx, n),
332-
}
333-
}
334-
335-
func toTokens(ctx *whisper.Context, n int) []Token {
336-
result := make([]Token, ctx.Whisper_full_n_tokens(n))
337-
for i := 0; i < len(result); i++ {
338-
data := ctx.Whisper_full_get_token_data(n, i)
339-
340-
result[i] = Token{
341-
Id: int(ctx.Whisper_full_get_token_id(n, i)),
342-
Text: ctx.Whisper_full_get_token_text(n, i),
343-
P: ctx.Whisper_full_get_token_p(n, i),
344-
Start: time.Duration(data.T0()) * time.Millisecond * 10,
345-
End: time.Duration(data.T1()) * time.Millisecond * 10,
346-
}
347-
}
348-
return result
349-
}

bindings/go/pkg/whisper/interface.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ type Context interface {
8585
SystemInfo() string
8686
}
8787

88+
// State is a per-request speech recognition state which shares the loaded model
89+
// but isolates recognition results. It embeds Context, so any state-specific
90+
// methods can be added later without breaking existing API.
91+
type State interface {
92+
io.Closer
93+
94+
Context
95+
}
96+
8897
// Segment is the text result of a speech recognition.
8998
type Segment struct {
9099
// Segment Number

bindings/go/pkg/whisper/model.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,19 @@ func (model *model) NewContext() (Context, error) {
9999
// Return new context
100100
return newContext(model, params)
101101
}
102+
103+
// NewState returns a new per-request state sharing the loaded model
104+
func (model *model) NewState() (State, error) {
105+
if model.ctx == nil {
106+
return nil, ErrInternalAppError
107+
}
108+
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
109+
params.SetTranslate(false)
110+
params.SetPrintSpecial(false)
111+
params.SetPrintProgress(false)
112+
params.SetPrintRealtime(false)
113+
params.SetPrintTimestamps(false)
114+
params.SetThreads(runtime.NumCPU())
115+
params.SetNoContext(true)
116+
return newState(model, params)
117+
}

bindings/go/pkg/whisper/state.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package whisper
2+
3+
import (
4+
"io"
5+
"strings"
6+
"time"
7+
8+
// Bindings
9+
whisper "github.com/ggerganov/whisper.cpp/bindings/go"
10+
)
11+
12+
// state embeds context behavior and carries a low-level state pointer
13+
// for isolated processing results.
14+
type state struct {
15+
*context
16+
st *whisper.State
17+
}
18+
19+
// NewState creates a new per-request State from a Model without changing the Model interface.
20+
func NewState(m Model) (State, error) {
21+
impl, ok := m.(*model)
22+
if !ok {
23+
return nil, ErrInternalAppError
24+
}
25+
params := impl.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
26+
params.SetTranslate(false)
27+
params.SetPrintSpecial(false)
28+
params.SetPrintProgress(false)
29+
params.SetPrintRealtime(false)
30+
params.SetPrintTimestamps(false)
31+
return newState(impl, params)
32+
}
33+
34+
// internal constructor used by model.NewState
35+
func newState(model *model, params whisper.Params) (State, error) {
36+
ctx := &context{model: model, params: params}
37+
st := model.ctx.Whisper_init_state()
38+
if st == nil {
39+
return nil, ErrUnableToCreateState
40+
}
41+
return &state{context: ctx, st: st}, nil
42+
}
43+
44+
// Process using an isolated state for concurrency
45+
func (s *state) Process(
46+
data []float32,
47+
callEncoderBegin EncoderBeginCallback,
48+
callNewSegment SegmentCallback,
49+
callProgress ProgressCallback,
50+
) error {
51+
if s.model.ctx == nil || s.st == nil {
52+
return ErrInternalAppError
53+
}
54+
if callNewSegment != nil {
55+
s.params.SetSingleSegment(true)
56+
}
57+
if err := s.model.ctx.Whisper_full_with_state(s.st, s.params, data, callEncoderBegin,
58+
func(new int) {
59+
if callNewSegment != nil {
60+
num_segments := s.model.ctx.Whisper_full_n_segments_from_state(s.st)
61+
s0 := num_segments - new
62+
for i := s0; i < num_segments; i++ {
63+
callNewSegment(toSegmentFromState(s.model.ctx, s.st, i))
64+
}
65+
}
66+
}, func(progress int) {
67+
if callProgress != nil {
68+
callProgress(progress)
69+
}
70+
}); err != nil {
71+
return err
72+
}
73+
return nil
74+
}
75+
76+
// Return the next segment of tokens for state
77+
func (s *state) NextSegment() (Segment, error) {
78+
if s.model.ctx == nil {
79+
return Segment{}, ErrInternalAppError
80+
}
81+
if s.n >= s.model.ctx.Whisper_full_n_segments_from_state(s.st) {
82+
return Segment{}, io.EOF
83+
}
84+
result := toSegmentFromState(s.model.ctx, s.st, s.n)
85+
s.n++
86+
return result, nil
87+
}
88+
89+
func (s *state) Close() error {
90+
if s.st != nil {
91+
s.st.Whisper_free_state()
92+
s.st = nil
93+
}
94+
return nil
95+
}
96+
97+
// Helpers specific to state-based results
98+
func toSegmentFromState(ctx *whisper.Context, st *whisper.State, n int) Segment {
99+
return Segment{
100+
Num: n,
101+
Text: stringsTrim(ctx.Whisper_full_get_segment_text_from_state(st, n)),
102+
Start: duration10x(ctx.Whisper_full_get_segment_t0_from_state(st, n)),
103+
End: duration10x(ctx.Whisper_full_get_segment_t1_from_state(st, n)),
104+
Tokens: toTokensFromState(ctx, st, n),
105+
}
106+
}
107+
108+
func toTokensFromState(ctx *whisper.Context, st *whisper.State, n int) []Token {
109+
result := make([]Token, ctx.Whisper_full_n_tokens_from_state(st, n))
110+
for i := 0; i < len(result); i++ {
111+
data := ctx.Whisper_full_get_token_data_from_state(st, n, i)
112+
result[i] = Token{
113+
Id: int(ctx.Whisper_full_get_token_id_from_state(st, n, i)),
114+
Text: ctx.Whisper_full_get_token_text_from_state(st, n, i),
115+
P: ctx.Whisper_full_get_token_p_from_state(st, n, i),
116+
Start: duration10x(data.T0()),
117+
End: duration10x(data.T1()),
118+
}
119+
}
120+
return result
121+
}
122+
123+
// small shared helpers to avoid importing time/strings here unnecessarily
124+
func stringsTrim(s string) string { return strings.TrimSpace(s) }
125+
func duration10x(ms10 int64) time.Duration { return time.Duration(ms10) * time.Millisecond * 10 }

0 commit comments

Comments
 (0)