diff --git a/conversation/metadata.go b/conversation/metadata.go index 30a2412aaf..d3ed4f706a 100644 --- a/conversation/metadata.go +++ b/conversation/metadata.go @@ -26,4 +26,5 @@ type LangchainMetadata struct { Key string `json:"key"` Model string `json:"model"` CacheTTL string `json:"cacheTTL"` + Endpoint string `json:"endpoint"` } diff --git a/conversation/metadata_test.go b/conversation/metadata_test.go new file mode 100644 index 0000000000..58edab76e2 --- /dev/null +++ b/conversation/metadata_test.go @@ -0,0 +1,56 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conversation + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLangchainMetadata(t *testing.T) { + t.Run("json marshaling with endpoint", func(t *testing.T) { + metadata := LangchainMetadata{ + Key: "test-key", + Model: "gpt-4", + CacheTTL: "10m", + Endpoint: "https://custom-endpoint.example.com", + } + + bytes, err := json.Marshal(metadata) + require.NoError(t, err) + + var unmarshaled LangchainMetadata + err = json.Unmarshal(bytes, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, metadata.Key, unmarshaled.Key) + assert.Equal(t, metadata.Model, unmarshaled.Model) + assert.Equal(t, metadata.CacheTTL, unmarshaled.CacheTTL) + assert.Equal(t, metadata.Endpoint, unmarshaled.Endpoint) + }) + + t.Run("json unmarshaling with endpoint", func(t *testing.T) { + jsonStr := `{"key": "test-key", "endpoint": "https://api.openai.com/v1"}` + + var metadata LangchainMetadata + err := json.Unmarshal([]byte(jsonStr), &metadata) + require.NoError(t, err) + + assert.Equal(t, "test-key", metadata.Key) + assert.Equal(t, "https://api.openai.com/v1", metadata.Endpoint) + }) +} diff --git a/conversation/openai/metadata.yaml b/conversation/openai/metadata.yaml index 4ffb074291..2adf807141 100644 --- a/conversation/openai/metadata.yaml +++ b/conversation/openai/metadata.yaml @@ -27,6 +27,12 @@ metadata: The OpenAI LLM to use. Defaults to gpt-4o type: string example: 'gpt-4-turbo' + - name: endpoint + required: false + description: | + Custom API endpoint URL for OpenAI API-compatible services. If not specified, the default OpenAI API endpoint will be used. + type: string + example: 'https://api.openai.com/v1' - name: cacheTTL required: false description: | diff --git a/conversation/openai/openai.go b/conversation/openai/openai.go index da9894eb48..2c2dc63692 100644 --- a/conversation/openai/openai.go +++ b/conversation/openai/openai.go @@ -54,11 +54,18 @@ func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error { if md.Model != "" { model = md.Model } - - llm, err := openai.New( + // Create options for OpenAI client + options := []openai.Option{ openai.WithModel(model), openai.WithToken(md.Key), - ) + } + + // Add custom endpoint if provided + if md.Endpoint != "" { + options = append(options, openai.WithBaseURL(md.Endpoint)) + } + + llm, err := openai.New(options...) if err != nil { return err } diff --git a/conversation/openai/openai_test.go b/conversation/openai/openai_test.go new file mode 100644 index 0000000000..3b5229d8d1 --- /dev/null +++ b/conversation/openai/openai_test.go @@ -0,0 +1,95 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openai + +import ( + "testing" + + "github.com/dapr/components-contrib/conversation" + "github.com/dapr/components-contrib/metadata" + "github.com/dapr/kit/logger" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInit(t *testing.T) { + testCases := []struct { + name string + metadata map[string]string + testFn func(*testing.T, *OpenAI, error) + }{ + { + name: "with default endpoint", + metadata: map[string]string{ + "key": "test-key", + "model": "gpt-4", + }, + testFn: func(t *testing.T, o *OpenAI, err error) { + require.NoError(t, err) + assert.NotNil(t, o.llm) + }, + }, + { + name: "with custom endpoint", + metadata: map[string]string{ + "key": "test-key", + "model": "gpt-4", + "endpoint": "https://api.openai.com/v1", + }, + testFn: func(t *testing.T, o *OpenAI, err error) { + require.NoError(t, err) + assert.NotNil(t, o.llm) + // Since we can't directly access the client's baseURL, + // we're mainly testing that initialization succeeds + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + o := NewOpenAI(logger.NewLogger("openai test")) + err := o.Init(t.Context(), conversation.Metadata{ + Base: metadata.Base{ + Properties: tc.metadata, + }, + }) + tc.testFn(t, o.(*OpenAI), err) + }) + } +} + +func TestEndpointInMetadata(t *testing.T) { + // Create an instance of OpenAI component + o := &OpenAI{} + + // This test relies on the metadata tag + md := o.GetComponentMetadata() + if len(md) == 0 { + t.Skip("Metadata is not enabled, skipping test") + } + + // Print all available metadata keys for debugging + t.Logf("Available metadata keys: %v", func() []string { + keys := make([]string, 0, len(md)) + for k := range md { + keys = append(keys, k) + } + return keys + }()) + + // Verify endpoint field exists (note: field names are capitalized in metadata) + _, exists := md["Endpoint"] + assert.True(t, exists, "Endpoint field should exist in metadata") +}