Skip to content

Commit

Permalink
feat: implement StreamingReasoningFunc for reasoning models (#1125)
Browse files Browse the repository at this point in the history
* feat: implement StreamingReasoningFunc for resoning models

* Using deepseek example for streaming reasoning content
  • Loading branch information
douglarek authored Feb 13, 2025
1 parent 96d0b28 commit d3e43b6
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ func main() {
content,
llms.WithMaxTokens(2000),
llms.WithTemperature(0.7),
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
fmt.Print(string(chunk))
llms.WithStreamingReasoningFunc(func(ctx context.Context, reasoningChunk []byte, chunk []byte) error {
if len(reasoningChunk) > 0 {
fmt.Printf("Streaming Reasoning: %s\n", string(reasoningChunk))
}
if len(chunk) > 0 {
fmt.Printf("Streaming Content: %s\n", string(chunk))
}
return nil
}),
)
Expand Down
17 changes: 14 additions & 3 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type ChatRequest struct {
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`

// StreamingReasoningFunc is a function to be called for each reasoning and content chunk of a streaming response.
// Return an error to stop streaming early.
StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"`

// Deprecated: use Tools instead.
Functions []FunctionDefinition `json:"functions,omitempty"`
// Deprecated: use ToolChoice instead.
Expand Down Expand Up @@ -380,7 +384,7 @@ type FunctionCall struct {
}

func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCompletionResponse, error) {
if payload.StreamingFunc != nil {
if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil {
payload.Stream = true
if payload.StreamOptions == nil {
payload.StreamOptions = &StreamOptions{IncludeUsage: true}
Expand Down Expand Up @@ -421,7 +425,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCom

return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113
}
if payload.StreamingFunc != nil {
if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil {
return parseStreamingChatResponse(ctx, r, payload)
}
// Parse response
Expand Down Expand Up @@ -493,9 +497,10 @@ func combineStreamingChatResponse(
}
choice := streamResponse.Choices[0]
chunk := []byte(choice.Delta.Content)
reasoningChunk := []byte(choice.Delta.ReasoningContent) // TODO: not sure if there will be any reasoning related to function call later, so just pass it here
response.Choices[0].Message.Content += choice.Delta.Content
response.Choices[0].FinishReason = choice.FinishReason
response.Choices[0].Message.ReasoningContent = choice.Delta.ReasoningContent
response.Choices[0].Message.ReasoningContent += choice.Delta.ReasoningContent

if choice.Delta.FunctionCall != nil {
chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall)
Expand All @@ -512,6 +517,12 @@ func combineStreamingChatResponse(
return nil, fmt.Errorf("streaming func returned an error: %w", err)
}
}
if payload.StreamingReasoningFunc != nil {
err := payload.StreamingReasoningFunc(ctx, reasoningChunk, chunk)
if err != nil {
return nil, fmt.Errorf("streaming reasoning func returned an error: %w", err)
}
}
}
return &response, nil
}
Expand Down
27 changes: 27 additions & 0 deletions llms/openai/internal/openaiclient/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ func TestParseStreamingChatResponse_ReasoningContent(t *testing.T) {
assert.Equal(t, FinishReason("stop"), resp.Choices[0].FinishReason)
}

func TestParseStreamingChatResponse_ReasoningFunc(t *testing.T) {
t.Parallel()
mockBody := `
data: {"id":"fa7e4fc5-a05d-4e7b-9a66-a2dd89e91a4e","object":"chat.completion.chunk","created":1738492867,"model":"deepseek-reasoner","system_fingerprint":"fp_7e73fd9a08","choices":[{"index":0,"delta":{"content":null,"reasoning_content":"Okay"},"logprobs":null,"finish_reason":null}]}
`
r := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(mockBody)),
}

req := &ChatRequest{
StreamingReasoningFunc: func(_ context.Context, reasoningChunk, chunk []byte) error {
t.Logf("reasoningChunk: %s", string(reasoningChunk))
t.Logf("chunk: %s", string(chunk))
return nil
},
}

resp, err := parseStreamingChatResponse(context.Background(), r, req)

require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "", resp.Choices[0].Message.Content)
assert.Equal(t, "Okay", resp.Choices[0].Message.ReasoningContent)
assert.Equal(t, FinishReason(""), resp.Choices[0].FinishReason)
}

func TestChatMessage_MarshalUnmarshal(t *testing.T) {
t.Parallel()
msg := ChatMessage{
Expand Down
17 changes: 9 additions & 8 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
chatMsgs = append(chatMsgs, msg)
}
req := &openaiclient.ChatRequest{
Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,
Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
StreamingReasoningFunc: opts.StreamingReasoningFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,

MaxCompletionTokens: opts.MaxTokens,

Expand Down
10 changes: 10 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ type CallOptions struct {
// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
// StreamingReasoningFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"`
// TopK is the number of tokens to consider for top-k sampling.
TopK int `json:"top_k"`
// TopP is the cumulative probability for top-p sampling.
Expand Down Expand Up @@ -162,6 +165,13 @@ func WithStreamingFunc(streamingFunc func(ctx context.Context, chunk []byte) err
}
}

// WithStreamingReasoningFunc specifies the streaming reasoning function to use.
func WithStreamingReasoningFunc(streamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error) CallOption {
return func(o *CallOptions) {
o.StreamingReasoningFunc = streamingReasoningFunc
}
}

// WithTopK will add an option to use top-k sampling.
func WithTopK(topK int) CallOption {
return func(o *CallOptions) {
Expand Down

0 comments on commit d3e43b6

Please sign in to comment.