diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..63a74db --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,35 @@ +name: Go Test Automation with Coverage + +on: + pull_request: + branches: + - main + +jobs: + test: + name: Run Go Tests and Coverage + runs-on: ubuntu-latest # Use GitHub-hosted Ubuntu runner + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 1.23 # Adjust to your required Go version + + - name: Install dependencies + run: go mod tidy + + - name: Run tests with coverage + run: go test -coverprofile=coverage.out -covermode=atomic ./... + + - name: Generate coverage report + run: go tool cover -func=coverage.out + + - name: Upload coverage report as artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.out diff --git a/caches.go b/caches.go index 8ad5166..88756b0 100644 --- a/caches.go +++ b/caches.go @@ -19,6 +19,7 @@ package genai import ( "context" "fmt" + "iter" "net/http" ) @@ -384,6 +385,70 @@ func updateCachedContentParametersToVertex(ac *apiClient, fromObject map[string] return toObject, nil } +func listCachedContentsConfigToMldev(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromPageSize := getValueByPath(fromObject, []string{"pageSize"}) + if fromPageSize != nil { + setValueByPath(parentObject, []string{"_query", "pageSize"}, fromPageSize) + } + + fromPageToken := getValueByPath(fromObject, []string{"pageToken"}) + if fromPageToken != nil { + setValueByPath(parentObject, []string{"_query", "pageToken"}, fromPageToken) + } + + return toObject, nil +} + +func listCachedContentsConfigToVertex(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromPageSize := getValueByPath(fromObject, []string{"pageSize"}) + if fromPageSize != nil { + setValueByPath(parentObject, []string{"_query", "pageSize"}, fromPageSize) + } + + fromPageToken := getValueByPath(fromObject, []string{"pageToken"}) + if fromPageToken != nil { + setValueByPath(parentObject, []string{"_query", "pageToken"}, fromPageToken) + } + + return toObject, nil +} + +func listCachedContentsParametersToMldev(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromConfig := getValueByPath(fromObject, []string{"config"}) + if fromConfig != nil { + fromConfig, err = listCachedContentsConfigToMldev(ac, fromConfig.(map[string]any), toObject) + if err != nil { + return nil, err + } + + setValueByPath(toObject, []string{"config"}, fromConfig) + } + + return toObject, nil +} + +func listCachedContentsParametersToVertex(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromConfig := getValueByPath(fromObject, []string{"config"}) + if fromConfig != nil { + fromConfig, err = listCachedContentsConfigToVertex(ac, fromConfig.(map[string]any), toObject) + if err != nil { + return nil, err + } + + setValueByPath(toObject, []string{"config"}, fromConfig) + } + + return toObject, nil +} + func cachedContentFromMldev(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { toObject = make(map[string]any) @@ -478,6 +543,48 @@ func deleteCachedContentResponseFromVertex(ac *apiClient, fromObject map[string] return toObject, nil } +func listCachedContentsResponseFromMldev(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromNextPageToken := getValueByPath(fromObject, []string{"nextPageToken"}) + if fromNextPageToken != nil { + setValueByPath(toObject, []string{"nextPageToken"}, fromNextPageToken) + } + + fromCachedContents := getValueByPath(fromObject, []string{"cachedContents"}) + if fromCachedContents != nil { + fromCachedContents, err = applyConverterToSlice(ac, fromCachedContents.([]any), cachedContentFromMldev) + if err != nil { + return nil, err + } + + setValueByPath(toObject, []string{"cachedContents"}, fromCachedContents) + } + + return toObject, nil +} + +func listCachedContentsResponseFromVertex(ac *apiClient, fromObject map[string]any, parentObject map[string]any) (toObject map[string]any, err error) { + toObject = make(map[string]any) + + fromNextPageToken := getValueByPath(fromObject, []string{"nextPageToken"}) + if fromNextPageToken != nil { + setValueByPath(toObject, []string{"nextPageToken"}, fromNextPageToken) + } + + fromCachedContents := getValueByPath(fromObject, []string{"cachedContents"}) + if fromCachedContents != nil { + fromCachedContents, err = applyConverterToSlice(ac, fromCachedContents.([]any), cachedContentFromVertex) + if err != nil { + return nil, err + } + + setValueByPath(toObject, []string{"cachedContents"}, fromCachedContents) + } + + return toObject, nil +} + type Caches struct { apiClient *apiClient } @@ -518,6 +625,14 @@ func (m Caches) Create(ctx context.Context, model string, config *CreateCachedCo if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -573,6 +688,14 @@ func (m Caches) Get(ctx context.Context, name string, config *GetCachedContentCo if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -628,6 +751,14 @@ func (m Caches) Delete(ctx context.Context, name string, config *DeleteCachedCon if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -683,6 +814,14 @@ func (m Caches) Update(ctx context.Context, name string, config *UpdateCachedCon if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -701,3 +840,92 @@ func (m Caches) Update(ctx context.Context, name string, config *UpdateCachedCon } return response, nil } + +func (m Caches) list(ctx context.Context, config *ListCachedContentsConfig) (*ListCachedContentsResponse, error) { + parameterMap := make(map[string]any) + + kwargs := map[string]any{"config": config} + deepMarshal(kwargs, ¶meterMap) + + var response = new(ListCachedContentsResponse) + var responseMap map[string]any + var fromConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error) + var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error) + if m.apiClient.clientConfig.Backend == BackendVertexAI { + toConverter = listCachedContentsParametersToVertex + fromConverter = listCachedContentsResponseFromVertex + } else { + toConverter = listCachedContentsParametersToMldev + fromConverter = listCachedContentsResponseFromMldev + } + + body, err := toConverter(m.apiClient, parameterMap, nil) + if err != nil { + return nil, err + } + var path string + var urlParams map[string]any + if _, ok := body["_url"]; ok { + urlParams = body["_url"].(map[string]any) + delete(body, "_url") + } + if m.apiClient.clientConfig.Backend == BackendVertexAI { + path, err = formatMap("cachedContents", urlParams) + } else { + path, err = formatMap("cachedContents", urlParams) + } + if err != nil { + return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) + } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } + + if _, ok := body["config"]; ok { + delete(body, "config") + } + responseMap, err = sendRequest(ctx, m.apiClient, path, http.MethodGet, body) + if err != nil { + return nil, err + } + responseMap, err = fromConverter(m.apiClient, responseMap, nil) + if err != nil { + return nil, err + } + err = mapToStruct(responseMap, response) + if err != nil { + return nil, err + } + return response, nil +} + +func (m Caches) List(ctx context.Context, config *ListCachedContentsConfig) (Page[CachedContent, ListCachedContentsConfig], error) { + listFunc := func(ctx context.Context, config *ListCachedContentsConfig) ([]*CachedContent, string, error) { + resp, err := m.list(ctx, config) + if err != nil { + return nil, "", err + } + return resp.CachedContents, resp.NextPageToken, nil + } + return newPage(ctx, "cachedContents", config, listFunc) +} + +func (m Caches) All(ctx context.Context) iter.Seq2[*CachedContent, error] { + listFunc := func(ctx context.Context, config *ListCachedContentsConfig) ([]*CachedContent, string, error) { + resp, err := m.list(ctx, config) + if err != nil { + return nil, "", err + } + return resp.CachedContents, resp.NextPageToken, nil + } + p, err := newPage(ctx, "cachedContents", genai.ListCachedContentsConfig{}, listFunc) + if err != nil { + return yieldErrorAndEndIterator[CachedContent](err) + } + return p.all(ctx) +} diff --git a/common.go b/common.go index d4b322c..3f1f124 100644 --- a/common.go +++ b/common.go @@ -19,7 +19,10 @@ import ( "encoding/json" "errors" "fmt" + "iter" + "net/url" "reflect" + "strconv" ) // Ptr returns a pointer to its argument. @@ -128,3 +131,34 @@ func deepMarshal(input any, output *map[string]any) error { } return nil } + +func createURLQuery(query map[string]any) (string, error) { + v := url.Values{} + for key, value := range query { + switch value := value.(type) { + case string: + v.Add(key, value) + case int: + v.Add(key, strconv.Itoa(value)) + case float64: + v.Add(key, strconv.FormatFloat(value, 'f', -1, 64)) + case bool: + v.Add(key, strconv.FormatBool(value)) + case []any: + for _, item := range value { + v.Add(key, item.(string)) + } + default: + return "", fmt.Errorf("unsupported type: %T", value) + } + } + return v.Encode(), nil +} + +func yieldErrorAndEndIterator[T any](err error) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + if !yield(nil, err) { + return + } + } +} diff --git a/models.go b/models.go index b220b5b..d68ecc3 100644 --- a/models.go +++ b/models.go @@ -2501,6 +2501,14 @@ func (m Models) generateContent(ctx context.Context, model string, contents []*C if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -2619,6 +2627,14 @@ func (m Models) GenerateImages(ctx context.Context, model string, prompt string, if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -2678,11 +2694,19 @@ func (m Models) upscaleImage(ctx context.Context, model string, image *Image, up if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") } - responseMap, err = sendRequest(ctx, m.apiClient, path, http.MethodPost, &body) + responseMap, err = sendRequest(ctx, m.apiClient, path, http.MethodPost, body) if err != nil { return nil, err } @@ -2733,6 +2757,14 @@ func (m Models) CountTokens(ctx context.Context, model string, contents []*Conte if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") @@ -2792,6 +2824,14 @@ func (m Models) ComputeTokens(ctx context.Context, model string, contents []*Con if err != nil { return nil, fmt.Errorf("invalid url params: %#v.\n%w", urlParams, err) } + if _, ok := body["_query"]; ok { + query, err := createURLQuery(body["_query"].(map[string]any)) + if err != nil { + return nil, err + } + path = path + "?" + query + delete(body, "_query") + } if _, ok := body["config"]; ok { delete(body, "config") diff --git a/pages.go b/pages.go new file mode 100644 index 0000000..f93a1ab --- /dev/null +++ b/pages.go @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package genai + +import ( + "context" + "fmt" + "io" + "iter" + "reflect" +) + +// PageDone is the error returned by Next when no more page is available. +var PageDone = errors.New("PageDone") + +type Page[T any, C any] struct { + Name string // The name of the paged items. + Items []*T // The items in the current page. + + config C // The configuration used for the API call. + listFunc func(ctx context.Context, config *C) ([]*T, string, error) // The function used to retrieve the next page. + nextPageToken string // The token to use to retrieve the next page of results. +} + +func newPage[T any, C any](ctx context.Context, name string, config C, listFunc func(ctx context.Context, config *C) ([]*T, string, error)) (page[T, C], error) { + p := Page[T, C]{ + Name: name, + config: config, + listFunc: listFunc, + } + items, nextPageToken, err := listFunc(ctx, &config) + if err != nil { + return p, err + } + p.Items = items + p.nextPageToken = nextPageToken + return p, nil +} + +// all returns an iterator that yields all items across all pages of results. +// +// The iterator retrieves each page sequentially and yields each item within +// the page. If an error occurs during retrieval, the iterator will stop +// and the error will be returned as the second value in the next call to Next(). +// An io.EOF error indicates that all pages have been processed. +func (p page[T, C]) all(ctx context.Context) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + for { + for _, item := range p.Items { + if !yield(item, nil) { + return + } + } + var err error + p, err = p.next(ctx) + if err == io.EOF { + return + } + if err != nil { + yield(nil, err) + return + } + } + } +} + +// next retrieves the next page of results. +// +// If there are no more pages, PageDone is returned. Otherwise, +// a new Page struct containing the next set of results is returned. +// Any other errors encountered during retrieval will also be returned. +func (p page[T, C]) Next(ctx context.Context) (page[T, C], error) { + if p.nextPageToken == "" { + return p, PageDone + } + configPtr := reflect.ValueOf(&p.config) // Note the & operator + configValue := configPtr.Elem() + pageTokenField := configValue.FieldByName("PageToken") + if !pageTokenField.IsValid() || !pageTokenField.CanSet() { + return p, fmt.Errorf("pageToken field is invalid or not settable") + } + pageTokenField.SetString(p.nextPageToken) + + return newPage[T, C](ctx, p.Name, p.config, p.listFunc) +} diff --git a/samples/create_cached_content.go b/samples/cached_content.go similarity index 71% rename from samples/create_cached_content.go rename to samples/cached_content.go index b3e549d..789a2bb 100644 --- a/samples/create_cached_content.go +++ b/samples/cached_content.go @@ -40,9 +40,18 @@ import ( var model = flag.String("model", "gemini-1.5-pro-002", "the model name, e.g. gemini-1.5-pro-002") +func debugPrint(r any) { + // Marshal the result to JSON. + response, err := json.MarshalIndent(r, "", " ") + if err != nil { + log.Fatal(err) + } + // Log the output. + fmt.Println(string(response)) +} + func createCachedContent(ctx context.Context) { client, err := genai.NewClient(ctx, nil) - fmt.Println("client: ", client.ClientConfig()) if err != nil { log.Fatal(err) } @@ -76,29 +85,35 @@ func createCachedContent(ctx context.Context) { if err != nil { log.Fatal(err) } - // Marshal the result to JSON and pretty-print it to a byte array. - response, err := json.MarshalIndent(*result, "", " ") - if err != nil { - log.Fatal(err) + debugPrint(result) + + // Iterate over the cached contents. + // Option 1: using the All method. + for item, err := range client.Caches.All(ctx) { + if err != nil { + log.Fatal(err) + } + debugPrint(item) } - // Log the output. - fmt.Println(string(response)) - // Retrieve the cached content. - resp, err := client.Caches.Get(ctx, result.Name, nil) - if err != nil { - log.Fatal(err) + // Iterate over the cached contents. + // Option 2: using the List method for more control. + // Example 2.1 - List the first page. + page, err := client.Caches.List(ctx, &genai.ListCachedContentsConfig{PageSize: 2}) + // Example 2.2 - Continue to the next page. + page, err = page.Next(ctx) + // Example 2.3 - Resume the page iteration using the next page token. + page, err = client.Caches.List(ctx, &genai.ListCachedContentsConfig{PageSize: 2, PageToken: page.NextPageToken}) + if err == genai.PageDone { + fmt.Println("No more cached content to retrieve.") + return } - // Marshal the result to JSON and pretty-print it to a byte array. - respJSON, err := json.MarshalIndent(resp, "", " ") if err != nil { log.Fatal(err) } - // Log the output. - fmt.Println(string(respJSON)) + debugPrint(page.Items) } - func main() { ctx := context.Background() flag.Parse() diff --git a/types.go b/types.go index fc0c65c..f3165c2 100644 --- a/types.go +++ b/types.go @@ -1493,6 +1493,31 @@ type UpdateCachedContentParameters struct { Config *UpdateCachedContentConfig `json:"config,omitempty"` } +// Config for caches.list method. +type ListCachedContentsConfig struct { + // PageSize specifies the maximum number of cached contents to return per API call. + // This setting does not affect the total number of cached contents returned by the + // All() function during iteration; it only controls how many items are retrieved in + // each individual request to the server. If zero, the server will use a default value. + // Setting a positive value can be useful for managing the size and frequency of API + // calls. + PageSize int64 `json:"pageSize,omitempty"` + + PageToken string `json:"pageToken,omitempty"` +} + +// Parameters for caches.list method. +type ListCachedContentsParameters struct { + // Configuration that contains optional parameters. + Config *ListCachedContentsConfig `json:"config,omitempty"` +} + +type ListCachedContentsResponse struct { + NextPageToken string `json:"nextPageToken,omitempty"` + // List of cached contents. + CachedContents []*CachedContent `json:"cachedContents,omitempty"` +} + type testTableItem struct { // The name of the test. This is used to derive the replay id. Name string `json:"name,omitempty"`