diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..62e9c47f --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,133 @@ +name: lint + +on: + workflow_call: + + push: + branches: + - "main" + - "develop" + + pull_request: + branches: + - "main" + - "develop" + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + check-latest: true + + - name: Install + run: go install mvdan.cc/gofumpt@latest + + - name: Go Format + run: gofmt -s -w . && git diff --exit-code + + - name: Gofumpt + run: gofumpt -l -w . && git diff --exit-code + + - name: Go Vet + run: go vet ./... + + - name: Go Tidy + run: go mod tidy && git diff --exit-code + + - name: Go Mod + run: go mod download + + - name: Go Mod Verify + run: go mod verify + + build: + name: Build + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + check-latest: true + - name: Build + run: go build -v ./... + + static-checks: + name: Static Checks + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + check-latest: true + + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@latest + + - name: Install nilaway + run: go install go.uber.org/nilaway/cmd/nilaway@latest + + - name: GolangCILint + uses: golangci/golangci-lint-action@v3.4.0 + with: + version: latest + args: --timeout 5m + + - name: Staticcheck + run: staticcheck ./... +# TODO: Ignore the issue in https://github.com/modelgateway/Glide/issues/32 +# - name: Nilaway +# run: nilaway ./... + + tests: + name: Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + check-latest: true + + - name: Test + run: go test -v -count=1 -race -shuffle=on -coverprofile=coverage.txt ./... + + - name: Test + run: make test + + - name: Upload Coverage + uses: codecov/codecov-action@v3 + continue-on-error: true # we don't care if it fails + with: + token: ${{secrets.CODECOV_TOKEN}} # set in repository settings + file: ./coverage.txt # file from the previous step + fail_ci_if_error: false + + api-docs: + name: OpenAPI Specs + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + check-latest: true + + - name: Generate OpenAPI Schema + run: make docs-api && git diff --exit-code diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 00000000..93f56cdb --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,75 @@ +name: release + +on: + push: + tags: + - "*" + + branches: + - main + +permissions: + contents: write + packages: write + +jobs: + lint: + uses: ./.github/workflows/lint.yaml + vuln: + uses: ./.github/workflows/vuln.yaml + release: + needs: + - lint + - vuln + runs-on: ubuntu-latest + steps: + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: 1.21 + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v5 + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DISCORD_WEBHOOK_ID: ${{ secrets.DISCORD_WEBHOOK_ID }} + DISCORD_WEBHOOK_TOKEN: ${{ secrets.DISCORD_WEBHOOK_TOKEN }} + BREW_TAP_PRIVATE_KEY: ${{ secrets.BREW_TAP_PRIVATE_KEY }} + images: + strategy: + matrix: + image: + - alpine + - ubuntu + - distroless + - redhat + runs-on: ubuntu-latest + needs: + - release + steps: + - name: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: login into Github Container Registry + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin + + - name: build ${{ matrix.image }} image + working-directory: ./images + env: + BUILD_TIME: ${{needs.build_time.outputs.BUILD_TIME}} + run: VERSION=${{ github.ref_name }} COMMIT=$(git rev-parse --short "$GITHUB_SHA") make ${{ matrix.image }} + + - name: publish ${{ matrix.image }} image to Github Container Registry + working-directory: ./images + run: VERSION=${{ github.ref_name }} make publish-ghcr-${{ matrix.image }} diff --git a/.github/workflows/vuln.yaml b/.github/workflows/vuln.yaml new file mode 100644 index 00000000..9e363c8e --- /dev/null +++ b/.github/workflows/vuln.yaml @@ -0,0 +1,46 @@ +name: vuln + +on: + workflow_call: + + push: + branches: + - "main" + - "develop" + + pull_request: + branches: + - "main" + - "develop" + + schedule: + - cron: '0 10 * * 1' # run "At 10:00 on Monday" + +jobs: + run: + name: Vulnerability Check + runs-on: ubuntu-latest + timeout-minutes: 5 + env: + GO111MODULE: on + steps: + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: '1.21.5' + check-latest: true + + - name: Checkout + uses: actions/checkout@v3 + + - name: Install govulncheck + run: go install golang.org/x/vuln/cmd/govulncheck@latest + + - name: Install gosec + run: go install github.com/securego/gosec/v2/cmd/gosec@latest + + - name: Govulncheck + run: govulncheck -test ./... + + - name: Govulncheck + run: gosec ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..066b8f56 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.idea +dist +.env +config.yaml +bin +glide +tmp +coverage.txt +precommit.txt +.vscode/settings.json diff --git a/.go-version b/.go-version new file mode 100644 index 00000000..d2ab029d --- /dev/null +++ b/.go-version @@ -0,0 +1 @@ +1.21 diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..3b2c7b2e --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,42 @@ +output: + # Make output more digestible with quickfix in vim/emacs/etc. + sort-results: true + print-issued-lines: false + +linters: + enable: + - nolintlint + - revive + - staticcheck + - asasalint + - bodyclose + - contextcheck + - cyclop + - dupword + - errname + - exhaustive + - loggercheck + - misspell + - nestif + - perfsprint + - prealloc + - predeclared + - testifylint + - unconvert + - usestdlibvars + - wsl + +linters-settings: + govet: + # These govet checks are disabled by default, but they're useful. + enable: + - niliness + - reflectvaluecompare + - sortslice + - unusedwrite + - defers + - atomic + - nilfunc + - printf + - sortslice + - tests diff --git a/.goreleaser.yml b/.goreleaser.yml new file mode 100644 index 00000000..4f17455e --- /dev/null +++ b/.goreleaser.yml @@ -0,0 +1,261 @@ +project_name: glide + +before: + hooks: + - go generate + +builds: + - binary: glide + env: + - CGO_ENABLED=0 + ldflags: + - -s -w -X glide/pkg.version={{.Tag}} -X glide/pkg.commitSha={{.ShortCommit}} -X glide/pkg.buildDate={{.Date}} + goos: + - linux + - darwin + - freebsd + goarch: + - amd64 + - arm + - arm64 + goarm: + - '7' + - '6' + ignore: + - goos: openbsd + goarch: arm + - goos: openbsd + goarch: arm64 + - goos: freebsd + goarch: arm + - goos: freebsd + goarch: arm64 + - goos: linux + goarch: arm + +changelog: + skip: true + +archives: + - id: glide + name_template: '{{ .ProjectName }}_v{{ .Tag }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}' + format: tar.gz + format_overrides: + - goos: windows + format: zip + files: + - LICENSE + +checksum: + name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt" + +release: + # If set to true, will not auto-publish the release. + # Available only for GitHub and Gitea. + draft: true + + # Whether to remove existing draft releases with the same name before creating + # a new one. + # Only effective if `draft` is set to true. + # Available only for GitHub. + # + # Since: v1.11 + replace_existing_draft: true + + # Useful if you want to delay the creation of the tag in the remote. + # You can create the tag locally, but not push it, and run GoReleaser. + # It'll then set the `target_commitish` portion of the GitHub release to the + # value of this field. + # Only works on GitHub. + # + # Default: '' + # Since: v1.11 + # Templates: allowed + target_commitish: "{{ .Commit }}" + + # If set, will create a release discussion in the category specified. + # + # Warning: do not use categories in the 'Announcement' format. + # Check https://github.com/goreleaser/goreleaser/issues/2304 for more info. + # + # Default is empty. + discussion_category_name: Releases + + # If set to auto, will mark the release as not ready for production + # in case there is an indicator for this in the tag e.g. v1.0.0-rc1 + # If set to true, will mark the release as not ready for production. + # Default is false. + prerelease: auto + + # If set to false, will NOT mark the release as "latest". + # This prevents it from being shown at the top of the release list, + # and from being returned when calling https://api.github.com/repos/OWNER/REPO/releases/latest. + # + # Available only for GitHub. + # + # Default is true. + # Since: v1.20 + make_latest: true + + # What to do with the release notes in case there the release already exists. + # + # Valid options are: + # - `keep-existing`: keep the existing notes + # - `append`: append the current release notes to the existing notes + # - `prepend`: prepend the current release notes to the existing notes + # - `replace`: replace existing notes + # + # Default is `keep-existing`. + mode: append + + # You can change the name of the release. + # + # Default: '{{.Tag}}' ('{{.PrefixedTag}}' on Pro) + # Templates: allowed + name_template: "v{{.Version}}" + +brews: + - + # Name of the recipe + # + # Default: ProjectName + # Templates: allowed + name: glide + + # Alternative names for the current recipe. + # + # Useful if you want to publish a versioned formula as well, so users can + # more easily downgrade. + # + # Since: v1.20 (pro) + # Templates: allowed + alternative_names: + - myproject@{{ .Version }} + - myproject@{{ .Major }} + - myproject@{{ .Major }}{{ .Minor }} + + # GOARM to specify which 32-bit arm version to use if there are multiple + # versions from the build section. Brew formulas support only one 32-bit + # version. + # + # Default: 6 + goarm: 6 + + # GOAMD64 to specify which amd64 version to use if there are multiple + # versions from the build section. + # + # Default: v1 + goamd64: v1 + + # NOTE: make sure the url_template, the token and given repo (github or + # gitlab) owner and name are from the same kind. + # We will probably unify this in the next major version like it is + # done with scoop. + + # URL which is determined by the given Token (github, gitlab or gitea). + # + # Default depends on the client. + # Templates: allowed + url_template: "https://github.mycompany.com/foo/bar/releases/download/{{ .Tag }}/{{ .ArtifactName }}" + + # Allows you to set a custom download strategy. Note that you'll need + # to implement the strategy and add it to your tap repository. + # Example: https://docs.brew.sh/Formula-Cookbook#specifying-the-download-strategy-explicitly + download_strategy: CurlDownloadStrategy + + # Git author used to commit to the repository. + commit_author: + name: Release Bot + email: roman.glushko.m@gmail.com + + # The project name and current git tag are used in the format string. + # + # Templates: allowed + commit_msg_template: "Brew formula update for {{ .ProjectName }} version {{ .Tag }}" + + # Folder inside the repository to put the formula. + folder: Formula + + # Caveats for the user of your binary. + caveats: "" + + # Your app's homepage. + homepage: "https://github.com/EinStack/glide" + + # Your app's description. + # + # Templates: allowed + description: "A Lightweight, Cloud-Native LLM Gateway" + + # SPDX identifier of your app's license. + license: "Apache-2.0" + + # Setting this will prevent goreleaser to actually try to commit the updated + # formula - instead, the formula file will be stored on the dist folder only, + # leaving the responsibility of publishing it to the user. + # If set to auto, the release will not be uploaded to the homebrew tap + # in case there is an indicator for prerelease in the tag e.g. v1.0.0-rc1 + # + # Templates: allowed + skip_upload: auto + + # Custom block for brew. + # Can be used to specify alternate downloads for devel or head releases. + custom_block: | + head "https://github.com/some/package.git" + ... + + # Packages your package depends on. + dependencies: [] + + # Repository to push the generated files to. + repository: + # Repository owner. + # + # Templates: allowed + owner: EinStack + + # Repository name. + # + # Templates: allowed + name: homebrew-tap + + # Optionally a branch can be provided. + # + # Default: default repository branch + # Templates: allowed + branch: main + + # Clone, create the file, commit and push, to a regular Git repository. + # + # Notice that this will only have any effect if the given URL is not + # empty. + # + # Since: v1.18 + git: + # The Git URL to push. + # + # Templates: allowed + url: 'git@github.com:EinStack/homebrew-tap.git' + + private_key: '{{ .Env.BREW_TAP_PRIVATE_KEY }}' + +announce: + discord: + # Whether its enabled or not. + enabled: true + + # Message template to use while publishing. + # + # Templates: allowed + message_template: '๐Ÿ“ฆ Glide {{.Tag}} is out! Check it out at {{ .ReleaseURL }}' + + # Set author of the embed. + author: 'EinStack' + + # Color code of the embed. You have to use decimal numeral system, not hexadecimal. + # Default: '3888754' (the grey-ish from GoReleaser) + color: '' + + # URL to an image to use as the icon for the embed. + icon_url: '' diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..eb05365e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,31 @@ +# Changelog + +The changelog consists of three categories: +- **Features** - a new functionality that brings value to users +- **Improvements** - bugfixes, performance and other types of improvements to existing functionality +- **Miscellaneous** - all other updates like build, release, CLI, etc. + +## 0.0.1-rc.1 (Jan 21st, 2024) + +### Features +- โœจ [providers] Support for OpenAI Chat API #3 (@mkrueger12 ) +- โœจ [API] #54 Unified Chat API (@mkrueger12 ) +- โœจ [providers] Support for Cohere Chat API #5 (@mkrueger12 ) +- โœจ [providers] Support for Azure OpenAI Chat API #4 (@mkrueger12 ) +- โœจ [providers] Support for OctoML Chat API #58 (@mkrueger12 ) +- โœจ [routing] The Routing Mechanism, Adaptive Health Tracking, and Fallbacks #42 #43 #51 (@roma-glushko) +- โœจ [routing] Support for round robin routing strategy #44 (@roma-glushko) +- โœจ [routing] Support for the least latency routing strategy #46 (@roma-glushko) +- โœจ [routing] Support for weighted round robin routing strategy #45 (@roma-glushko) +- โœจ [providers] Support for Anthropic Chat API #60 (@mkrueger12 ) +- โœจ [docs] OpenAPI specifications #22 (@roma-glushko ) + +### Miscellaneous + +- ๐Ÿ”ง [chores] Inited the project #6 (@roma-glushko) +- ๐Ÿ”Š [telemetry] Inited logging #14 (@roma-glushko) +- ๐Ÿ”ง [chores] Inited Glide's CLI #12 (@roma-glushko) +- ๐Ÿ‘ท [chores] Setup CI workflows #8 (@roma-glushko) +- โš™๏ธ [config] Inited configs #11 (@roma-glushko) +- ๐Ÿ”ง [chores] Automatic coverage reports #39 (@roma-glushko) +- ๐Ÿ‘ท [build] Setup release workflows #9 (@roma-glushko) diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..2606a2d6 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,28 @@ +# This CITATION.cff file was generated with cffinit. +# Visit https://bit.ly/cffinit to generate yours today! + +cff-version: 1.2.0 +title: Glide +message: 'A lightweight, cloud-Native model gateway' +type: software +authors: + - given-names: Roman + family-names: Hlushko + email: roman.glushko.m@gmail.com + - given-names: Max + family-names: Krueger +repository-code: 'https://github.com/EinStack/glide' +repository-artifact: 'https://github.com/EinStack/glide/packages' +abstract: >- + Glide is your go-to cloud-native model gateway, delivering + high-performance production-ready LLMLOps in a + lightweight, all-in-one service. +keywords: + - generative-ai + - llmops + - mlops + - gateway + - infrastructure + - distributed-system + - llms +license: Apache-2.0 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..a13ebf99 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,75 @@ +# Community Code of Conduct + +The Glide project aims to be a welcoming place where new and existing members feel safe +to respectfully share their opinions and disagreements. +We want to attract a diverse group of people to collaborate with us, +which means acknowledging that people come from different backgrounds and cultures. + +## Our Pledge + +We as contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at roman.glushko.m@gmail.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html diff --git a/LICENSE b/LICENSE index 261eeb9e..4b39b4d0 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2023 Max Krueger, Roman Hlushko Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..ecf858bb --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ +CHECKER_BIN=$(PWD)/tmp/bin +VERSION_PACKAGE := glide/pkg +COMMIT ?= $(shell git describe --dirty --long --always --abbrev=15) +BUILD_DATE ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +VERSION ?= "latest" + +LDFLAGS_COMMON := "-s -w -X $(VERSION_PACKAGE).commitSha=$(COMMIT) -X $(VERSION_PACKAGE).version=$(VERSION) -X $(VERSION_PACKAGE).buildDate=$(BUILD_DATE)" + +.PHONY: help + +help: + @echo "๐Ÿ› ๏ธ Glide Dev Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + + +install-checkers: ## Install static checkers + @echo "๐Ÿšš Downloading binaries.." + @GOBIN=$(CHECKER_BIN) go install mvdan.cc/gofumpt@latest + @GOBIN=$(CHECKER_BIN) go install golang.org/x/vuln/cmd/govulncheck@latest + @GOBIN=$(CHECKER_BIN) go install github.com/securego/gosec/v2/cmd/gosec@latest + @GOBIN=$(CHECKER_BIN) go install github.com/swaggo/swag/cmd/swag@latest + +lint: install-checkers ## Lint the source code + @echo "๐Ÿงน Cleaning go.mod.." + @go mod tidy + @echo "๐Ÿงน Formatting files.." + @go fmt ./... + @$(CHECKER_BIN)/gofumpt -l -w . + @echo "๐Ÿงน Vetting go.mod.." + @go vet ./... + @echo "๐Ÿงน GoCI Lint.." + @golangci-lint run ./... + +vuln: install-checkers ## Check for vulnerabilities + @echo "๐Ÿ” Checking for vulnerabilities" + @$(CHECKER_BIN)/govulncheck -test ./... + @$(CHECKER_BIN)/gosec -quiet -exclude=G104 ./... + +run: ## Run Glide + @go run -ldflags $(LDFLAGS_COMMON) main.go -c ./config.dev.yaml + +build: ## Build Glide + @go build -ldflags $(LDFLAGS_COMMON) -o ./dist/glide + +test: ## Run tests + @go test -v -count=1 -race -shuffle=on -coverprofile=coverage.txt ./... + +docs-api: install-checkers ## Generate OpenAPI API docs + @$(CHECKER_BIN)/swag init diff --git a/README.md b/README.md index 211a2f26..9df7c083 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # Glide: Cloud-Native LLM Gateway for Seamless LLMOps +
+ Glide GH Header +
[![LICENSE](https://img.shields.io/github/license/modelgateway/glide.svg?style=flat-square&color=%233f90c8)](https://github.com/modelgateway/glide/blob/main/LICENSE) [![codecov](https://codecov.io/github/EinStack/glide/graph/badge.svg?token=F7JT39RHX9)](https://codecov.io/github/EinStack/glide) @@ -37,7 +40,7 @@ Check out our [documentation](https://backlandlabs.mintlify.app/introduction)! | | Azure OpenAI | ๐Ÿ‘ Supported | | | Cohere | ๐Ÿ‘ Supported | | | OctoML | ๐Ÿ‘ Supported | -| | Anthropic | ๐Ÿ—๏ธ Coming Soon | +| | Anthropic | ๐Ÿ‘ Supported | | | Google Gemini | ๐Ÿ—๏ธ Coming Soon | diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 00000000..a6e7eecf --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,53 @@ +*Updated: Fri, 12 Jan 2024* + +# Glide - Roadmap + +This document describes the current status and the upcoming milestones of the Glide LLM Router. + +## Glide + +#### Milestone Summary + +| Status | Milestone | Goals | +| :---: | :--- | :---: | +| ๐Ÿ | **Unified Chat Endpoint Support** | 4 / 4 | +| ๐Ÿ | **Fallback Routing Strategy** | 1 / 1 | +| ๐Ÿ | **Priority, Round Robin, Weighted Round Robin, Least Latency** | 2 / 4 | +| ๐ŸŽ | **Documentation** | 1 / 1 | +| ๐ŸŽ | **Private Preview** | 4 / 5 | +| ๐ŸŽ | **Streaming Support** | 0 / 4 | +| ๐ŸŽ | **Embedding Support** | 0 / 4 | +| ๐ŸŽ | **Caching** | 0 / 1 | +| ๐ŸŽ | **Public Preview** | 0 / 3 | +| ๐ŸŽ | **Python SDK** | 0 / 1 | +| ๐ŸŽ | **STT & TTS Model Support** | 0 / 2 | +| ๐ŸŽ | **Intelligent Routing** | 0 / 1 | +| ๐ŸŽ | **General Availability Routing** | 0 / 1 | + +### Private Preview + +- Unified LLM Chat REST API +- Support for most popular LLM providers +- Seamless model fallbacking +- Routing Strategies: Priority, Round Robin, Weighted Round Robin, Least Latency + +### Public Preview +- Embeddings +- Steaming +- Intelligent Routing + +### General Availability + +- Python SDK +- Speech-to-text & Text-to-speech models +- Exact & Semantic Caching + +### Future + +- Cost Management & Budgeting +- Safety & Control Over Inputs & Outputs + +- and many more! + +Open [an issue](https://github.com/modelgateway/glide/issues) or start [a discussion](https://github.com/modelgateway/glide/discussions) +if there is a feature or an enhancement you'd like to see in Glide. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..05beb8da --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,6 @@ +# Security Policy + +We want to keep Glide safe for everyone. + +If you've discovered a security vulnerability in Glide, +we appreciate your help in disclosing it to us in a responsible manner, using this email: roman.glushko.m@gmail.com diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..b82018c1 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,7 @@ +coverage: + status: + project: + default: + threshold: 1% + + patch: false diff --git a/config.dev.yaml b/config.dev.yaml new file mode 100644 index 00000000..30f83e87 --- /dev/null +++ b/config.dev.yaml @@ -0,0 +1,21 @@ +telemetry: + logging: + level: debug # debug, info, warn, error, fatal + encoding: console + +#api: +# http: +# ... + +routers: + language: + - id: myrouter + models: + - id: openai + openai: + api_key: "" + - id: azureopenai + azureopenai: + api_key: "" + model: "" + base_url: "" \ No newline at end of file diff --git a/config.sample.yaml b/config.sample.yaml new file mode 100644 index 00000000..3ce72055 --- /dev/null +++ b/config.sample.yaml @@ -0,0 +1,8 @@ +telemetry: + logging: + level: INFO # DEBUG, INFO, WARNING, ERROR, FATAL + encoding: json # console, json + +#api: +# http: +# ... diff --git a/docs/api/Health.bru b/docs/api/Health.bru new file mode 100644 index 00000000..0486a046 --- /dev/null +++ b/docs/api/Health.bru @@ -0,0 +1,11 @@ +meta { + name: Health + type: http + seq: 1 +} + +get { + url: {{base_url}}/health + body: none + auth: none +} diff --git a/docs/api/[Lang] Chat.bru b/docs/api/[Lang] Chat.bru new file mode 100644 index 00000000..d3a31a71 --- /dev/null +++ b/docs/api/[Lang] Chat.bru @@ -0,0 +1,21 @@ +meta { + name: [Lang] Chat + type: http + seq: 2 +} + +post { + url: {{base_url}}/v1/language/myrouter/chat/ + body: json + auth: none +} + +body:json { + { + "message": { + "role": "user", + "content": "How are you doing?" + }, + "messageHistory": [] + } +} diff --git a/docs/api/[Lang] Router List.bru b/docs/api/[Lang] Router List.bru new file mode 100644 index 00000000..81ccec75 --- /dev/null +++ b/docs/api/[Lang] Router List.bru @@ -0,0 +1,21 @@ +meta { + name: [Lang] Router List + type: http + seq: 3 +} + +get { + url: {{base_url}}/v1/language/ + body: json + auth: none +} + +body:json { + { + "message": { + "role": "user", + "content": "How are you doing?" + }, + "messageHistory": [] + } +} diff --git a/docs/api/bruno.json b/docs/api/bruno.json new file mode 100644 index 00000000..c543e3e0 --- /dev/null +++ b/docs/api/bruno.json @@ -0,0 +1,5 @@ +{ + "version": "1", + "name": "glide", + "type": "collection" +} \ No newline at end of file diff --git a/docs/api/environments/Development.bru b/docs/api/environments/Development.bru new file mode 100644 index 00000000..732c80cf --- /dev/null +++ b/docs/api/environments/Development.bru @@ -0,0 +1,3 @@ +vars { + base_url: http://127.0.0.1:9099 +} diff --git a/docs/docs.go b/docs/docs.go new file mode 100644 index 00000000..8a51f1ac --- /dev/null +++ b/docs/docs.go @@ -0,0 +1,712 @@ +// Package docs Code generated by swaggo/swag. DO NOT EDIT +package docs + +import "github.com/swaggo/swag" + +const docTemplate = `{ + "schemes": {{ marshal .Schemes }}, + "swagger": "2.0", + "info": { + "description": "{{escape .Description}}", + "title": "{{.Title}}", + "contact": { + "name": "Glide Community", + "url": "https://github.com/modelgateway/glide" + }, + "license": { + "name": "Apache 2.0", + "url": "https://github.com/modelgateway/glide/blob/develop/LICENSE" + }, + "version": "{{.Version}}" + }, + "host": "{{.Host}}", + "basePath": "{{.BasePath}}", + "paths": { + "/v1/health/": { + "get": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Operations" + ], + "summary": "Gateway Health", + "operationId": "glide-health", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/http.HealthSchema" + } + } + } + } + }, + "/v1/language/": { + "get": { + "description": "Retrieve list of configured language routers and their configurations", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Router List", + "operationId": "glide-language-routers", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/http.RouterListSchema" + } + } + } + } + }, + "/v1/language/{router}/chat": { + "post": { + "description": "Talk to different LLMs Chat API via unified endpoint", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Chat", + "operationId": "glide-language-chat", + "parameters": [ + { + "type": "string", + "description": "Router ID", + "name": "router", + "in": "path", + "required": true + }, + { + "description": "Request Data", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/schemas.UnifiedChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schemas.UnifiedChatResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + } + } + } + } + }, + "definitions": { + "anthropic.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/anthropic.Params" + }, + "model": { + "type": "string" + } + } + }, + "anthropic.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "metadata": { + "type": "string" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "system": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, + "azureopenai.Config": { + "type": "object", + "required": [ + "apiVersion", + "baseUrl", + "model" + ], + "properties": { + "apiVersion": { + "description": "The API version to use for this operation. This follows the YYYY-MM-DD format (e.g 2023-05-15)", + "type": "string" + }, + "baseUrl": { + "description": "The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/)", + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/azureopenai.Params" + }, + "model": { + "description": "This is your deployment name. You're required to first deploy a model before you can make calls (e.g. glide-gpt-35)", + "type": "string" + } + } + }, + "azureopenai.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "response_format": { + "description": "TODO: should this be a part of the chat request API?" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": {}, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + } + }, + "clients.ClientConfig": { + "type": "object", + "properties": { + "timeout": { + "type": "string" + } + } + }, + "cohere.ChatHistory": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "role": { + "type": "string" + }, + "user": { + "type": "string" + } + } + }, + "cohere.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/cohere.Params" + }, + "model": { + "type": "string" + } + } + }, + "cohere.Params": { + "type": "object", + "properties": { + "chat_history": { + "type": "array", + "items": { + "$ref": "#/definitions/cohere.ChatHistory" + } + }, + "citiation_quality": { + "type": "string" + }, + "connectors": { + "type": "array", + "items": { + "type": "string" + } + }, + "conversation_id": { + "type": "string" + }, + "preamble_override": { + "type": "string" + }, + "prompt_truncation": { + "type": "string" + }, + "search_queries_only": { + "type": "boolean" + }, + "stream": { + "description": "unsupported right now", + "type": "boolean" + }, + "temperature": { + "type": "number" + } + } + }, + "http.ErrorSchema": { + "type": "object", + "properties": { + "message": { + "type": "string" + } + } + }, + "http.HealthSchema": { + "type": "object", + "properties": { + "healthy": { + "type": "boolean" + } + } + }, + "http.RouterListSchema": { + "type": "object", + "properties": { + "routers": { + "type": "array", + "items": { + "$ref": "#/definitions/routers.LangRouterConfig" + } + } + } + }, + "latency.Config": { + "type": "object", + "properties": { + "decay": { + "description": "Weight of new latency measurements", + "type": "number" + }, + "update_interval": { + "description": "How often gateway should probe models with not the lowest response latency", + "type": "string" + }, + "warmup_samples": { + "description": "The number of latency probes required to init moving average", + "type": "integer" + } + } + }, + "octoml.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/octoml.Params" + }, + "model": { + "type": "string" + } + } + }, + "octoml.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "max_tokens": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + } + } + }, + "openai.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/openai.Params" + }, + "model": { + "type": "string" + } + } + }, + "openai.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "response_format": { + "description": "TODO: should this be a part of the chat request API?" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": {}, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + } + }, + "providers.LangModelConfig": { + "type": "object", + "required": [ + "id" + ], + "properties": { + "anthropic": { + "$ref": "#/definitions/anthropic.Config" + }, + "azureopenai": { + "$ref": "#/definitions/azureopenai.Config" + }, + "client": { + "$ref": "#/definitions/clients.ClientConfig" + }, + "cohere": { + "$ref": "#/definitions/cohere.Config" + }, + "enabled": { + "description": "Is the model enabled?", + "type": "boolean" + }, + "error_budget": { + "type": "string" + }, + "id": { + "description": "Model instance ID (unique in scope of the router)", + "type": "string" + }, + "latency": { + "$ref": "#/definitions/latency.Config" + }, + "octoml": { + "$ref": "#/definitions/octoml.Config" + }, + "openai": { + "$ref": "#/definitions/openai.Config" + }, + "weight": { + "type": "integer" + } + } + }, + "retry.ExpRetryConfig": { + "type": "object", + "properties": { + "base_multiplier": { + "type": "integer" + }, + "max_delay": { + "type": "integer" + }, + "max_retries": { + "type": "integer" + }, + "min_delay": { + "type": "integer" + } + } + }, + "routers.LangRouterConfig": { + "type": "object", + "required": [ + "models", + "routers" + ], + "properties": { + "enabled": { + "description": "Is router enabled?", + "type": "boolean" + }, + "models": { + "description": "the list of models that could handle requests", + "type": "array", + "items": { + "$ref": "#/definitions/providers.LangModelConfig" + } + }, + "retry": { + "description": "retry when no healthy model is available to router", + "allOf": [ + { + "$ref": "#/definitions/retry.ExpRetryConfig" + } + ] + }, + "routers": { + "description": "Unique router ID", + "type": "string" + }, + "strategy": { + "description": "strategy on picking the next model to serve the request", + "type": "string" + } + } + }, + "schemas.ChatMessage": { + "type": "object", + "properties": { + "content": { + "description": "The content of the message.", + "type": "string" + }, + "name": { + "description": "The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,\nwith a maximum length of 64 characters.", + "type": "string" + }, + "role": { + "description": "The role of the author of this message. One of system, user, or assistant.", + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenCount" + } + } + }, + "schemas.TokenCount": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } + }, + "schemas.UnifiedChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "messageHistory": { + "type": "array", + "items": { + "$ref": "#/definitions/schemas.ChatMessage" + } + } + } + }, + "schemas.UnifiedChatResponse": { + "type": "object", + "properties": { + "cached": { + "type": "boolean" + }, + "created": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "modelResponse": { + "$ref": "#/definitions/schemas.ProviderResponse" + }, + "model_id": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "router": { + "type": "string" + } + } + } + } +}` + +// SwaggerInfo holds exported Swagger Info so clients can modify it +var SwaggerInfo = &swag.Spec{ + Version: "1.0", + Host: "localhost:9099", + BasePath: "/", + Schemes: []string{"http"}, + Title: "Glide Gateway", + Description: "API documentation for Glide, an open-source lightweight high-performance model gateway", + InfoInstanceName: "swagger", + SwaggerTemplate: docTemplate, + LeftDelim: "{{", + RightDelim: "}}", +} + +func init() { + swag.Register(SwaggerInfo.InstanceName(), SwaggerInfo) +} diff --git a/docs/images/anthropic.svg b/docs/images/anthropic.svg new file mode 100644 index 00000000..1702fb45 --- /dev/null +++ b/docs/images/anthropic.svg @@ -0,0 +1,8 @@ + + anthropic_BIG copy-svg + + + + \ No newline at end of file diff --git a/docs/images/azure.svg b/docs/images/azure.svg new file mode 100644 index 00000000..445315a5 --- /dev/null +++ b/docs/images/azure.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/images/bard.svg b/docs/images/bard.svg new file mode 100644 index 00000000..4a943308 --- /dev/null +++ b/docs/images/bard.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/images/cohere.png b/docs/images/cohere.png new file mode 100644 index 00000000..3da0b837 Binary files /dev/null and b/docs/images/cohere.png differ diff --git a/docs/images/glide.png b/docs/images/glide.png new file mode 100644 index 00000000..48eb15a8 Binary files /dev/null and b/docs/images/glide.png differ diff --git a/docs/images/localai.webp b/docs/images/localai.webp new file mode 100644 index 00000000..7dbad578 Binary files /dev/null and b/docs/images/localai.webp differ diff --git a/docs/images/octo.png b/docs/images/octo.png new file mode 100644 index 00000000..a4ca6d5c Binary files /dev/null and b/docs/images/octo.png differ diff --git a/docs/images/openai.svg b/docs/images/openai.svg new file mode 100644 index 00000000..7ca399ea --- /dev/null +++ b/docs/images/openai.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/swagger.json b/docs/swagger.json new file mode 100644 index 00000000..4722323c --- /dev/null +++ b/docs/swagger.json @@ -0,0 +1,691 @@ +{ + "schemes": [ + "http" + ], + "swagger": "2.0", + "info": { + "description": "API documentation for Glide, an open-source lightweight high-performance model gateway", + "title": "Glide Gateway", + "contact": { + "name": "Glide Community", + "url": "https://github.com/modelgateway/glide" + }, + "license": { + "name": "Apache 2.0", + "url": "https://github.com/modelgateway/glide/blob/develop/LICENSE" + }, + "version": "1.0" + }, + "host": "localhost:9099", + "basePath": "/", + "paths": { + "/v1/health/": { + "get": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Operations" + ], + "summary": "Gateway Health", + "operationId": "glide-health", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/http.HealthSchema" + } + } + } + } + }, + "/v1/language/": { + "get": { + "description": "Retrieve list of configured language routers and their configurations", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Router List", + "operationId": "glide-language-routers", + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/http.RouterListSchema" + } + } + } + } + }, + "/v1/language/{router}/chat": { + "post": { + "description": "Talk to different LLMs Chat API via unified endpoint", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Chat", + "operationId": "glide-language-chat", + "parameters": [ + { + "type": "string", + "description": "Router ID", + "name": "router", + "in": "path", + "required": true + }, + { + "description": "Request Data", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/schemas.UnifiedChatRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schemas.UnifiedChatResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + } + } + } + } + }, + "definitions": { + "anthropic.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/anthropic.Params" + }, + "model": { + "type": "string" + } + } + }, + "anthropic.Params": { + "type": "object", + "properties": { + "max_tokens": { + "type": "integer" + }, + "metadata": { + "type": "string" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "system": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_k": { + "type": "integer" + }, + "top_p": { + "type": "number" + } + } + }, + "azureopenai.Config": { + "type": "object", + "required": [ + "apiVersion", + "baseUrl", + "model" + ], + "properties": { + "apiVersion": { + "description": "The API version to use for this operation. This follows the YYYY-MM-DD format (e.g 2023-05-15)", + "type": "string" + }, + "baseUrl": { + "description": "The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/)", + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/azureopenai.Params" + }, + "model": { + "description": "This is your deployment name. You're required to first deploy a model before you can make calls (e.g. glide-gpt-35)", + "type": "string" + } + } + }, + "azureopenai.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "response_format": { + "description": "TODO: should this be a part of the chat request API?" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": {}, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + } + }, + "clients.ClientConfig": { + "type": "object", + "properties": { + "timeout": { + "type": "string" + } + } + }, + "cohere.ChatHistory": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "role": { + "type": "string" + }, + "user": { + "type": "string" + } + } + }, + "cohere.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/cohere.Params" + }, + "model": { + "type": "string" + } + } + }, + "cohere.Params": { + "type": "object", + "properties": { + "chat_history": { + "type": "array", + "items": { + "$ref": "#/definitions/cohere.ChatHistory" + } + }, + "citiation_quality": { + "type": "string" + }, + "connectors": { + "type": "array", + "items": { + "type": "string" + } + }, + "conversation_id": { + "type": "string" + }, + "preamble_override": { + "type": "string" + }, + "prompt_truncation": { + "type": "string" + }, + "search_queries_only": { + "type": "boolean" + }, + "stream": { + "description": "unsupported right now", + "type": "boolean" + }, + "temperature": { + "type": "number" + } + } + }, + "http.ErrorSchema": { + "type": "object", + "properties": { + "message": { + "type": "string" + } + } + }, + "http.HealthSchema": { + "type": "object", + "properties": { + "healthy": { + "type": "boolean" + } + } + }, + "http.RouterListSchema": { + "type": "object", + "properties": { + "routers": { + "type": "array", + "items": { + "$ref": "#/definitions/routers.LangRouterConfig" + } + } + } + }, + "latency.Config": { + "type": "object", + "properties": { + "decay": { + "description": "Weight of new latency measurements", + "type": "number" + }, + "update_interval": { + "description": "How often gateway should probe models with not the lowest response latency", + "type": "string" + }, + "warmup_samples": { + "description": "The number of latency probes required to init moving average", + "type": "integer" + } + } + }, + "octoml.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/octoml.Params" + }, + "model": { + "type": "string" + } + } + }, + "octoml.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "max_tokens": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + } + } + }, + "openai.Config": { + "type": "object", + "required": [ + "baseUrl", + "chatEndpoint", + "model" + ], + "properties": { + "baseUrl": { + "type": "string" + }, + "chatEndpoint": { + "type": "string" + }, + "defaultParams": { + "$ref": "#/definitions/openai.Params" + }, + "model": { + "type": "string" + } + } + }, + "openai.Params": { + "type": "object", + "properties": { + "frequency_penalty": { + "type": "integer" + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "presence_penalty": { + "type": "integer" + }, + "response_format": { + "description": "TODO: should this be a part of the chat request API?" + }, + "seed": { + "type": "integer" + }, + "stop": { + "type": "array", + "items": { + "type": "string" + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": {}, + "tools": { + "type": "array", + "items": { + "type": "string" + } + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + } + }, + "providers.LangModelConfig": { + "type": "object", + "required": [ + "id" + ], + "properties": { + "anthropic": { + "$ref": "#/definitions/anthropic.Config" + }, + "azureopenai": { + "$ref": "#/definitions/azureopenai.Config" + }, + "client": { + "$ref": "#/definitions/clients.ClientConfig" + }, + "cohere": { + "$ref": "#/definitions/cohere.Config" + }, + "enabled": { + "description": "Is the model enabled?", + "type": "boolean" + }, + "error_budget": { + "type": "string" + }, + "id": { + "description": "Model instance ID (unique in scope of the router)", + "type": "string" + }, + "latency": { + "$ref": "#/definitions/latency.Config" + }, + "octoml": { + "$ref": "#/definitions/octoml.Config" + }, + "openai": { + "$ref": "#/definitions/openai.Config" + }, + "weight": { + "type": "integer" + } + } + }, + "retry.ExpRetryConfig": { + "type": "object", + "properties": { + "base_multiplier": { + "type": "integer" + }, + "max_delay": { + "type": "integer" + }, + "max_retries": { + "type": "integer" + }, + "min_delay": { + "type": "integer" + } + } + }, + "routers.LangRouterConfig": { + "type": "object", + "required": [ + "models", + "routers" + ], + "properties": { + "enabled": { + "description": "Is router enabled?", + "type": "boolean" + }, + "models": { + "description": "the list of models that could handle requests", + "type": "array", + "items": { + "$ref": "#/definitions/providers.LangModelConfig" + } + }, + "retry": { + "description": "retry when no healthy model is available to router", + "allOf": [ + { + "$ref": "#/definitions/retry.ExpRetryConfig" + } + ] + }, + "routers": { + "description": "Unique router ID", + "type": "string" + }, + "strategy": { + "description": "strategy on picking the next model to serve the request", + "type": "string" + } + } + }, + "schemas.ChatMessage": { + "type": "object", + "properties": { + "content": { + "description": "The content of the message.", + "type": "string" + }, + "name": { + "description": "The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,\nwith a maximum length of 64 characters.", + "type": "string" + }, + "role": { + "description": "The role of the author of this message. One of system, user, or assistant.", + "type": "string" + } + } + }, + "schemas.ProviderResponse": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "responseId": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "tokenCount": { + "$ref": "#/definitions/schemas.TokenCount" + } + } + }, + "schemas.TokenCount": { + "type": "object", + "properties": { + "promptTokens": { + "type": "number" + }, + "responseTokens": { + "type": "number" + }, + "totalTokens": { + "type": "number" + } + } + }, + "schemas.UnifiedChatRequest": { + "type": "object", + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "messageHistory": { + "type": "array", + "items": { + "$ref": "#/definitions/schemas.ChatMessage" + } + } + } + }, + "schemas.UnifiedChatResponse": { + "type": "object", + "properties": { + "cached": { + "type": "boolean" + }, + "created": { + "type": "integer" + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "modelResponse": { + "$ref": "#/definitions/schemas.ProviderResponse" + }, + "model_id": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "router": { + "type": "string" + } + } + } + } +} \ No newline at end of file diff --git a/docs/swagger.yaml b/docs/swagger.yaml new file mode 100644 index 00000000..74616bb9 --- /dev/null +++ b/docs/swagger.yaml @@ -0,0 +1,468 @@ +basePath: / +definitions: + anthropic.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/anthropic.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + anthropic.Params: + properties: + max_tokens: + type: integer + metadata: + type: string + stop: + items: + type: string + type: array + system: + type: string + temperature: + type: number + top_k: + type: integer + top_p: + type: number + type: object + azureopenai.Config: + properties: + apiVersion: + description: The API version to use for this operation. This follows the YYYY-MM-DD + format (e.g 2023-05-15) + type: string + baseUrl: + description: The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/) + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/azureopenai.Params' + model: + description: This is your deployment name. You're required to first deploy + a model before you can make calls (e.g. glide-gpt-35) + type: string + required: + - apiVersion + - baseUrl + - model + type: object + azureopenai.Params: + properties: + frequency_penalty: + type: integer + logit_bias: + additionalProperties: + type: number + type: object + max_tokens: + type: integer + "n": + type: integer + presence_penalty: + type: integer + response_format: + description: 'TODO: should this be a part of the chat request API?' + seed: + type: integer + stop: + items: + type: string + type: array + temperature: + type: number + tool_choice: {} + tools: + items: + type: string + type: array + top_p: + type: number + user: + type: string + type: object + clients.ClientConfig: + properties: + timeout: + type: string + type: object + cohere.ChatHistory: + properties: + message: + type: string + role: + type: string + user: + type: string + type: object + cohere.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/cohere.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + cohere.Params: + properties: + chat_history: + items: + $ref: '#/definitions/cohere.ChatHistory' + type: array + citiation_quality: + type: string + connectors: + items: + type: string + type: array + conversation_id: + type: string + preamble_override: + type: string + prompt_truncation: + type: string + search_queries_only: + type: boolean + stream: + description: unsupported right now + type: boolean + temperature: + type: number + type: object + http.ErrorSchema: + properties: + message: + type: string + type: object + http.HealthSchema: + properties: + healthy: + type: boolean + type: object + http.RouterListSchema: + properties: + routers: + items: + $ref: '#/definitions/routers.LangRouterConfig' + type: array + type: object + latency.Config: + properties: + decay: + description: Weight of new latency measurements + type: number + update_interval: + description: How often gateway should probe models with not the lowest response + latency + type: string + warmup_samples: + description: The number of latency probes required to init moving average + type: integer + type: object + octoml.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/octoml.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + octoml.Params: + properties: + frequency_penalty: + type: integer + max_tokens: + type: integer + presence_penalty: + type: integer + stop: + items: + type: string + type: array + temperature: + type: number + top_p: + type: number + type: object + openai.Config: + properties: + baseUrl: + type: string + chatEndpoint: + type: string + defaultParams: + $ref: '#/definitions/openai.Params' + model: + type: string + required: + - baseUrl + - chatEndpoint + - model + type: object + openai.Params: + properties: + frequency_penalty: + type: integer + logit_bias: + additionalProperties: + type: number + type: object + max_tokens: + type: integer + "n": + type: integer + presence_penalty: + type: integer + response_format: + description: 'TODO: should this be a part of the chat request API?' + seed: + type: integer + stop: + items: + type: string + type: array + temperature: + type: number + tool_choice: {} + tools: + items: + type: string + type: array + top_p: + type: number + user: + type: string + type: object + providers.LangModelConfig: + properties: + anthropic: + $ref: '#/definitions/anthropic.Config' + azureopenai: + $ref: '#/definitions/azureopenai.Config' + client: + $ref: '#/definitions/clients.ClientConfig' + cohere: + $ref: '#/definitions/cohere.Config' + enabled: + description: Is the model enabled? + type: boolean + error_budget: + type: string + id: + description: Model instance ID (unique in scope of the router) + type: string + latency: + $ref: '#/definitions/latency.Config' + octoml: + $ref: '#/definitions/octoml.Config' + openai: + $ref: '#/definitions/openai.Config' + weight: + type: integer + required: + - id + type: object + retry.ExpRetryConfig: + properties: + base_multiplier: + type: integer + max_delay: + type: integer + max_retries: + type: integer + min_delay: + type: integer + type: object + routers.LangRouterConfig: + properties: + enabled: + description: Is router enabled? + type: boolean + models: + description: the list of models that could handle requests + items: + $ref: '#/definitions/providers.LangModelConfig' + type: array + retry: + allOf: + - $ref: '#/definitions/retry.ExpRetryConfig' + description: retry when no healthy model is available to router + routers: + description: Unique router ID + type: string + strategy: + description: strategy on picking the next model to serve the request + type: string + required: + - models + - routers + type: object + schemas.ChatMessage: + properties: + content: + description: The content of the message. + type: string + name: + description: |- + The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, + with a maximum length of 64 characters. + type: string + role: + description: The role of the author of this message. One of system, user, + or assistant. + type: string + type: object + schemas.ProviderResponse: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + responseId: + additionalProperties: + type: string + type: object + tokenCount: + $ref: '#/definitions/schemas.TokenCount' + type: object + schemas.TokenCount: + properties: + promptTokens: + type: number + responseTokens: + type: number + totalTokens: + type: number + type: object + schemas.UnifiedChatRequest: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + messageHistory: + items: + $ref: '#/definitions/schemas.ChatMessage' + type: array + type: object + schemas.UnifiedChatResponse: + properties: + cached: + type: boolean + created: + type: integer + id: + type: string + model: + type: string + model_id: + type: string + modelResponse: + $ref: '#/definitions/schemas.ProviderResponse' + provider: + type: string + router: + type: string + type: object +host: localhost:9099 +info: + contact: + name: Glide Community + url: https://github.com/modelgateway/glide + description: API documentation for Glide, an open-source lightweight high-performance + model gateway + license: + name: Apache 2.0 + url: https://github.com/modelgateway/glide/blob/develop/LICENSE + title: Glide Gateway + version: "1.0" +paths: + /v1/health/: + get: + consumes: + - application/json + operationId: glide-health + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/http.HealthSchema' + summary: Gateway Health + tags: + - Operations + /v1/language/: + get: + consumes: + - application/json + description: Retrieve list of configured language routers and their configurations + operationId: glide-language-routers + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/http.RouterListSchema' + summary: Language Router List + tags: + - Language + /v1/language/{router}/chat: + post: + consumes: + - application/json + description: Talk to different LLMs Chat API via unified endpoint + operationId: glide-language-chat + parameters: + - description: Router ID + in: path + name: router + required: true + type: string + - description: Request Data + in: body + name: payload + required: true + schema: + $ref: '#/definitions/schemas.UnifiedChatRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/schemas.UnifiedChatResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/http.ErrorSchema' + "404": + description: Not Found + schema: + $ref: '#/definitions/http.ErrorSchema' + summary: Language Chat + tags: + - Language +schemes: +- http +swagger: "2.0" diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..5db4005e --- /dev/null +++ b/go.mod @@ -0,0 +1,53 @@ +module glide + +go 1.21.5 + +require ( + github.com/cloudwego/hertz v0.7.3 + github.com/hertz-contrib/logger/zap v1.1.0 + github.com/hertz-contrib/swagger v0.1.0 + github.com/spf13/cobra v1.8.0 + github.com/stretchr/testify v1.8.4 + github.com/swaggo/files v1.0.1 + github.com/swaggo/swag v1.16.2 + go.uber.org/goleak v1.3.0 + go.uber.org/multierr v1.11.0 + go.uber.org/zap v1.26.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/KyleBanks/depth v1.2.1 // indirect + github.com/andeya/ameda v1.5.3 // indirect + github.com/andeya/goutil v1.0.1 // indirect + github.com/bytedance/go-tagexpr/v2 v2.9.11 // indirect + github.com/bytedance/gopkg v0.0.0-20231219111115-a5eedbe96960 // indirect + github.com/bytedance/sonic v1.10.2 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect + github.com/chenzhuoyu/iasm v0.9.1 // indirect + github.com/cloudwego/netpoll v0.5.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-openapi/jsonpointer v0.20.2 // indirect + github.com/go-openapi/jsonreference v0.20.4 // indirect + github.com/go-openapi/spec v0.20.13 // indirect + github.com/go-openapi/swag v0.22.7 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/nyaruka/phonenumbers v1.3.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/tidwall/gjson v1.17.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + golang.org/x/arch v0.6.0 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.16.1 // indirect + google.golang.org/protobuf v1.32.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..b8dd2d9d --- /dev/null +++ b/go.sum @@ -0,0 +1,186 @@ +github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= +github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= +github.com/andeya/ameda v1.5.3 h1:SvqnhQPZwwabS8HQTRGfJwWPl2w9ZIPInHAw9aE1Wlk= +github.com/andeya/ameda v1.5.3/go.mod h1:FQDHRe1I995v6GG+8aJ7UIUToEmbdTJn/U26NCPIgXQ= +github.com/andeya/goutil v1.0.1 h1:eiYwVyAnnK0dXU5FJsNjExkJW4exUGn/xefPt3k4eXg= +github.com/andeya/goutil v1.0.1/go.mod h1:jEG5/QnnhG7yGxwFUX6Q+JGMif7sjdHmmNVjn7nhJDo= +github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= +github.com/bytedance/go-tagexpr/v2 v2.9.11 h1:jJgmoDKPKacGl0llPYbYL/+/2N+Ng0vV0ipbnVssXHY= +github.com/bytedance/go-tagexpr/v2 v2.9.11/go.mod h1:UAyKh4ZRLBPGsyTRFZoPqTni1TlojMdOJXQnEIPCX84= +github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/gopkg v0.0.0-20231219111115-a5eedbe96960 h1:t2xAuIlnhWJDIpcHZEbpoVsQH1hOk9eGGaKU2dXl1PE= +github.com/bytedance/gopkg v0.0.0-20231219111115-a5eedbe96960/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= +github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= +github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= +github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= +github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= +github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= +github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0= +github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= +github.com/cloudwego/hertz v0.7.3 h1:VM1DxditA6vxI97rG5SBu4hHB24xdzDbKBQfUy7sfVE= +github.com/cloudwego/hertz v0.7.3/go.mod h1:WliNtVbwihWHHgAaIQEbVXl0O3aWj0ks1eoPrcEAnjs= +github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= +github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= +github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= +github.com/go-openapi/jsonreference v0.20.4 h1:bKlDxQxQJgwpUSgOENiMPzCTBVuc7vTdXSSgNeAhojU= +github.com/go-openapi/jsonreference v0.20.4/go.mod h1:5pZJyJP2MnYCpoeoMAql78cCHauHj0V9Lhc506VOpw4= +github.com/go-openapi/spec v0.20.13 h1:XJDIN+dLH6vqXgafnl5SUIMnzaChQ6QTo0/UPMbkIaE= +github.com/go-openapi/spec v0.20.13/go.mod h1:8EOhTpBoFiask8rrgwbLC3zmJfz4zsCUueRuPM6GNkw= +github.com/go-openapi/swag v0.22.7 h1:JWrc1uc/P9cSomxfnsFSVWoE1FW6bNbrVPmpQYpCcR8= +github.com/go-openapi/swag v0.22.7/go.mod h1:Gl91UqO+btAM0plGGxHqJcQZ1ZTy6jbmridBTsDy8A0= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= +github.com/hertz-contrib/logger/zap v1.1.0 h1:4efINiIDJrXEtAFeEdDJvc3Hye0VFxp+0X4BwaZgxNs= +github.com/hertz-contrib/logger/zap v1.1.0/go.mod h1:D/rJJgsYn+SGaHVfVqWS3vHTbbc7ODAlJO+6smWgTeE= +github.com/hertz-contrib/swagger v0.1.0 h1:FlnMPRHuvAt/3pt3KCQRZ6RH1g/agma9SU70Op2Pb58= +github.com/hertz-contrib/swagger v0.1.0/go.mod h1:Bt5i+Nyo7bGmYbuEfMArx7raf1oK+nWVgYbEvhpICKE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= +github.com/nyaruka/phonenumbers v1.3.0 h1:IFyyJfF2Elg8xGKFghWrRXzb6qAHk+Q3uPqmIgS20JQ= +github.com/nyaruka/phonenumbers v1.3.0/go.mod h1:4jyKp/BFUokLbCHyoZag+T3S1KezFVoEKtgnbpzItC4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.5/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= +github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg= +github.com/swaggo/swag v1.16.2 h1:28Pp+8DkQoV+HLzLx8RGJZXNGKbFqnuvSbAAtoxiY04= +github.com/swaggo/swag v1.16.2/go.mod h1:6YzXnDcpr0767iOejs318CwYkCQqyGer6BizOg03f+E= +github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.6.0 h1:S0JTfE48HbRj80+4tbvZDYsJ3tGv6BUU3XxyZ7CirAc= +golang.org/x/arch v0.6.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/images/Makefile b/images/Makefile new file mode 100644 index 00000000..2bb33870 --- /dev/null +++ b/images/Makefile @@ -0,0 +1,101 @@ +VENDOR ?= einstack +PROJECT ?= Glide +SOURCE ?= https://github.com/EinStack/glide +LICENSE ?= Apache-2.0 +DESCRIPTION ?= "A lightweight, cloud-native LLM gateway" +REPOSITORY ?= einstack/glide + +VERSION ?= dev +RC_PART ?= rc +COMMIT ?= $(shell git describe --dirty --always --abbrev=15) +BUILD_DATE ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") + + +# OCI Labels: https://specs.opencontainers.org/image-spec/annotations +# Test images via: docker run --rm --platform linux/amd64 -i einstack/glide:dev-alpine --config config.dev.yaml + +alpine: ## Make an alpine-based image + @echo "๐Ÿ› ๏ธ Build alpine image ($(VERSION)).." + @echo "- Commit: $(COMMIT)" + @echo "- Build Date: $(BUILD_DATE)" + @docker build .. -t $(REPOSITORY):$(VERSION)-alpine -f alpine.Dockerfile \ + --build-arg VERSION="$(VERSION)" \ + --build-arg COMMIT="$(COMMIT)" \ + --build-arg BUILD_DATE="$(BUILD_DATE)" \ + --label=org.opencontainers.image.vendor="$(VENDOR)" \ + --label=org.opencontainers.image.title="$(PROJECT)" \ + --label=org.opencontainers.image.revision="$(COMMIT)" \ + --label=org.opencontainers.image.version="$(VERSION)" \ + --label=org.opencontainers.image.created="$(BUILD_DATE)" \ + --label=org.opencontainers.image.source="$(SOURCE)" \ + --label=org.opencontainers.image.licenses="$(LICENSE)" \ + --label=org.opencontainers.image.description=$(DESCRIPTION) + +ubuntu: ## Make an ubuntu-based image + @echo "๐Ÿ› ๏ธ Build ubuntu image ($(VERSION)).." + @echo "- Commit: $(COMMIT)" + @echo "- Build Date: $(BUILD_DATE)" + @docker build .. -t $(REPOSITORY):$(VERSION)-ubuntu -f ubuntu.Dockerfile \ + --build-arg VERSION="$(VERSION)" \ + --build-arg COMMIT="$(COMMIT)" \ + --build-arg BUILD_DATE="$(BUILD_DATE)" \ + --label=org.opencontainers.image.vendor=$(VENDOR) \ + --label=org.opencontainers.image.title=$(PROJECT) \ + --label=org.opencontainers.image.revision="$(COMMIT)" \ + --label=org.opencontainers.image.version="$(VERSION)" \ + --label=org.opencontainers.image.created="$(BUILD_DATE)" \ + --label=org.opencontainers.image.source=$(SOURCE) \ + --label=org.opencontainers.image.licenses=$(LICENSE) \ + --label=org.opencontainers.image.description=$(DESCRIPTION) + +distroless: ## Make an distroless-based image + @echo "๐Ÿ› ๏ธ Build distroless image ($(VERSION)).." + @echo "- Commit: $(COMMIT)" + @echo "- Build Date: $(BUILD_DATE)" + @docker build .. -t $(REPOSITORY):$(VERSION)-distroless -f distroless.Dockerfile \ + --build-arg VERSION="$(VERSION)" \ + --build-arg COMMIT="$(COMMIT)" \ + --build-arg BUILD_DATE="$(BUILD_DATE)" \ + --label=org.opencontainers.image.vendor=$(VENDOR) \ + --label=org.opencontainers.image.title=$(PROJECT) \ + --label=org.opencontainers.image.revision="$(COMMIT)" \ + --label=org.opencontainers.image.version="$(VERSION)" \ + --label=org.opencontainers.image.created="$(BUILD_DATE)" \ + --label=org.opencontainers.image.source=$(SOURCE) \ + --label=org.opencontainers.image.licenses=$(LICENSE) \ + --label=org.opencontainers.image.description=$(DESCRIPTION) + +redhat: ## Make an Red Hat-based image + @echo "๐Ÿ› ๏ธ Build Red Hat image ($(VERSION)).." + @echo "- Commit: $(COMMIT)" + @echo "- Build Date: $(BUILD_DATE)" + @docker build .. -t $(REPOSITORY):$(VERSION)-redhat -f redhat.Dockerfile \ + --build-arg VERSION="$(VERSION)" \ + --build-arg COMMIT="$(COMMIT)" \ + --build-arg BUILD_DATE="$(BUILD_DATE)" \ + --label=org.opencontainers.image.vendor=$(VENDOR) \ + --label=org.opencontainers.image.title=$(PROJECT) \ + --label=org.opencontainers.image.revision="$(COMMIT)" \ + --label=org.opencontainers.image.version="$(VERSION)" \ + --label=org.opencontainers.image.created="$(BUILD_DATE)" \ + --label=org.opencontainers.image.source=$(SOURCE) \ + --label=org.opencontainers.image.licenses=$(LICENSE) \ + --label=org.opencontainers.image.description=$(DESCRIPTION) + +all: alpine ubuntu distroless redhat + +publish-ghcr-%: ## Push images to Github Registry + @echo "๐ŸššPushing the $* image to Github Registry.." + @docker tag $(REPOSITORY):$(VERSION)-$* ghcr.io/$(REPOSITORY):$(VERSION)-$* + @echo "- pushing ghcr.io/$(REPOSITORY):$(VERSION)-$*" + @docker push ghcr.io/$(REPOSITORY):$(VERSION)-$* + @echo $(VERSION) | grep -q $(RC_PART) || { \ + docker tag $(REPOSITORY):$(VERSION)-$* ghcr.io/$(REPOSITORY):latest-$*; \ + echo "- pushing ghcr.io/$(REPOSITORY):latest-$*"; \ + docker push ghcr.io/$(REPOSITORY):latest-$*; \ + if [ "$*" = "alpine" ]; then \ + docker tag $(REPOSITORY):$(VERSION)-$* ghcr.io/$(REPOSITORY):latest; \ + echo "- pushing ghcr.io/$(REPOSITORY):latest"; \ + docker push ghcr.io/$(REPOSITORY):latest; \ + fi; \ + } diff --git a/images/alpine.Dockerfile b/images/alpine.Dockerfile new file mode 100644 index 00000000..1454c0a7 --- /dev/null +++ b/images/alpine.Dockerfile @@ -0,0 +1,21 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.21-alpine as build + +ARG VERSION +ARG COMMIT +ARG BUILD_DATE + +ENV GOOS=linux + +WORKDIR /build + +COPY . /build/ +RUN go mod download +RUN go build -ldflags "-s -w -X glide/pkg.version=$VERSION -X glide/pkg.commitSha=$COMMIT -X glide/pkg.buildDate=$BUILD_DATE" -o /build/dist/glide + +FROM alpine:3.19 as release + +WORKDIR /bin +COPY --from=build /build/dist/glide /bin/ + +ENTRYPOINT ["/bin/glide"] diff --git a/images/distroless.Dockerfile b/images/distroless.Dockerfile new file mode 100644 index 00000000..34776aca --- /dev/null +++ b/images/distroless.Dockerfile @@ -0,0 +1,21 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.21-alpine as build + +ARG VERSION +ARG COMMIT +ARG BUILD_DATE + +ENV GOOS=linux + +WORKDIR /build + +COPY . /build/ +RUN go mod download +RUN go build -ldflags "-s -w -X glide/pkg.version=$VERSION -X glide/pkg.commitSha=$COMMIT -X glide/pkg.buildDate=$BUILD_DATE" -o /build/dist/glide + +FROM gcr.io/distroless/static-debian12:nonroot as release + +WORKDIR /bin +COPY --from=build /build/dist/glide /bin/ + +ENTRYPOINT ["/bin/glide"] diff --git a/images/redhat.Dockerfile b/images/redhat.Dockerfile new file mode 100644 index 00000000..f55c9534 --- /dev/null +++ b/images/redhat.Dockerfile @@ -0,0 +1,21 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.21-alpine as build + +ARG VERSION +ARG COMMIT +ARG BUILD_DATE + +ENV GOOS=linux + +WORKDIR /build + +COPY . /build/ +RUN go mod download +RUN go build -ldflags "-s -w -X glide/pkg.version=$VERSION -X glide/pkg.commitSha=$COMMIT -X glide/pkg.buildDate=$BUILD_DATE" -o /build/dist/glide + +FROM redhat/ubi8-micro:8.9 as release + +WORKDIR /bin +COPY --from=build /build/dist/glide /bin/ + +ENTRYPOINT ["/bin/glide"] diff --git a/images/ubuntu.Dockerfile b/images/ubuntu.Dockerfile new file mode 100644 index 00000000..7db2cb62 --- /dev/null +++ b/images/ubuntu.Dockerfile @@ -0,0 +1,21 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.21-alpine as build + +ARG VERSION +ARG COMMIT +ARG BUILD_DATE + +ENV GOOS=linux + +WORKDIR /build + +COPY . /build/ +RUN go mod download +RUN go build -ldflags "-s -w -X glide/pkg.version=$VERSION -X glide/pkg.commitSha=$COMMIT -X glide/pkg.buildDate=$BUILD_DATE" -o /build/dist/glide + +FROM ubuntu:22.04 as release + +WORKDIR /bin +COPY --from=build /build/dist/glide /bin/ + +ENTRYPOINT ["/bin/glide"] diff --git a/leak_test.go b/leak_test.go new file mode 100644 index 00000000..024a502a --- /dev/null +++ b/leak_test.go @@ -0,0 +1,11 @@ +package main + +import ( + _ "go.uber.org/goleak" +) + +// TODO: investigate why netpoll leaves pending goroutines +// https://github.com/modelgateway/Glide/issues/33 +//func TestMain(m *testing.M) { +// goleak.VerifyTestMain(m) +//} diff --git a/main.go b/main.go new file mode 100644 index 00000000..b37e583f --- /dev/null +++ b/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "log" + + "glide/pkg/cmd" +) + +// @title Glide Gateway +// @version 1.0 +// @description API documentation for Glide, an open-source lightweight high-performance model gateway + +// @contact.name Glide Community +// @contact.url https://github.com/modelgateway/glide + +// @license.name Apache 2.0 +// @license.url https://github.com/modelgateway/glide/blob/develop/LICENSE + +// @host localhost:9099 +// @BasePath / +// @schemes http +func main() { + cli := cmd.NewCLI() + + if err := cli.Execute(); err != nil { + log.Fatalf("glide run finished with error: %v", err) + } +} diff --git a/pkg/api/config.go b/pkg/api/config.go new file mode 100644 index 00000000..6f230344 --- /dev/null +++ b/pkg/api/config.go @@ -0,0 +1,14 @@ +package api + +import "glide/pkg/api/http" + +// Config defines configuration for all API types we support (e.g. HTTP, gRPC) +type Config struct { + HTTP *http.ServerConfig `yaml:"http"` +} + +func DefaultConfig() *Config { + return &Config{ + HTTP: http.DefaultServerConfig(), + } +} diff --git a/pkg/api/http/config.go b/pkg/api/http/config.go new file mode 100644 index 00000000..1e869db1 --- /dev/null +++ b/pkg/api/http/config.go @@ -0,0 +1,28 @@ +package http + +import ( + "time" + + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/network/netpoll" +) + +type ServerConfig struct { + HostPort string +} + +func DefaultServerConfig() *ServerConfig { + return &ServerConfig{ + HostPort: "127.0.0.1:9099", + } +} + +func (cfg *ServerConfig) ToServer() *server.Hertz { + // TODO: do real server build based on provided config + return server.Default( + server.WithIdleTimeout(1*time.Second), + server.WithHostPorts(cfg.HostPort), + server.WithMaxRequestBodySize(20<<20), + server.WithTransport(netpoll.NewTransporter), + ) +} diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go new file mode 100644 index 00000000..9db2e5fc --- /dev/null +++ b/pkg/api/http/handlers.go @@ -0,0 +1,116 @@ +package http + +import ( + "context" + "encoding/json" + "errors" + + "glide/pkg/api/schemas" + "glide/pkg/routers" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +type Handler = func(ctx context.Context, c *app.RequestContext) + +// Swagger 101: +// - https://github.com/swaggo/swag/tree/master/example/celler + +// LangChatHandler +// +// @id glide-language-chat +// @Summary Language Chat +// @Description Talk to different LLMs Chat API via unified endpoint +// @tags Language +// @Param router path string true "Router ID" +// @Param payload body schemas.UnifiedChatRequest true "Request Data" +// @Accept json +// @Produce json +// @Success 200 {object} schemas.UnifiedChatResponse +// @Failure 400 {object} http.ErrorSchema +// @Failure 404 {object} http.ErrorSchema +// @Router /v1/language/{router}/chat [POST] +func LangChatHandler(routerManager *routers.RouterManager) Handler { + return func(ctx context.Context, c *app.RequestContext) { + var req *schemas.UnifiedChatRequest + + err := json.Unmarshal(c.Request.Body(), &req) + if err != nil { + c.JSON(consts.StatusBadRequest, ErrorSchema{ + Message: err.Error(), + }) + + return + } + + err = c.BindJSON(&req) + if err != nil { + c.JSON(consts.StatusBadRequest, ErrorSchema{ + Message: err.Error(), + }) + + return + } + + routerID := c.Param("router") + router, err := routerManager.GetLangRouter(routerID) + + if errors.Is(err, routers.ErrRouterNotFound) { + c.JSON(consts.StatusNotFound, ErrorSchema{ + Message: err.Error(), + }) + + return + } + + resp, err := router.Chat(ctx, req) + if err != nil { + // TODO: do a better handling, not everything is going to be an internal error + c.JSON(consts.StatusInternalServerError, ErrorSchema{ + Message: err.Error(), + }) + + return + } + + c.JSON(consts.StatusOK, resp) + } +} + +// LangRoutersHandler +// +// @id glide-language-routers +// @Summary Language Router List +// @Description Retrieve list of configured language routers and their configurations +// @tags Language +// @Accept json +// @Produce json +// @Success 200 {object} http.RouterListSchema +// @Router /v1/language/ [GET] +func LangRoutersHandler(routerManager *routers.RouterManager) Handler { + return func(ctx context.Context, c *app.RequestContext) { + configuredRouters := routerManager.GetLangRouters() + cfgs := make([]*routers.LangRouterConfig, 0, len(configuredRouters)) + + for _, router := range configuredRouters { + cfgs = append(cfgs, router.Config) + } + + c.JSON(consts.StatusOK, RouterListSchema{Routers: cfgs}) + } +} + +// HealthHandler +// +// @id glide-health +// @Summary Gateway Health +// @Description +// @tags Operations +// @Accept json +// @Produce json +// @Success 200 {object} http.HealthSchema +// @Router /v1/health/ [get] +func HealthHandler(_ context.Context, c *app.RequestContext) { + c.JSON(consts.StatusOK, HealthSchema{Healthy: true}) +} diff --git a/pkg/api/http/schemas.go b/pkg/api/http/schemas.go new file mode 100644 index 00000000..e69bd94f --- /dev/null +++ b/pkg/api/http/schemas.go @@ -0,0 +1,15 @@ +package http + +import "glide/pkg/routers" + +type ErrorSchema struct { + Message string `json:"message"` +} + +type HealthSchema struct { + Healthy bool `json:"healthy"` +} + +type RouterListSchema struct { + Routers []*routers.LangRouterConfig `json:"routers"` +} diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go new file mode 100644 index 00000000..fc274b91 --- /dev/null +++ b/pkg/api/http/server.go @@ -0,0 +1,62 @@ +package http + +import ( + "context" + "fmt" + "time" + + "github.com/hertz-contrib/swagger" + swaggerFiles "github.com/swaggo/files" + _ "glide/docs" // importing docs package to include them into the binary + + "glide/pkg/routers" + + "glide/pkg/telemetry" + + "github.com/cloudwego/hertz/pkg/app/server" +) + +type Server struct { + config *ServerConfig + telemetry *telemetry.Telemetry + routerManager *routers.RouterManager + server *server.Hertz +} + +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { + srv := config.ToServer() + + return &Server{ + config: config, + telemetry: tel, + routerManager: routerManager, + server: srv, + }, nil +} + +func (srv *Server) Run() error { + defaultGroup := srv.server.Group("/v1") + + defaultGroup.GET("/language/", LangRoutersHandler(srv.routerManager)) + defaultGroup.POST("/language/:router/chat/", LangChatHandler(srv.routerManager)) + + defaultGroup.GET("/health/", HealthHandler) + + schemaDocURL := swagger.URL(fmt.Sprintf("http://%v/v1/swagger/doc.json", srv.config.HostPort)) + defaultGroup.GET("/swagger/*any", swagger.WrapHandler(swaggerFiles.Handler, schemaDocURL)) + + return srv.server.Run() +} + +func (srv *Server) Shutdown(_ context.Context) error { + exitWaitTime := srv.server.GetOptions().ExitWaitTimeout + + srv.telemetry.Logger.Info( + fmt.Sprintf("Begin graceful shutdown, wait at most %d seconds...", exitWaitTime/time.Second), + ) + + ctx, cancel := context.WithTimeout(context.Background(), exitWaitTime) + defer cancel() + + return srv.server.Shutdown(ctx) //nolint:contextcheck +} diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go new file mode 100644 index 00000000..068e0588 --- /dev/null +++ b/pkg/api/schemas/language.go @@ -0,0 +1,161 @@ +package schemas + +// UnifiedChatRequest defines Glide's Chat Request Schema unified across all language models +type UnifiedChatRequest struct { + Message ChatMessage `json:"message"` + MessageHistory []ChatMessage `json:"messageHistory"` +} + +func NewChatFromStr(message string) *UnifiedChatRequest { + return &UnifiedChatRequest{ + Message: ChatMessage{ + "human", + message, + "roma", + }, + } +} + +// UnifiedChatResponse defines Glide's Chat Response Schema unified across all language models +type UnifiedChatResponse struct { + ID string `json:"id,omitempty"` + Created int `json:"created,omitempty"` + Provider string `json:"provider,omitempty"` + RouterID string `json:"router,omitempty"` + ModelID string `json:"model_id,omitempty"` + Model string `json:"model,omitempty"` + Cached bool `json:"cached,omitempty"` + ModelResponse ProviderResponse `json:"modelResponse,omitempty"` +} + +// ProviderResponse is the unified response from the provider. + +type ProviderResponse struct { + SystemID map[string]string `json:"responseId,omitempty"` + Message ChatMessage `json:"message"` + TokenCount TokenCount `json:"tokenCount"` +} + +type TokenCount struct { + PromptTokens float64 `json:"promptTokens"` + ResponseTokens float64 `json:"responseTokens"` + TotalTokens float64 `json:"totalTokens"` +} + +// ChatMessage is a message in a chat request. +type ChatMessage struct { + // The role of the author of this message. One of system, user, or assistant. + Role string `json:"role"` + // The content of the message. + Content string `json:"content"` + // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name string `json:"name,omitempty"` +} + +// OpenAI Chat Response (also used by Azure OpenAI and OctoML) +// TODO: Should this live here? +type OpenAIChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens float64 `json:"prompt_tokens"` + CompletionTokens float64 `json:"completion_tokens"` + TotalTokens float64 `json:"total_tokens"` +} + +// Cohere Chat Response +type CohereChatCompletion struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` + ResponseID string `json:"response_id"` + TokenCount CohereTokenCount `json:"token_count"` + Citations []Citation `json:"citations"` + Documents []Documents `json:"documents"` + SearchQueries []SearchQuery `json:"search_queries"` + SearchResults []SearchResults `json:"search_results"` + Meta Meta `json:"meta"` + ToolInputs map[string]interface{} `json:"tool_inputs"` +} + +type CohereTokenCount struct { + PromptTokens float64 `json:"prompt_tokens"` + ResponseTokens float64 `json:"response_tokens"` + TotalTokens float64 `json:"total_tokens"` + BilledTokens float64 `json:"billed_tokens"` +} + +type Meta struct { + APIVersion struct { + Version string `json:"version"` + } `json:"api_version"` + BilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"billed_units"` +} + +type Citation struct { + Start int `json:"start"` + End int `json:"end"` + Text string `json:"text"` + DocumentID []string `json:"document_id"` +} + +type Documents struct { + ID string `json:"id"` + Data map[string]string `json:"data"` // TODO: This needs to be updated +} + +type SearchQuery struct { + Text string `json:"text"` + GenerationID string `json:"generation_id"` +} + +type SearchResults struct { + SearchQuery []SearchQueryObject `json:"search_query"` + Connectors []ConnectorsResponse `json:"connectors"` + DocumentID []string `json:"documentId"` +} + +type SearchQueryObject struct { + Text string `json:"text"` + GenerationID string `json:"generationId"` +} + +type ConnectorsResponse struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token"` + ContOnFail string `json:"continue_on_failure"` + Options map[string]string `json:"options"` +} + +// Anthropic Chat Response +type AnthropicChatCompletion struct { + ID string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Role string `json:"role"` + Content []Content `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence"` +} + +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} diff --git a/pkg/api/servers.go b/pkg/api/servers.go new file mode 100644 index 00000000..8b0580d9 --- /dev/null +++ b/pkg/api/servers.go @@ -0,0 +1,58 @@ +package api + +import ( + "context" + "sync" + + "glide/pkg/routers" + + "glide/pkg/telemetry" + + "glide/pkg/api/http" +) + +type ServerManager struct { + httpServer *http.Server + shutdownWG *sync.WaitGroup +} + +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { + httpServer, err := http.NewServer(cfg.HTTP, tel, router) + if err != nil { + return nil, err + } + + // TODO: init other servers like gRPC in future + + return &ServerManager{ + httpServer: httpServer, + shutdownWG: &sync.WaitGroup{}, + }, nil +} + +func (mgr *ServerManager) Start() { + if mgr.httpServer != nil { + mgr.shutdownWG.Add(1) + + go func() { + defer mgr.shutdownWG.Done() + + // TODO: log the error + err := mgr.httpServer.Run() + + println(err) + }() + } +} + +func (mgr *ServerManager) Shutdown(ctx context.Context) error { + var err error + + if mgr.httpServer != nil { + err = mgr.httpServer.Shutdown(ctx) + } + + mgr.shutdownWG.Wait() + + return err +} diff --git a/pkg/cmd/cli.go b/pkg/cmd/cli.go new file mode 100644 index 00000000..885ec2e3 --- /dev/null +++ b/pkg/cmd/cli.go @@ -0,0 +1,40 @@ +package cmd + +import ( + "glide/pkg" + "glide/pkg/config" + + "github.com/spf13/cobra" +) + +var cfgFile string + +// NewCLI Create a Glide CLI +func NewCLI() *cobra.Command { + // TODO: Chances are we could use the build in flags module in this is all we need from CLI + cli := &cobra.Command{ + Use: "glide", + Short: "๐ŸฆGlide is an open-source, lightweight, high-performance model gateway", + Long: "TODO", + Version: pkg.FullVersion, + RunE: func(cmd *cobra.Command, args []string) error { + configProvider, err := config.NewProvider().Load(cfgFile) + if err != nil { + return err + } + + gateway, err := pkg.NewGateway(configProvider) + if err != nil { + return err + } + + return gateway.Run(cmd.Context()) + }, + // SilenceUsage: true, + } + + cli.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file") + _ = cli.MarkPersistentFlagRequired("config") + + return cli +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 00000000..80208bee --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,22 @@ +package config + +import ( + "glide/pkg/api" + "glide/pkg/routers" + "glide/pkg/telemetry" +) + +// Config is a general top-level Glide configuration +type Config struct { + Telemetry *telemetry.Config `yaml:"telemetry"` + API *api.Config `yaml:"api"` + Routers routers.Config `yaml:"routers" validate:"required"` +} + +func DefaultConfig() *Config { + return &Config{ + Telemetry: telemetry.DefaultConfig(), + API: api.DefaultConfig(), + // Routers should be defined by users + } +} diff --git a/pkg/config/expander.go b/pkg/config/expander.go new file mode 100644 index 00000000..ef2c32b0 --- /dev/null +++ b/pkg/config/expander.go @@ -0,0 +1,82 @@ +package config + +import ( + "log" + "os" + "path/filepath" + "regexp" +) + +// Expander finds special directives like ${env:ENV_VAR} in the config file and fill them with actual values +type Expander struct{} + +func (e *Expander) Expand(content []byte) []byte { + expandedContent := string(content) + + expandedContent = e.expandEnvVarDirectives(expandedContent) + expandedContent = e.expandFileDirectives(expandedContent) + expandedContent = e.expandEnvVars(expandedContent) + + return []byte(expandedContent) +} + +// expandEnvVars expands $ENVAR +func (e *Expander) expandEnvVars(content string) string { + return os.Expand(content, func(str string) string { + // This allows escaping environment variable substitution via $$, e.g. + // - $FOO will be substituted with env var FOO + // - $$FOO will be replaced with $FOO + // - $$$FOO will be replaced with $ + substituted env var FOO + if str == "$" { + return "$" + } + + return os.Getenv(str) + }) +} + +// expandEnvVarDirectives expands ${env:ENVAR} directives +func (e *Expander) expandEnvVarDirectives(content string) string { + dirMatcher := regexp.MustCompile(`\$\{env:(.+?)\}`) + + return dirMatcher.ReplaceAllStringFunc(content, func(match string) string { + matches := dirMatcher.FindStringSubmatch(match) + + if len(matches) != 2 { + return match // No replacement if the pattern is not matched + } + + envVarName := matches[1] + value, exists := os.LookupEnv(envVarName) + + if !exists { + log.Printf("could not expand the env var directive: \"%s\" variable is not found", envVarName) + + return "" + } + + return value + }) +} + +// expandFileDirectives expands ${file:/path/to/file} directives +func (e *Expander) expandFileDirectives(content string) string { + dirMatcher := regexp.MustCompile(`\$\{file:(.+?)\}`) + + return dirMatcher.ReplaceAllStringFunc(content, func(match string) string { + matches := dirMatcher.FindStringSubmatch(match) + + if len(matches) != 2 { + return match // No replacement if the pattern is not matched + } + + filePath := matches[1] + content, err := os.ReadFile(filepath.Clean(filePath)) + if err != nil { + log.Printf("could not expand the file directive (${file:%s}): %v", filePath, err) + return match // Return original match if there's an error + } + + return string(content) + }) +} diff --git a/pkg/config/expander_test.go b/pkg/config/expander_test.go new file mode 100644 index 00000000..f5d2930f --- /dev/null +++ b/pkg/config/expander_test.go @@ -0,0 +1,64 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +type sampleConfig struct { + Name string `yaml:"name"` + APIKey string `yaml:"api_key"` + Messages map[string]string `yaml:"messages"` + Seeds []string `yaml:"seeds"` + Params []struct { + Name string `yaml:"name"` + Value string `yaml:"value"` + } `yaml:"params"` +} + +func TestExpander_EnvVarExpanded(t *testing.T) { + const apiKey = "ABC1234" + + const seed1 = "40" + + const seed2 = "41" + + const answerMarker = "Answer:" + + const topP = "3" + + const budget = "100" + + t.Setenv("OPENAPI_KEY", apiKey) + t.Setenv("SEED_1", seed1) + t.Setenv("SEED_2", seed2) + t.Setenv("ANSWER_MARKER", answerMarker) + t.Setenv("OPENAI_TOP_P", topP) + t.Setenv("OPENAI_BUDGET", budget) + + content, err := os.ReadFile(filepath.Clean(filepath.Join(".", "testdata", "expander.env.yaml"))) + require.NoError(t, err) + + expander := Expander{} + updatedContent := expander.Expand(content) + + var cfg *sampleConfig + + err = yaml.Unmarshal(updatedContent, &cfg) + require.NoError(t, err) + + assert.Equal(t, apiKey, cfg.APIKey) + assert.Equal(t, []string{seed1, seed2, "42"}, cfg.Seeds) + + assert.Contains(t, cfg.Messages["human"], "how $$ $ $ does") + assert.Contains(t, cfg.Messages["human"], fmt.Sprintf("$%v", answerMarker)) + + assert.Equal(t, topP, cfg.Params[0].Value) + assert.Equal(t, fmt.Sprintf("$%v", budget), cfg.Params[1].Value) +} diff --git a/pkg/config/fields/secret.go b/pkg/config/fields/secret.go new file mode 100644 index 00000000..a5038571 --- /dev/null +++ b/pkg/config/fields/secret.go @@ -0,0 +1,17 @@ +package fields + +import ( + "encoding" +) + +// Secret is a string that is marshaled in an opaque way, so we are not leaking sensitive information +type Secret string + +const maskedSecret = "[REDACTED]" + +var _ encoding.TextMarshaler = Secret("") + +// MarshalText marshals the secret as `[REDACTED]`. +func (s Secret) MarshalText() ([]byte, error) { + return []byte(maskedSecret), nil +} diff --git a/pkg/config/fields/secret_test.go b/pkg/config/fields/secret_test.go new file mode 100644 index 00000000..223617f6 --- /dev/null +++ b/pkg/config/fields/secret_test.go @@ -0,0 +1,32 @@ +package fields + +import ( + "testing" + + "gopkg.in/yaml.v3" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSecret_OpaqueOnMarshaling(t *testing.T) { + name := "OpenAI" + secretValue := "ABCDE123" + + config := struct { + APIKey Secret `json:"api_key"` + Name string `json:"name"` + }{ + APIKey: Secret(secretValue), + Name: name, + } + + rawConfig, err := yaml.Marshal(config) + require.NoError(t, err) + + rawConfigStr := string(rawConfig) + + assert.NotContains(t, rawConfigStr, secretValue) + assert.Contains(t, rawConfigStr, maskedSecret) + assert.Contains(t, rawConfigStr, name) +} diff --git a/pkg/config/provider.go b/pkg/config/provider.go new file mode 100644 index 00000000..03d4f7f9 --- /dev/null +++ b/pkg/config/provider.go @@ -0,0 +1,53 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +// Provider reads, collects, validates and process config files +type Provider struct { + expander *Expander + Config *Config +} + +// NewProvider creates a instance of Config Provider +func NewProvider() *Provider { + return &Provider{ + expander: &Expander{}, + Config: nil, + } +} + +func (p *Provider) Load(configPath string) (*Provider, error) { + content, err := os.ReadFile(filepath.Clean(configPath)) + if err != nil { + return p, fmt.Errorf("unable to read config file %v: %w", configPath, err) + } + + // process raw config + content = p.expander.Expand(content) + + // validate the config structure + cfg := DefaultConfig() + + if err := yaml.Unmarshal(content, &cfg); err != nil { + return p, fmt.Errorf("unable to parse config file %v: %w", configPath, err) + } + + // TODO: validate config values + + p.Config = cfg + + return p, nil +} + +func (p *Provider) Get() *Config { + return p.Config +} + +func (p *Provider) Start() { +} diff --git a/pkg/config/provider_test.go b/pkg/config/provider_test.go new file mode 100644 index 00000000..05f246fb --- /dev/null +++ b/pkg/config/provider_test.go @@ -0,0 +1,45 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfigProvider_NonExistingConfigFile(t *testing.T) { + _, err := NewProvider().Load("./testdata/doesntexist.yaml") + + require.Error(t, err) + require.ErrorContains(t, err, "no such file or directory") +} + +func TestConfigProvider_NonYAMLConfigFile(t *testing.T) { + _, err := NewProvider().Load("./testdata/provider.broken.yaml") + + require.Error(t, err) + require.ErrorContains(t, err, "unable to parse config file") +} + +func TestConfigProvider_ValidConfigLoaded(t *testing.T) { + configProvider := NewProvider() + configProvider, err := configProvider.Load("./testdata/provider.fullconfig.yaml") + require.NoError(t, err) + + cfg := configProvider.Get() + + langRouters := cfg.Routers.LanguageRouters + + require.Len(t, langRouters, 1) + require.True(t, langRouters[0].Enabled) + + models := langRouters[0].Models + require.Len(t, models, 1) +} + +func TestConfigProvider_NoProvider(t *testing.T) { + configProvider := NewProvider() + _, err := configProvider.Load("./testdata/provider.nomodelprovider.yaml") + + require.Error(t, err) + require.ErrorContains(t, err, "none is configured") +} diff --git a/pkg/config/testdata/expander.env.yaml b/pkg/config/testdata/expander.env.yaml new file mode 100644 index 00000000..39fa8d67 --- /dev/null +++ b/pkg/config/testdata/expander.env.yaml @@ -0,0 +1,16 @@ +name: "OpenAI" +api_key: "${env:OPENAPI_KEY}" + +messages: + human: "Hello buddy, how $$$$ $$ $ does it cost to build a startup? $$$ANSWER_MARKER" + +seeds: + - "${SEED_1}" + - "${env:SEED_2}" + - "42" + +params: + - name: top_p + value: "$OPENAI_TOP_P" + - name: budget + value: "$$${env:OPENAI_BUDGET}" diff --git a/pkg/config/testdata/provider.broken.yaml b/pkg/config/testdata/provider.broken.yaml new file mode 100644 index 00000000..e3a0043d --- /dev/null +++ b/pkg/config/testdata/provider.broken.yaml @@ -0,0 +1,6 @@ +{ + "telemetry": { + "logging": { + "level": "debug", + "encoding": "console" + diff --git a/pkg/config/testdata/provider.fullconfig.yaml b/pkg/config/testdata/provider.fullconfig.yaml new file mode 100644 index 00000000..3d960918 --- /dev/null +++ b/pkg/config/testdata/provider.fullconfig.yaml @@ -0,0 +1,17 @@ +telemetry: + logging: + level: INFO # DEBUG, INFO, WARNING, ERROR, FATAL + encoding: json # console, json + +routers: + language: + - id: simplerouter + strategy: priority + models: + - id: openai-boring + openai: + model: gpt-3.5-turbo + api_key: "ABSC@124" + default_params: + temperature: 0 + diff --git a/pkg/config/testdata/provider.nomodelprovider.yaml b/pkg/config/testdata/provider.nomodelprovider.yaml new file mode 100644 index 00000000..68af932f --- /dev/null +++ b/pkg/config/testdata/provider.nomodelprovider.yaml @@ -0,0 +1,12 @@ +telemetry: + logging: + level: INFO # DEBUG, INFO, WARNING, ERROR, FATAL + encoding: json # console, json + +routers: + language: + - id: simplerouter + strategy: priority + models: + - id: openaimodel + diff --git a/pkg/gateway.go b/pkg/gateway.go new file mode 100644 index 00000000..0df1f368 --- /dev/null +++ b/pkg/gateway.go @@ -0,0 +1,103 @@ +package pkg + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + "glide/pkg/routers" + + "glide/pkg/config" + + "glide/pkg/telemetry" + "go.uber.org/zap" + + "glide/pkg/api" + "go.uber.org/multierr" +) + +// Gateway represents an instance of running Glide gateway. +// It loads configs, start API server(s), and listen to termination signals to shut down +type Gateway struct { + // configProvider holds all configurations + configProvider *config.Provider + // telemetry holds logger, meter, and tracer + telemetry *telemetry.Telemetry + // serverManager controls API over different protocols + serverManager *api.ServerManager + // signalChannel is used to receive termination signals from the OS. + signalC chan os.Signal + // shutdownC is used to terminate the gateway + shutdownC chan struct{} +} + +func NewGateway(configProvider *config.Provider) (*Gateway, error) { + cfg := configProvider.Get() + + tel, err := telemetry.NewTelemetry(&telemetry.Config{LogConfig: cfg.Telemetry.LogConfig}) + if err != nil { + return nil, err + } + + routerManager, err := routers.NewManager(&cfg.Routers, tel) + if err != nil { + return nil, err + } + + serverManager, err := api.NewServerManager(cfg.API, tel, routerManager) + if err != nil { + return nil, err + } + + return &Gateway{ + configProvider: configProvider, + telemetry: tel, + serverManager: serverManager, + signalC: make(chan os.Signal, 3), // equal to number of signal types we expect to receive + shutdownC: make(chan struct{}), + }, nil +} + +// Run starts and runs the gateway according to given configuration +func (gw *Gateway) Run(ctx context.Context) error { + gw.configProvider.Start() + gw.serverManager.Start() + + signal.Notify(gw.signalC, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(gw.signalC) + +LOOP: + for { + select { + // TODO: Watch for config updates + case sig := <-gw.signalC: + gw.telemetry.Logger.Info("received signal from os", zap.String("signal", sig.String())) + break LOOP + case <-gw.shutdownC: + gw.telemetry.Logger.Info("received shutdown request") + break LOOP + case <-ctx.Done(): + gw.telemetry.Logger.Info("context done, terminating process") + // Call shutdown with background context as the passed in context has been canceled + return gw.shutdown(context.Background()) //nolint:contextcheck + } + } + + return gw.shutdown(ctx) +} + +func (gw *Gateway) Shutdown() { + close(gw.shutdownC) +} + +func (gw *Gateway) shutdown(ctx context.Context) error { + var errs error + + if err := gw.serverManager.Shutdown(ctx); err != nil { + errs = multierr.Append(errs, fmt.Errorf("failed to shutdown servers: %w", err)) + } + + return errs +} diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go new file mode 100644 index 00000000..11c742f0 --- /dev/null +++ b/pkg/providers/anthropic/chat.go @@ -0,0 +1,190 @@ +package anthropic + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an Anthropic-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Metadata *string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Model: cfg.Model, + System: cfg.DefaultParams.System, + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + TopK: cfg.DefaultParams.TopK, + MaxTokens: cfg.DefaultParams.MaxTokens, + Metadata: cfg.DefaultParams.Metadata, + StopSequences: cfg.DefaultParams.StopSequences, + Stream: false, // unsupported right now + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { + messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return messages +} + +// Chat sends a chat request to the specified anthropic model. +func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal anthropic chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create anthropic chat request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "anthropic chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send anthropic chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "anthropic chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var anthropicCompletion schemas.AnthropicChatCompletion + + err = json.Unmarshal(bodyBytes, &anthropicCompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse anthropic chat response", zap.Error(err)) + return nil, err + } + + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: anthropicCompletion.ID, + Created: int(time.Now().UTC().Unix()), // not provided by anthropic + Provider: providerName, + Model: anthropicCompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": anthropicCompletion.ID, + }, + Message: schemas.ChatMessage{ + Role: anthropicCompletion.Content[0].Type, + Content: anthropicCompletion.Content[0].Text, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: 0, // Anthropic doesn't send prompt tokens + ResponseTokens: 0, + TotalTokens: 0, + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/anthropic/client.go b/pkg/providers/anthropic/client.go new file mode 100644 index 00000000..c7131455 --- /dev/null +++ b/pkg/providers/anthropic/client.go @@ -0,0 +1,59 @@ +package anthropic + +import ( + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +const ( + providerName = "anthropic" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing OpenAI API +type Client struct { + baseURL string + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new OpenAI client for the OpenAI API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) + if err != nil { + return nil, err + } + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/anthropic/client_test.go b/pkg/providers/anthropic/client_test.go new file mode 100644 index 00000000..7ffb0557 --- /dev/null +++ b/pkg/providers/anthropic/client_test.go @@ -0,0 +1,104 @@ +package anthropic + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestAnthropicClient_ChatRequest(t *testing.T) { + // Anthropic Messages API: https://docs.anthropic.com/claude/reference/messages_post + AnthropicMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + AnthropicServer := httptest.NewServer(AnthropicMock) + defer AnthropicServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = AnthropicServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, "msg_013Zva2CMHLNnXjNJJKqJ2EF", response.ID) +} + +func TestAnthropicClient_BadChatRequest(t *testing.T) { + // Anthropic Messages API: https://docs.anthropic.com/claude/reference/messages_post + AnthropicMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return a non-OK status code + w.WriteHeader(http.StatusBadRequest) + }) + + AnthropicServer := httptest.NewServer(AnthropicMock) + defer AnthropicServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = AnthropicServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + + // Assert that an error is returned + require.Error(t, err) + + // Assert that the error message contains the expected substring + require.Contains(t, err.Error(), "provider is not available") + + // Assert that the response is nil + require.Nil(t, response) +} diff --git a/pkg/providers/anthropic/config.go b/pkg/providers/anthropic/config.go new file mode 100644 index 00000000..beb98734 --- /dev/null +++ b/pkg/providers/anthropic/config.go @@ -0,0 +1,65 @@ +package anthropic + +import ( + "glide/pkg/config/fields" +) + +// Params defines OpenAI-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + System string `yaml:"system,omitempty" json:"system"` + Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p"` + TopK int `yaml:"top_k,omitempty" json:"top_k"` + MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` + StopSequences []string `yaml:"stop,omitempty" json:"stop"` + Metadata *string `yaml:"metadata,omitempty" json:"metadata"` + // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment +} + +func DefaultParams() Params { + return Params{ + Temperature: 1, + TopP: 0, + TopK: 0, + MaxTokens: 250, + System: "You are a helpful assistant.", + StopSequences: []string{}, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + DefaultParams *Params `yaml:"defaultParams,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OpenAI models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "https://api.anthropic.com/v1", + ChatEndpoint: "/messages", + Model: "claude-instant-1.2", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/anthropic/testdata/chat.req.json b/pkg/providers/anthropic/testdata/chat.req.json new file mode 100644 index 00000000..e4aac07b --- /dev/null +++ b/pkg/providers/anthropic/testdata/chat.req.json @@ -0,0 +1,12 @@ +{ + "model": "claude-instant-1.2", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 1, + "top_p": 0, + "max_tokens": 100 +} diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/providers/anthropic/testdata/chat.success.json new file mode 100644 index 00000000..eaf0f6c9 --- /dev/null +++ b/pkg/providers/anthropic/testdata/chat.success.json @@ -0,0 +1,14 @@ +{ + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "type": "message", + "model": "claude-2.1", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Blue is often seen as a calming and soothing color." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null +} \ No newline at end of file diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go new file mode 100644 index 00000000..320c90ee --- /dev/null +++ b/pkg/providers/azureopenai/chat.go @@ -0,0 +1,202 @@ +package azureopenai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an Azure openai-specific request schema +type ChatRequest struct { + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + MaxTokens: cfg.DefaultParams.MaxTokens, + N: cfg.DefaultParams.N, + StopWords: cfg.DefaultParams.StopWords, + Stream: false, // unsupported right now + FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, + PresencePenalty: cfg.DefaultParams.PresencePenalty, + LogitBias: cfg.DefaultParams.LogitBias, + User: cfg.DefaultParams.User, + Seed: cfg.DefaultParams.Seed, + Tools: cfg.DefaultParams.Tools, + ToolChoice: cfg.DefaultParams.ToolChoice, + ResponseFormat: cfg.DefaultParams.ResponseFormat, + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { + messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return messages +} + +// Chat sends a chat request to the specified azure openai model. +func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal azure openai chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create azure openai chat request: %w", err) + } + + req.Header.Set("api-key", string(c.config.APIKey)) + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "azure openai chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send azure openai chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "azure openai chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var openAICompletion schemas.OpenAIChatCompletion + + err = json.Unmarshal(bodyBytes, &openAICompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + return nil, err + } + + openAICompletion.SystemFingerprint = "" // Azure OpenAI doesn't return this + + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + Model: openAICompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": openAICompletion.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: openAICompletion.Choices[0].Message.Role, + Content: openAICompletion.Choices[0].Message.Content, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: openAICompletion.Usage.PromptTokens, + ResponseTokens: openAICompletion.Usage.CompletionTokens, + TotalTokens: openAICompletion.Usage.TotalTokens, + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/azureopenai/client.go b/pkg/providers/azureopenai/client.go new file mode 100644 index 00000000..03cce2f2 --- /dev/null +++ b/pkg/providers/azureopenai/client.go @@ -0,0 +1,58 @@ +package azureopenai + +import ( + "errors" + "fmt" + "net/http" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +const ( + providerName = "azureopenai" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing Azure OpenAI API +type Client struct { + baseURL string // The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/) + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new Azure OpenAI client for the OpenAI API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", providerConfig.BaseURL, providerConfig.Model, providerConfig.APIVersion) + + fmt.Println("chatURL", chatURL) + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + // TODO: use values from the config + Timeout: *clientConfig.Timeout, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go new file mode 100644 index 00000000..5e96753b --- /dev/null +++ b/pkg/providers/azureopenai/client_test.go @@ -0,0 +1,127 @@ +package azureopenai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestAzureOpenAIClient_ChatRequest(t *testing.T) { + // AzureOpenAI Chat API: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions + azureOpenAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + azureOpenAIServer := httptest.NewServer(azureOpenAIMock) + defer azureOpenAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + providerCfg.BaseURL = azureOpenAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "user", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, "chatcmpl-8cdqrFT2lBQlHz0EDvvq6oQcRxNcZ", response.ID) +} + +func TestAzureOpenAIClient_ChatError(t *testing.T) { + azureOpenAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + }) + + azureOpenAIServer := httptest.NewServer(azureOpenAIMock) + defer azureOpenAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + providerCfg.BaseURL = azureOpenAIServer.URL + + // Verify the default configuration values + require.Equal(t, "/chat/completions", providerCfg.ChatEndpoint) + require.Equal(t, "", providerCfg.Model) + require.Equal(t, "2023-05-15", providerCfg.APIVersion) + require.NotNil(t, providerCfg.DefaultParams) + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "user", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.Error(t, err) + require.Nil(t, response) +} + +func TestDoChatRequest_ErrorResponse(t *testing.T) { + // Create a mock HTTP server that returns a non-OK status code + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + + defer mockServer.Close() + + // Create a new client with the mock server URL + client := &Client{ + httpClient: http.DefaultClient, + chatURL: mockServer.URL, + config: &Config{APIKey: "dummy_key"}, + telemetry: telemetry.NewTelemetryMock(), + } + + // Create a chat request payload + payload := &ChatRequest{ + Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, + } + + // Call the doChatRequest function + _, err := client.doChatRequest(context.Background(), payload) + + require.Error(t, err) + require.Contains(t, err.Error(), "provider is not available") +} diff --git a/pkg/providers/azureopenai/config.go b/pkg/providers/azureopenai/config.go new file mode 100644 index 00000000..29e2e978 --- /dev/null +++ b/pkg/providers/azureopenai/config.go @@ -0,0 +1,73 @@ +package azureopenai + +import ( + "glide/pkg/config/fields" +) + +// Params defines OpenAI-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p"` + MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` + N int `yaml:"n,omitempty" json:"n"` + StopWords []string `yaml:"stop,omitempty" json:"stop"` + FrequencyPenalty int `yaml:"frequency_penalty,omitempty" json:"frequency_penalty"` + PresencePenalty int `yaml:"presence_penalty,omitempty" json:"presence_penalty"` + LogitBias *map[int]float64 `yaml:"logit_bias,omitempty" json:"logit_bias"` + User *string `yaml:"user,omitempty" json:"user"` + Seed *int `yaml:"seed,omitempty" json:"seed"` + Tools []string `yaml:"tools,omitempty" json:"tools"` + ToolChoice interface{} `yaml:"tool_choice,omitempty" json:"tool_choice"` + ResponseFormat interface{} `yaml:"response_format,omitempty" json:"response_format"` // TODO: should this be a part of the chat request API? + // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment +} + +func DefaultParams() Params { + return Params{ + Temperature: 0.8, + TopP: 1, + MaxTokens: 100, + N: 1, + StopWords: []string{}, + Tools: []string{}, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"` // The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/) + ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint"` + Model string `yaml:"model" json:"model" validate:"required"` // This is your deployment name. You're required to first deploy a model before you can make calls (e.g. glide-gpt-35) + APIVersion string `yaml:"api_version" json:"apiVersion" validate:"required"` // The API version to use for this operation. This follows the YYYY-MM-DD format (e.g 2023-05-15) + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OpenAI models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "", // This needs to come from config + ChatEndpoint: "/chat/completions", + Model: "", // This needs to come from config + APIVersion: "2023-05-15", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/azureopenai/testdata/chat.req.json b/pkg/providers/azureopenai/testdata/chat.req.json new file mode 100644 index 00000000..81327b2c --- /dev/null +++ b/pkg/providers/azureopenai/testdata/chat.req.json @@ -0,0 +1,15 @@ +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8, + "top_p": 1, + "max_tokens": 100, + "n": 1, + "user": null, + "seed": null +} diff --git a/pkg/providers/azureopenai/testdata/chat.success.json b/pkg/providers/azureopenai/testdata/chat.success.json new file mode 100644 index 00000000..33354c0b --- /dev/null +++ b/pkg/providers/azureopenai/testdata/chat.success.json @@ -0,0 +1,21 @@ +{ + "id": "chatcmpl-8cdqrFT2lBQlHz0EDvvq6oQcRxNcZ", + "object": "chat.completion", + "created": 1704220345, + "model": "gpt-35-turbo", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "role": "assistant", + "content": "The biggest animal is the blue whale, which can grow up to 100 feet long and weigh as much as 200 tons." + } + } + ], + "usage": { + "prompt_tokens": 14, + "completion_tokens": 26, + "total_tokens": 40 + } +} \ No newline at end of file diff --git a/pkg/providers/clients/config.go b/pkg/providers/clients/config.go new file mode 100644 index 00000000..d01a2ab1 --- /dev/null +++ b/pkg/providers/clients/config.go @@ -0,0 +1,15 @@ +package clients + +import "time" + +type ClientConfig struct { + Timeout *time.Duration `yaml:"timeout,omitempty" json:"timeout" swaggertype:"primitive,string"` +} + +func DefaultClientConfig() *ClientConfig { + defaultTimeout := 10 * time.Second + + return &ClientConfig{ + Timeout: &defaultTimeout, + } +} diff --git a/pkg/providers/clients/errors.go b/pkg/providers/clients/errors.go new file mode 100644 index 00000000..8c704a3f --- /dev/null +++ b/pkg/providers/clients/errors.go @@ -0,0 +1,33 @@ +package clients + +import ( + "errors" + "fmt" + "time" +) + +var ErrProviderUnavailable = errors.New("provider is not available") + +type RateLimitError struct { + untilReset time.Duration +} + +func (e RateLimitError) Error() string { + return fmt.Sprintf("rate limit reached, please wait %v", e.untilReset) +} + +func (e RateLimitError) UntilReset() time.Duration { + return e.untilReset +} + +func NewRateLimitError(untilReset *time.Duration) *RateLimitError { + defaultResetTime := 1 * time.Minute + + if untilReset == nil { + untilReset = &defaultResetTime + } + + return &RateLimitError{ + untilReset: *untilReset, + } +} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go new file mode 100644 index 00000000..ffcc017c --- /dev/null +++ b/pkg/providers/cohere/chat.go @@ -0,0 +1,244 @@ +package cohere + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatHistory struct { + Role string `json:"role"` + Message string `json:"message"` + User string `json:"user,omitempty"` +} + +// ChatRequest is a request to complete a chat completion.. +type ChatRequest struct { + Model string `json:"model"` + Message string `json:"message"` + Temperature float64 `json:"temperature,omitempty"` + PreambleOverride string `json:"preamble_override,omitempty"` + ChatHistory []ChatHistory `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + CitiationQuality string `json:"citiation_quality,omitempty"` + + // Stream bool `json:"stream,omitempty"` +} + +type Connectors struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token"` + ContOnFail string `json:"continue_on_failure"` + Options map[string]string `json:"options"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Model: cfg.Model, + Temperature: cfg.DefaultParams.Temperature, + PreambleOverride: cfg.DefaultParams.PreambleOverride, + ChatHistory: cfg.DefaultParams.ChatHistory, + ConversationID: cfg.DefaultParams.ConversationID, + PromptTruncation: cfg.DefaultParams.PromptTruncation, + Connectors: cfg.DefaultParams.Connectors, + SearchQueriesOnly: cfg.DefaultParams.SearchQueriesOnly, + CitiationQuality: cfg.DefaultParams.CitiationQuality, + } +} + +// Chat sends a chat request to the specified cohere model. +func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Message = request.Message.Content + + // Build the Cohere specific ChatHistory + if len(request.MessageHistory) > 0 { + chatRequest.ChatHistory = make([]ChatHistory, len(request.MessageHistory)) + for i, message := range request.MessageHistory { + chatRequest.ChatHistory[i] = ChatHistory{ + // Copy the necessary fields from message to ChatHistory + // For example, if ChatHistory has a field called "Text", you can do: + Role: message.Role, + Message: message.Content, + User: "", + } + } + } + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal cohere chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create cohere chat request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "cohere chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send cohere chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "cohere chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode != http.StatusOK { + return c.handleErrorResponse(resp) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var responseJSON map[string]interface{} + + err = json.Unmarshal(bodyBytes, &responseJSON) + if err != nil { + c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var cohereCompletion schemas.CohereChatCompletion + + err = json.Unmarshal(bodyBytes, &cohereCompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + return nil, err + } + + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: cohereCompletion.ResponseID, + Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this + Provider: providerName, + Model: c.config.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "generationId": cohereCompletion.GenerationID, + "responseId": cohereCompletion.ResponseID, + }, + Message: schemas.ChatMessage{ + Role: "model", // TODO: Does this need to change? + Content: cohereCompletion.Text, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: cohereCompletion.TokenCount.PromptTokens, + ResponseTokens: cohereCompletion.TokenCount.ResponseTokens, + TotalTokens: cohereCompletion.TokenCount.TotalTokens, + }, + }, + } + + return &response, nil +} + +func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.UnifiedChatResponse, error) { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + return nil, err + } + + c.telemetry.Logger.Error( + "cohere chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + cooldownDelay, err := c.getCooldownDelay(resp) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + return nil, clients.ErrProviderUnavailable +} + +func (c *Client) getCooldownDelay(resp *http.Response) (time.Duration, error) { + retryAfter := resp.Header.Get("Retry-After") + + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return 0, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return cooldownDelay, nil +} diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go new file mode 100644 index 00000000..a6cc9cf5 --- /dev/null +++ b/pkg/providers/cohere/client.go @@ -0,0 +1,59 @@ +package cohere + +import ( + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +const ( + providerName = "cohere" +) + +// ErrEmptyResponse is returned when the Cohere API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing Cohere API +type Client struct { + baseURL string + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new Cohere client for the Cohere API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) + if err != nil { + return nil, err + } + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/cohere/client_test.go b/pkg/providers/cohere/client_test.go new file mode 100644 index 00000000..5e49a3e0 --- /dev/null +++ b/pkg/providers/cohere/client_test.go @@ -0,0 +1,67 @@ +// pkg/providers/cohere/client_test.go +package cohere + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "glide/pkg/providers/clients" + + "github.com/stretchr/testify/require" +) + +func TestCohereClient_ChatRequest(t *testing.T) { + cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading cohere chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + cohereServer := httptest.NewServer(cohereMock) + defer cohereServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + providerCfg.BaseURL = cohereServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, "ec9eb88b-2da5-462e-8f0f-0899d243aa2e", response.ID) +} diff --git a/pkg/providers/cohere/config.go b/pkg/providers/cohere/config.go new file mode 100644 index 00000000..1a38aefa --- /dev/null +++ b/pkg/providers/cohere/config.go @@ -0,0 +1,69 @@ +package cohere + +import ( + "glide/pkg/config/fields" +) + +// Params defines Cohere-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` // unsupported right now + PreambleOverride string `json:"preamble_override,omitempty"` + ChatHistory []ChatHistory `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + CitiationQuality string `json:"citiation_quality,omitempty"` +} + +func DefaultParams() Params { + return Params{ + Temperature: 0.3, + Stream: false, + PreambleOverride: "", + ChatHistory: nil, + ConversationID: "", + PromptTruncation: "", + Connectors: []string{}, + SearchQueriesOnly: false, + CitiationQuality: "", + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"` +} + +// DefaultConfig for Cohere models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "https://api.cohere.ai/v1", + ChatEndpoint: "/chat", + Model: "command-light", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/cohere/testdata/chat.req.json b/pkg/providers/cohere/testdata/chat.req.json new file mode 100644 index 00000000..62c659fc --- /dev/null +++ b/pkg/providers/cohere/testdata/chat.req.json @@ -0,0 +1,10 @@ +{ + "model": "command-light", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8 +} diff --git a/pkg/providers/cohere/testdata/chat.success.json b/pkg/providers/cohere/testdata/chat.success.json new file mode 100644 index 00000000..5cc8c779 --- /dev/null +++ b/pkg/providers/cohere/testdata/chat.success.json @@ -0,0 +1,21 @@ +{ + "response_id": "ec9eb88b-2da5-462e-8f0f-0899d243aa2e", + "text": "It's difficult to definitively determine the \"biggest\" animal, as different animals have varying physical characteristics and occupy different ecological niches. However, some animals that are commonly recognized as large and impressive in size include:\n\n- Polar bears: Polar bears are large carnivorous mammals native to the Arctic. They have powerful front legs and a robust body structure, allowing them to prey upon seals and other prey in their environment. Polar bears can reach up to 1,300 kg (2,700 lb) in weight and stand up to 1.8 meters (5 ft 11 in) in height on their hind legs.\n\n- Elephants: Elephants are large mammals belonging to the family Elephantidae. They are known for their enormous size, intelligence, and strong social bonds. African elephants, the largest species, can reach up to 2.8 meters (9.2 ft) in height and weigh up to 3.5 tons (7,716 lb). They are highly adapted to their environment and play a significant role in the ecosystems of the African continent.\n\n- Giraffes: Giraffes are tall, long-necked mammals found in Africa. They have distinctive patterns and long necks and legs. Giraffes can reach up to 6.7 meters (22 ft) in height and weigh up to 600 kg (1,320 lb). Their size and distinctive appearance make them notable creatures in the wild.\n\nThese animals are not generally considered the biggest in terms of weight or body mass, as they often inhabit specific ecological niches within their environments. Instead, they are recognized for their impressive size, strength, and other physical adaptations that enable them to successfully survive and thrive in their respective habitats. \n\nWould you like me to provide more information on any of these animals?", + "generation_id": "ee579745-ab9c-410c-a28c-d9b9aa6a0fcc", + "token_count": { + "prompt_tokens": 68, + "response_tokens": 360, + "total_tokens": 428, + "billed_tokens": 417 + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 57, + "output_tokens": 360 + } + }, + "tool_inputs": null +} \ No newline at end of file diff --git a/pkg/providers/config.go b/pkg/providers/config.go new file mode 100644 index 00000000..957cfd66 --- /dev/null +++ b/pkg/providers/config.go @@ -0,0 +1,127 @@ +package providers + +import ( + "errors" + "fmt" + + "glide/pkg/routers/latency" + + "glide/pkg/providers/clients" + + "glide/pkg/routers/health" + + "glide/pkg/providers/anthropic" + "glide/pkg/providers/azureopenai" + "glide/pkg/providers/cohere" + "glide/pkg/providers/octoml" + "glide/pkg/providers/openai" + "glide/pkg/telemetry" +) + +var ErrProviderNotFound = errors.New("provider not found") + +type LangModelConfig struct { + ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) + Enabled bool `yaml:"enabled" json:"enabled"` // Is the model enabled? + ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` + Latency *latency.Config `yaml:"latency" json:"latency"` + Weight int `yaml:"weight" json:"weight"` + Client *clients.ClientConfig `yaml:"client" json:"client"` + OpenAI *openai.Config `yaml:"openai" json:"openai"` + AzureOpenAI *azureopenai.Config `yaml:"azureopenai" json:"azureopenai"` + Cohere *cohere.Config `yaml:"cohere" json:"cohere"` + OctoML *octoml.Config `yaml:"octoml" json:"octoml"` + Anthropic *anthropic.Config `yaml:"anthropic" json:"anthropic"` + // Add other providers like + // Cohere *cohere.Config + // Anthropic *anthropic.Config +} + +func DefaultLangModelConfig() *LangModelConfig { + return &LangModelConfig{ + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Weight: 1, + } +} + +func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error) { + client, err := c.initClient(tel) + if err != nil { + return nil, fmt.Errorf("error initializing client: %v", err) + } + + return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil +} + +// initClient initializes the language model client based on the provided configuration. +// It takes a telemetry object as input and returns a LangModelProvider and an error. +func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangModelProvider, error) { + switch { + case c.OpenAI != nil: + return openai.NewClient(c.OpenAI, c.Client, tel) + case c.AzureOpenAI != nil: + return azureopenai.NewClient(c.AzureOpenAI, c.Client, tel) + case c.Cohere != nil: + return cohere.NewClient(c.Cohere, c.Client, tel) + case c.OctoML != nil: + return octoml.NewClient(c.OctoML, c.Client, tel) + case c.Anthropic != nil: + return anthropic.NewClient(c.Anthropic, c.Client, tel) + default: + return nil, ErrProviderNotFound + } +} + +func (c *LangModelConfig) validateOneProvider() error { + providersConfigured := 0 + + if c.OpenAI != nil { + providersConfigured++ + } + + if c.AzureOpenAI != nil { + providersConfigured++ + } + + if c.Cohere != nil { + providersConfigured++ + } + + if c.OctoML != nil { + providersConfigured++ + } + + if c.Anthropic != nil { + providersConfigured++ + } + + // check other providers here + if providersConfigured == 0 { + return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", c.ID) + } + + if providersConfigured > 1 { + return fmt.Errorf( + "exactly one provider must be cofigured for model \"%v\", %v are configured", + c.ID, + providersConfigured, + ) + } + + return nil +} + +func (c *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultLangModelConfig() + + type plain LangModelConfig // to avoid recursion + + if err := unmarshal((*plain)(c)); err != nil { + return err + } + + return c.validateOneProvider() +} diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go new file mode 100644 index 00000000..00ab6aa0 --- /dev/null +++ b/pkg/providers/octoml/chat.go @@ -0,0 +1,188 @@ +package octoml + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an octoml-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Model: cfg.Model, + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + MaxTokens: cfg.DefaultParams.MaxTokens, + StopWords: cfg.DefaultParams.StopWords, + Stream: false, // unsupported right now + FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, + PresencePenalty: cfg.DefaultParams.PresencePenalty, + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { + messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return messages +} + +// Chat sends a chat request to the specified octoml model. +func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal octoml chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create octoml chat request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "octoml chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send octoml chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "octoml chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var openAICompletion schemas.OpenAIChatCompletion // Octo uses the same response schema as OpenAI + + err = json.Unmarshal(bodyBytes, &openAICompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + return nil, err + } + + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + Model: openAICompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": openAICompletion.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: openAICompletion.Choices[0].Message.Role, + Content: openAICompletion.Choices[0].Message.Content, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: openAICompletion.Usage.PromptTokens, + ResponseTokens: openAICompletion.Usage.CompletionTokens, + TotalTokens: openAICompletion.Usage.TotalTokens, + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/octoml/client.go b/pkg/providers/octoml/client.go new file mode 100644 index 00000000..df8cff5b --- /dev/null +++ b/pkg/providers/octoml/client.go @@ -0,0 +1,59 @@ +package octoml + +import ( + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +const ( + providerName = "octoml" +) + +// ErrEmptyResponse is returned when the OctoML API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing OctoML API +type Client struct { + baseURL string + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new OctoML client for the OctoML API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) + if err != nil { + return nil, err + } + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go new file mode 100644 index 00000000..a8f0d625 --- /dev/null +++ b/pkg/providers/octoml/client_test.go @@ -0,0 +1,133 @@ +package octoml + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "glide/pkg/providers/clients" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestOctoMLClient_ChatRequest(t *testing.T) { + // OctoML Chat API: https://docs.octoai.cloud/docs/text-gen-api-docs + octoMLMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading octoml chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + octoMLServer := httptest.NewServer(octoMLMock) + defer octoMLServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + providerCfg.BaseURL = octoMLServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, providerCfg.Model, response.Model) + require.Equal(t, "cmpl-8ea213aece0747aca6d0608b02b57196", response.ID) +} + +func TestOctoMLClient_Chat_Error(t *testing.T) { + // Set up the test case + // Create a mock API server that returns an error + octoMLMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return an error + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + }) + + // Create a mock API server + octoMLServer := httptest.NewServer(octoMLMock) + defer octoMLServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + providerCfg.BaseURL = octoMLServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + // Create a chat request + request := schemas.UnifiedChatRequest{ + Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }, + } + + // Call the Chat function + _, err = client.Chat(ctx, &request) + + // Check the error + require.Error(t, err) + require.Contains(t, err.Error(), "provider is not available") +} + +func TestDoChatRequest_ErrorResponse(t *testing.T) { + // Create a mock HTTP server that returns a non-OK status code + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + + defer mockServer.Close() + + // Create a new client with the mock server URL + client := &Client{ + httpClient: http.DefaultClient, + chatURL: mockServer.URL, + config: &Config{APIKey: "dummy_key"}, + telemetry: telemetry.NewTelemetryMock(), + } + + // Create a chat request payload + payload := &ChatRequest{ + Model: "dummy_model", + Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, + } + + // Call the doChatRequest function + _, err := client.doChatRequest(context.Background(), payload) + + require.Error(t, err) + require.Contains(t, err.Error(), "provider is not available") +} diff --git a/pkg/providers/octoml/config.go b/pkg/providers/octoml/config.go new file mode 100644 index 00000000..ed79fce5 --- /dev/null +++ b/pkg/providers/octoml/config.go @@ -0,0 +1,62 @@ +package octoml + +import ( + "glide/pkg/config/fields" +) + +// Params defines OctoML-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p"` + MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` + StopWords []string `yaml:"stop,omitempty" json:"stop"` + FrequencyPenalty int `yaml:"frequency_penalty,omitempty" json:"frequency_penalty"` + PresencePenalty int `yaml:"presence_penalty,omitempty" json:"presence_penalty"` + // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment +} + +func DefaultParams() Params { + return Params{ + Temperature: 1, + TopP: 1, + MaxTokens: 100, + StopWords: []string{}, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OctoML models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "https://text.octoai.run/v1", + ChatEndpoint: "/chat/completions", + Model: "mistral-7b-instruct-fp16", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/octoml/testdata/chat.req.json b/pkg/providers/octoml/testdata/chat.req.json new file mode 100644 index 00000000..b5fa8d98 --- /dev/null +++ b/pkg/providers/octoml/testdata/chat.req.json @@ -0,0 +1,12 @@ +{ + "model": "mistral-7b-instruct-fp16", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8, + "top_p": 1, + "max_tokens": 100 +} diff --git a/pkg/providers/octoml/testdata/chat.success.json b/pkg/providers/octoml/testdata/chat.success.json new file mode 100644 index 00000000..1bbb6e12 --- /dev/null +++ b/pkg/providers/octoml/testdata/chat.success.json @@ -0,0 +1,24 @@ +{ + "id": "cmpl-8ea213aece0747aca6d0608b02b57196", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The biggest animal that has ever lived is the blue whale (Balaenoptera musculus). Blue whales can reach lengths of up to 100 feet (30.5 meters) and weights of as much as 200 tons (181 metric tonnes). Their tongues alone can weigh as much as an elephant, and their hearts can be the size of a small car. Blue whales feed primarily on krill, which they filter from", + "function_call": null + }, + "delta": null, + "finish_reason": "length" + } + ], + "created": 5399, + "model": "mistral-7b-instruct-fp16", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "completion_tokens": 150, + "prompt_tokens": 571, + "total_tokens": 721 + } +} \ No newline at end of file diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go new file mode 100644 index 00000000..f8a69525 --- /dev/null +++ b/pkg/providers/openai/chat.go @@ -0,0 +1,202 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + "go.uber.org/zap" +) + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an OpenAI-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives +func NewChatRequestFromConfig(cfg *Config) *ChatRequest { + return &ChatRequest{ + Model: cfg.Model, + Temperature: cfg.DefaultParams.Temperature, + TopP: cfg.DefaultParams.TopP, + MaxTokens: cfg.DefaultParams.MaxTokens, + N: cfg.DefaultParams.N, + StopWords: cfg.DefaultParams.StopWords, + Stream: false, // unsupported right now + FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, + PresencePenalty: cfg.DefaultParams.PresencePenalty, + LogitBias: cfg.DefaultParams.LogitBias, + User: cfg.DefaultParams.User, + Seed: cfg.DefaultParams.Seed, + Tools: cfg.DefaultParams.Tools, + ToolChoice: cfg.DefaultParams.ToolChoice, + ResponseFormat: cfg.DefaultParams.ResponseFormat, + } +} + +func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []ChatMessage { + messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return messages +} + +// Chat sends a chat request to the specified OpenAI model. +func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // Create a new chat request + chatRequest := c.createChatRequestSchema(request) + + chatResponse, err := c.doChatRequest(ctx, chatRequest) + if err != nil { + return nil, err + } + + if len(chatResponse.ModelResponse.Message.Content) == 0 { + return nil, ErrEmptyResponse + } + + return chatResponse, nil +} + +func (c *Client) createChatRequestSchema(request *schemas.UnifiedChatRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + + return chatRequest +} + +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.UnifiedChatResponse, error) { + // Build request payload + rawPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("unable to marshal openai chat request payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create openai chat request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) + req.Header.Set("Content-Type", "application/json") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.telemetry.Logger.Debug( + "openai chat request", + zap.String("chat_url", c.chatURL), + zap.Any("payload", payload), + ) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send openai chat request: %w", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err)) + } + + c.telemetry.Logger.Error( + "openai chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return nil, clients.NewRateLimitError(&cooldownDelay) + } + + // Server & client errors result in the same error to keep gateway resilient + return nil, clients.ErrProviderUnavailable + } + + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var openAICompletion schemas.OpenAIChatCompletion + + err = json.Unmarshal(bodyBytes, &openAICompletion) + if err != nil { + c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + return nil, err + } + + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + Model: openAICompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "system_fingerprint": openAICompletion.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: openAICompletion.Choices[0].Message.Role, + Content: openAICompletion.Choices[0].Message.Content, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: openAICompletion.Usage.PromptTokens, + ResponseTokens: openAICompletion.Usage.CompletionTokens, + TotalTokens: openAICompletion.Usage.TotalTokens, + }, + }, + } + + return &response, nil +} diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go new file mode 100644 index 00000000..7a825cd5 --- /dev/null +++ b/pkg/providers/openai/client.go @@ -0,0 +1,63 @@ +package openai + +import ( + "errors" + "net/http" + "net/url" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +// TODO: Explore resource pooling +// TODO: Optimize Type use +// TODO: Explore Hertz TLS & resource pooling + +const ( + providerName = "openai" +) + +// ErrEmptyResponse is returned when the OpenAI API returns an empty response. +var ( + ErrEmptyResponse = errors.New("empty response") +) + +// Client is a client for accessing OpenAI API +type Client struct { + baseURL string + chatURL string + chatRequestTemplate *ChatRequest + config *Config + httpClient *http.Client + telemetry *telemetry.Telemetry +} + +// NewClient creates a new OpenAI client for the OpenAI API. +func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { + chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) + if err != nil { + return nil, err + } + + c := &Client{ + baseURL: providerConfig.BaseURL, + chatURL: chatURL, + config: providerConfig, + chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + httpClient: &http.Client{ + Timeout: *clientConfig.Timeout, + // TODO: use values from the config + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 2, + }, + }, + telemetry: tel, + } + + return c, nil +} + +func (c *Client) Provider() string { + return providerName +} diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/client_test.go new file mode 100644 index 00000000..d026298a --- /dev/null +++ b/pkg/providers/openai/client_test.go @@ -0,0 +1,68 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/providers/clients" + + "glide/pkg/api/schemas" + + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIClient_ChatRequest(t *testing.T) { + // OpenAI Chat API: https://platform.openai.com/docs/api-reference/chat/create + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json")) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(chatResponse) + + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.UnifiedChatRequest{Message: schemas.ChatMessage{ + Role: "human", + Content: "What's the biggest animal?", + }} + + response, err := client.Chat(ctx, &request) + require.NoError(t, err) + + require.Equal(t, "chatcmpl-123", response.ID) +} diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go new file mode 100644 index 00000000..86854f3e --- /dev/null +++ b/pkg/providers/openai/config.go @@ -0,0 +1,71 @@ +package openai + +import ( + "glide/pkg/config/fields" +) + +// Params defines OpenAI-specific model params with the specific validation of values +// TODO: Add validations +type Params struct { + Temperature float64 `yaml:"temperature,omitempty" json:"temperature"` + TopP float64 `yaml:"top_p,omitempty" json:"top_p"` + MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` + N int `yaml:"n,omitempty" json:"n"` + StopWords []string `yaml:"stop,omitempty" json:"stop"` + FrequencyPenalty int `yaml:"frequency_penalty,omitempty" json:"frequency_penalty"` + PresencePenalty int `yaml:"presence_penalty,omitempty" json:"presence_penalty"` + LogitBias *map[int]float64 `yaml:"logit_bias,omitempty" json:"logit_bias"` + User *string `yaml:"user,omitempty" json:"user"` + Seed *int `yaml:"seed,omitempty" json:"seed"` + Tools []string `yaml:"tools,omitempty" json:"tools"` + ToolChoice interface{} `yaml:"tool_choice,omitempty" json:"tool_choice"` + ResponseFormat interface{} `yaml:"response_format,omitempty" json:"response_format"` // TODO: should this be a part of the chat request API? + // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment +} + +func DefaultParams() Params { + return Params{ + Temperature: 0.8, + TopP: 1, + MaxTokens: 100, + N: 1, + StopWords: []string{}, + Tools: []string{}, + } +} + +func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { + *p = DefaultParams() + + type plain Params // to avoid recursion + + return unmarshal((*plain)(p)) +} + +type Config struct { + BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"` + ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"` + Model string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` + DefaultParams *Params `yaml:"defaultParams,omitempty" json:"defaultParams"` +} + +// DefaultConfig for OpenAI models +func DefaultConfig() *Config { + defaultParams := DefaultParams() + + return &Config{ + BaseURL: "https://api.openai.com/v1", + ChatEndpoint: "/chat/completions", + Model: "gpt-3.5-turbo", + DefaultParams: &defaultParams, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = *DefaultConfig() + + type plain Config // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/providers/openai/testdata/chat.req.json b/pkg/providers/openai/testdata/chat.req.json new file mode 100644 index 00000000..81327b2c --- /dev/null +++ b/pkg/providers/openai/testdata/chat.req.json @@ -0,0 +1,15 @@ +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "human", + "content": "What's the biggest animal?" + } + ], + "temperature": 0.8, + "top_p": 1, + "max_tokens": 100, + "n": 1, + "user": null, + "seed": null +} diff --git a/pkg/providers/openai/testdata/chat.success.json b/pkg/providers/openai/testdata/chat.success.json new file mode 100644 index 00000000..8b863610 --- /dev/null +++ b/pkg/providers/openai/testdata/chat.success.json @@ -0,0 +1,21 @@ +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nHello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } +} diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go new file mode 100644 index 00000000..11e89ae7 --- /dev/null +++ b/pkg/providers/provider.go @@ -0,0 +1,107 @@ +package providers + +import ( + "context" + "errors" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/routers/health" + "glide/pkg/routers/latency" + + "glide/pkg/api/schemas" +) + +// LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests +type LangModelProvider interface { + Provider() string + Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) +} + +type Model interface { + ID() string + Healthy() bool + Latency() *latency.MovingAverage + LatencyUpdateInterval() *time.Duration + Weight() int +} + +type LanguageModel interface { + Model + LangModelProvider +} + +// LangModel wraps provider client and expend it with health & latency tracking +type LangModel struct { + modelID string + weight int + client LangModelProvider + rateLimit *health.RateLimitTracker + errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry + latency *latency.MovingAverage + latencyUpdateInterval *time.Duration +} + +func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, latencyConfig latency.Config, weight int) *LangModel { + return &LangModel{ + modelID: modelID, + client: client, + rateLimit: health.NewRateLimitTracker(), + errorBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()), + latency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), + latencyUpdateInterval: latencyConfig.UpdateInterval, + weight: weight, + } +} + +func (m *LangModel) ID() string { + return m.modelID +} + +func (m *LangModel) Provider() string { + return m.client.Provider() +} + +func (m *LangModel) Latency() *latency.MovingAverage { + return m.latency +} + +func (m *LangModel) LatencyUpdateInterval() *time.Duration { + return m.latencyUpdateInterval +} + +func (m *LangModel) Healthy() bool { + return !m.rateLimit.Limited() && m.errorBudget.HasTokens() +} + +func (m *LangModel) Weight() int { + return m.weight +} + +func (m *LangModel) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + // TODO: we may want to track time-to-first-byte to "normalize" response latency wrt response size + startedAt := time.Now() + resp, err := m.client.Chat(ctx, request) + + // Do we want to track latency in case of errors as well? + m.latency.Add(float64(time.Since(startedAt))) + + if err == nil { + // successful response + resp.ModelID = m.modelID + + return resp, err + } + + var rle *clients.RateLimitError + + if errors.As(err, &rle) { + m.rateLimit.SetLimited(rle.UntilReset()) + + return resp, err + } + + _ = m.errorBudget.Take(1) + + return resp, err +} diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go new file mode 100644 index 00000000..f408380c --- /dev/null +++ b/pkg/providers/testing.go @@ -0,0 +1,100 @@ +package providers + +import ( + "context" + "time" + + "glide/pkg/routers/latency" + + "glide/pkg/api/schemas" +) + +type ResponseMock struct { + Msg string + Err *error +} + +func (m *ResponseMock) Resp() *schemas.UnifiedChatResponse { + return &schemas.UnifiedChatResponse{ + ID: "rsp0001", + ModelResponse: schemas.ProviderResponse{ + SystemID: map[string]string{ + "ID": "0001", + }, + Message: schemas.ChatMessage{ + Content: m.Msg, + }, + }, + } +} + +type ProviderMock struct { + idx int + responses []ResponseMock +} + +func NewProviderMock(responses []ResponseMock) *ProviderMock { + return &ProviderMock{ + idx: 0, + responses: responses, + } +} + +func (c *ProviderMock) Chat(_ context.Context, _ *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + response := c.responses[c.idx] + c.idx++ + + if response.Err != nil { + return nil, *response.Err + } + + return response.Resp(), nil +} + +func (c *ProviderMock) Provider() string { + return "provider_mock" +} + +type LangModelMock struct { + modelID string + healthy bool + latency *latency.MovingAverage + weight int +} + +func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) *LangModelMock { + movingAverage := latency.NewMovingAverage(0.06, 3) + + if avgLatency > 0.0 { + movingAverage.Set(avgLatency) + } + + return &LangModelMock{ + modelID: ID, + healthy: healthy, + latency: movingAverage, + weight: weight, + } +} + +func (m *LangModelMock) ID() string { + return m.modelID +} + +func (m *LangModelMock) Healthy() bool { + return m.healthy +} + +func (m *LangModelMock) Latency() *latency.MovingAverage { + return m.latency +} + +func (m *LangModelMock) LatencyUpdateInterval() *time.Duration { + updateInterval := 30 * time.Second + + return &updateInterval +} + +func (m *LangModelMock) Weight() int { + return m.weight +} diff --git a/pkg/routers/config.go b/pkg/routers/config.go new file mode 100644 index 00000000..3dc12c7e --- /dev/null +++ b/pkg/routers/config.go @@ -0,0 +1,142 @@ +package routers + +import ( + "fmt" + + "glide/pkg/providers" + "glide/pkg/routers/retry" + "glide/pkg/routers/routing" + "glide/pkg/telemetry" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +type Config struct { + LanguageRouters []LangRouterConfig `yaml:"language"` // the list of language routers +} + +func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, error) { + routers := make([]*LangRouter, 0, len(c.LanguageRouters)) + + var errs error + + for idx, routerConfig := range c.LanguageRouters { + if !routerConfig.Enabled { + tel.Logger.Info("router is disabled, skipping", zap.String("routerID", routerConfig.ID)) + continue + } + + tel.Logger.Debug("init router", zap.String("routerID", routerConfig.ID)) + + router, err := NewLangRouter(&c.LanguageRouters[idx], tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + routers = append(routers, router) + } + + if errs != nil { + return nil, errs + } + + return routers, nil +} + +// TODO: how to specify other backoff strategies? +// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 +// LangRouterConfig +type LangRouterConfig struct { + ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID + Enabled bool `yaml:"enabled" json:"enabled"` // Is router enabled? + Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry"` // retry when no healthy model is available to router + RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string"` // strategy on picking the next model to serve the request + Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required"` // the list of models that could handle requests +} + +// BuildModels creates LanguageModel slice out of the given config +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.LanguageModel, error) { + var errs error + + models := make([]providers.LanguageModel, 0, len(c.Models)) + + for _, modelConfig := range c.Models { + if !modelConfig.Enabled { + tel.Logger.Info( + "model is disabled, skipping", + zap.String("router", c.ID), + zap.String("model", modelConfig.ID), + ) + + continue + } + + tel.Logger.Debug( + "init lang model", + zap.String("router", c.ID), + zap.String("model", modelConfig.ID), + ) + + model, err := modelConfig.ToModel(tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + models = append(models, model) + } + + if errs != nil { + return nil, errs + } + + return models, nil +} + +func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { + retryConfig := c.Retry + + return retry.NewExpRetry( + retryConfig.MaxRetries, + retryConfig.BaseMultiplier, + retryConfig.MinDelay, + retryConfig.MaxDelay, + ) +} + +func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routing.LangModelRouting, error) { + m := make([]providers.Model, 0, len(models)) + for _, model := range models { + m = append(m, model) + } + + switch c.RoutingStrategy { + case routing.Priority: + return routing.NewPriority(m), nil + case routing.RoundRobin: + return routing.NewRoundRobinRouting(m), nil + case routing.WeightedRoundRobin: + return routing.NewWeightedRoundRobin(m), nil + case routing.LeastLatency: + return routing.NewLeastLatencyRouting(m), nil + } + + return nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) +} + +func DefaultLangRouterConfig() LangRouterConfig { + return LangRouterConfig{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + } +} + +func (c *LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + *c = DefaultLangRouterConfig() + + type plain LangRouterConfig // to avoid recursion + + return unmarshal((*plain)(c)) +} diff --git a/pkg/routers/config_test.go b/pkg/routers/config_test.go new file mode 100644 index 00000000..c7270b67 --- /dev/null +++ b/pkg/routers/config_test.go @@ -0,0 +1,71 @@ +package routers + +import ( + "testing" + + "github.com/stretchr/testify/require" + "glide/pkg/providers" + "glide/pkg/providers/clients" + "glide/pkg/providers/openai" + "glide/pkg/routers/health" + "glide/pkg/routers/latency" + "glide/pkg/routers/retry" + "glide/pkg/routers/routing" + "glide/pkg/telemetry" +) + +func TestRouterConfig_BuildModels(t *testing.T) { + defaultParams := openai.DefaultParams() + + cfg := Config{ + LanguageRouters: []LangRouterConfig{ + { + ID: "first_router", + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + Models: []providers.LangModelConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + OpenAI: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }, + { + ID: "first_router", + Enabled: true, + RoutingStrategy: routing.LeastLatency, + Retry: retry.DefaultExpRetryConfig(), + Models: []providers.LangModelConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + OpenAI: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }, + }, + } + + routers, err := cfg.BuildLangRouters(telemetry.NewTelemetryMock()) + + require.NoError(t, err) + require.Len(t, routers, 2) + require.Len(t, routers[0].models, 1) + require.IsType(t, routers[0].routing, &routing.PriorityRouting{}) + require.Len(t, routers[1].models, 1) + require.IsType(t, routers[1].routing, &routing.LeastLatencyRouting{}) +} diff --git a/pkg/routers/health/buckets.go b/pkg/routers/health/buckets.go new file mode 100644 index 00000000..ca7afd53 --- /dev/null +++ b/pkg/routers/health/buckets.go @@ -0,0 +1,87 @@ +package health + +import ( + "errors" + "sync/atomic" + "time" +) + +var ErrNoTokens = errors.New("not enough tokens in the bucket") + +// TokenBucket is a lock-free concurrency-safe implementation of token bucket algo +// based on atomic operations for the max performance +// +// We are not tracking the number of tokens directly, +// but rather the time that has passed since the last token consumption +type TokenBucket struct { + timePointer uint64 + timePerToken uint64 + timePerBurst uint64 +} + +func NewTokenBucket(timePerToken, burstSize uint) *TokenBucket { + return &TokenBucket{ + timePointer: 0, + timePerToken: uint64(timePerToken), + timePerBurst: uint64(burstSize * timePerToken), + } +} + +// Take one token from the bucket +func (b *TokenBucket) Take(tokens uint64) error { + oldTime := atomic.LoadUint64(&b.timePointer) + newTime := oldTime + + timeNeeded := tokens * b.timePerToken + + for { + now := b.nowInMicro() + minTime := now - b.timePerBurst + + // Take into account burst size. + if minTime > oldTime { + newTime = minTime + } + + // Now shift by the time needed. + newTime += timeNeeded + + // Check if too many tokens. + if newTime > now { + return ErrNoTokens + } + + if atomic.CompareAndSwapUint64(&b.timePointer, oldTime, newTime) { + // consumed tokens + return nil + } + + // Otherwise load old value and try again. + oldTime = atomic.LoadUint64(&b.timePointer) + newTime = oldTime + } +} + +func (b *TokenBucket) HasTokens() bool { + return b.Tokens() >= 1.0 +} + +// Tokens returns number of available tokens in the bucket +func (b *TokenBucket) Tokens() float64 { + timePointer := atomic.LoadUint64(&b.timePointer) + now := b.nowInMicro() + minTime := now - b.timePerBurst + + newTime := timePointer + + // Take into account burst size. + if minTime > timePointer { + newTime = minTime + } + + return float64(now-newTime) / float64(b.timePerToken) +} + +func (b *TokenBucket) nowInMicro() uint64 { + return uint64(time.Now().UnixNano() / 1000.0) +} diff --git a/pkg/routers/health/buckets_test.go b/pkg/routers/health/buckets_test.go new file mode 100644 index 00000000..4abf9c13 --- /dev/null +++ b/pkg/routers/health/buckets_test.go @@ -0,0 +1,79 @@ +package health + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTokenBucket_Take(t *testing.T) { + bucketSize := 10 + bucket := NewTokenBucket(1000, uint(bucketSize)) + + for i := 0; i < bucketSize-1; i++ { + require.NoError(t, bucket.Take(1)) + require.True(t, bucket.HasTokens()) + } + + // consuming 10th token + require.NoError(t, bucket.Take(1)) + + // only 10 tokens in the bucket + require.ErrorIs(t, bucket.Take(1), ErrNoTokens) + require.False(t, bucket.HasTokens()) +} + +func TestTokenBucket_TakeConcurrently(t *testing.T) { + bucket := NewTokenBucket(10_000, 1) + wg := &sync.WaitGroup{} + + before := time.Now() + + for i := 0; i < 10; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for k := 0; k < 10; k++ { + for bucket.Take(1) != nil { + time.Sleep(10 * time.Millisecond) + } + } + }() + } + + wg.Wait() + + if time.Since(before) < 1*time.Second { + t.Fatal("Did not wait 1s") + } +} + +func TestTokenBucket_TokenNumberIsCorrect(t *testing.T) { + bucket := NewTokenBucket(1_000_000, 10) + require.InEpsilon(t, 10.0, bucket.Tokens(), 0.0001) + + require.NoError(t, bucket.Take(2)) + require.InEpsilon(t, 8.0, bucket.Tokens(), 0.0001) + + require.NoError(t, bucket.Take(2)) + require.InEpsilon(t, 6.0, bucket.Tokens(), 0.0001) + + require.NoError(t, bucket.Take(2)) + require.InEpsilon(t, 4.0, bucket.Tokens(), 0.0001) + + require.NoError(t, bucket.Take(2)) + require.InEpsilon(t, 2.0, bucket.Tokens(), 0.0001) + + require.NoError(t, bucket.Take(2)) + require.LessOrEqual(t, 0.0, bucket.Tokens()) +} + +func TestTokenBucket_TakeBurstly(t *testing.T) { + bucket := NewTokenBucket(1, 10) + + require.NoError(t, bucket.Take(10)) +} diff --git a/pkg/routers/health/error_budget.go b/pkg/routers/health/error_budget.go new file mode 100644 index 00000000..e6eae66c --- /dev/null +++ b/pkg/routers/health/error_budget.go @@ -0,0 +1,103 @@ +package health + +import ( + "fmt" + "strconv" + "strings" +) + +const budgetSeparator = "/" + +type Unit string + +const ( + MILLI Unit = "ms" + MIN Unit = "m" + SEC Unit = "s" + HOUR Unit = "h" +) + +// ErrorBudget parses human-friendly error budget representation and return it as errors & update rate pair +// Error budgets could be set as a string in the following format: "10/s", "5/ms", "100/m" "1500/h" +type ErrorBudget struct { + budget uint + unit Unit +} + +func NewErrorBudget(budget uint, unit Unit) *ErrorBudget { + return &ErrorBudget{ + budget: budget, + unit: unit, + } +} + +func DefaultErrorBudget() *ErrorBudget { + return &ErrorBudget{ + budget: 10, + unit: MIN, + } +} + +// Budget defines max allows number of errors per given time period +func (b *ErrorBudget) Budget() uint { + return b.budget +} + +// TimePerTokenMicro defines how much time do we need to wait to get one error token recovered (in microseconds) +func (b *ErrorBudget) TimePerTokenMicro() uint { + return b.unitToMicro(b.unit) / b.budget +} + +// MarshalText implements the encoding.TextMarshaler interface. +// This marshals the type and name as one string in the config. +func (b *ErrorBudget) MarshalText() (text []byte, err error) { + return []byte(b.String()), nil +} + +func (b *ErrorBudget) UnmarshalText(text []byte) error { + parts := strings.Split(string(text), budgetSeparator) + + if len(parts) != 2 { + return fmt.Errorf("invalid format") + } + + budget, err := strconv.Atoi(parts[0]) + if err != nil { + return fmt.Errorf("error parsing error number: %v", err) + } + + if budget <= 0 { + return fmt.Errorf("error number should be greater then 0 (%v given)", budget) + } + + unit := Unit(parts[1]) + + if unit != MILLI && unit != SEC && unit != MIN && unit != HOUR { + return fmt.Errorf("invalid unit (supported: ms, s, m, h)") + } + + b.budget = uint(budget) + b.unit = unit + + return nil +} + +func (b *ErrorBudget) unitToMicro(unit Unit) uint { + switch unit { + case MILLI: + return 1_000 // 1 ms = 1000 microseconds + case SEC: + return 1_000_000 // 1 s = 1,000,000 microseconds + case MIN: + return 60_000_000 // 1 m = 60,000,000 microseconds + case HOUR: + return 3_600_000_000 // 1 h = 3,600,000,000 microseconds + default: + return 1 + } +} + +// String returns the ID string representation as "type[/name]" format. +func (b *ErrorBudget) String() string { + return strconv.Itoa(int(b.budget)) + budgetSeparator + string(b.unit) +} diff --git a/pkg/routers/health/error_budget_test.go b/pkg/routers/health/error_budget_test.go new file mode 100644 index 00000000..31edde67 --- /dev/null +++ b/pkg/routers/health/error_budget_test.go @@ -0,0 +1,53 @@ +package health + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestErrorBudget_ParseValidString(t *testing.T) { + tests := map[string]struct { + input string + errors int + unit Unit + timePerToken int + }{ + "1/s": {input: "1/s", errors: 1, unit: SEC}, + "10/ms": {input: "10/ms", errors: 10, unit: MILLI}, + "1000/m": {input: "1000/m", errors: 1000, unit: MIN}, + "100000/h": {input: "100000/h", errors: 100000, unit: HOUR}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + budget := DefaultErrorBudget() + + require.NoError(t, budget.UnmarshalText([]byte(tc.input))) + require.Equal(t, tc.errors, int(budget.Budget())) + require.Equal(t, tc.unit, budget.unit) + require.Equal(t, tc.input, budget.String()) + }) + } +} + +func TestErrorBudget_ParseInvalidString(t *testing.T) { + tests := map[string]struct { + input string + }{ + "0/s": {input: "0/s"}, + "-1/s": {input: "-1/s"}, + "1.9/s": {input: "1.9/s"}, + "1,9/s": {input: "1,9/s"}, + "100/d": {input: "100/d"}, + "100/mo": {input: "100/mo"}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + budget := DefaultErrorBudget() + + require.Error(t, budget.UnmarshalText([]byte(tc.input))) + }) + } +} diff --git a/pkg/routers/health/ratelimit.go b/pkg/routers/health/ratelimit.go new file mode 100644 index 00000000..42a94655 --- /dev/null +++ b/pkg/routers/health/ratelimit.go @@ -0,0 +1,29 @@ +package health + +import "time" + +// RateLimitTracker handles rate/quota limits that often represented via 429 errors and +// has some well-defined cooldown period +type RateLimitTracker struct { + resetAt *time.Time +} + +func NewRateLimitTracker() *RateLimitTracker { + return &RateLimitTracker{ + resetAt: nil, + } +} + +func (t *RateLimitTracker) Limited() bool { + if t.resetAt != nil && time.Now().After(*t.resetAt) { + t.resetAt = nil + } + + return t.resetAt != nil +} + +func (t *RateLimitTracker) SetLimited(untilReset time.Duration) { + resetAt := time.Now().Add(untilReset) + + t.resetAt = &resetAt +} diff --git a/pkg/routers/health/ratelimit_test.go b/pkg/routers/health/ratelimit_test.go new file mode 100644 index 00000000..b61aa5d1 --- /dev/null +++ b/pkg/routers/health/ratelimit_test.go @@ -0,0 +1,19 @@ +package health + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimitTracker_ResetCorrectly(t *testing.T) { + tracker := NewRateLimitTracker() + require.False(t, tracker.Limited()) + + tracker.SetLimited(10 * time.Millisecond) + require.True(t, tracker.Limited()) + + time.Sleep(11 * time.Millisecond) + require.False(t, tracker.Limited()) +} diff --git a/pkg/routers/latency/config.go b/pkg/routers/latency/config.go new file mode 100644 index 00000000..dd1001f8 --- /dev/null +++ b/pkg/routers/latency/config.go @@ -0,0 +1,20 @@ +package latency + +import "time" + +// Config defines setting for moving average latency calculations +type Config struct { + Decay float64 `yaml:"decay" json:"decay"` // Weight of new latency measurements + WarmupSamples uint8 `yaml:"warmup_samples" json:"warmup_samples"` // The number of latency probes required to init moving average + UpdateInterval *time.Duration `yaml:"update_interval,omitempty" json:"update_interval" swaggertype:"primitive,string"` // How often gateway should probe models with not the lowest response latency +} + +func DefaultConfig() *Config { + defaultUpdateInterval := 30 * time.Second + + return &Config{ + Decay: 0.06, + WarmupSamples: 3, + UpdateInterval: &defaultUpdateInterval, + } +} diff --git a/pkg/routers/latency/config_test.go b/pkg/routers/latency/config_test.go new file mode 100644 index 00000000..64d02fc8 --- /dev/null +++ b/pkg/routers/latency/config_test.go @@ -0,0 +1,13 @@ +package latency + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLatencyConfig_Default(t *testing.T) { + config := DefaultConfig() + + require.NotEmpty(t, config) +} diff --git a/pkg/routers/latency/moving_average.go b/pkg/routers/latency/moving_average.go new file mode 100644 index 00000000..16f5bdc4 --- /dev/null +++ b/pkg/routers/latency/moving_average.go @@ -0,0 +1,75 @@ +package latency + +import "sync" + +// MovingAverage represents the exponentially weighted moving average of a series of numbers +type MovingAverage struct { + mu sync.RWMutex + // The multiplier factor by which the previous samples decay + decay float64 + // The current value of the average + value float64 + // The number of samples added to this instance. + count uint8 + // The number of samples required to start estimating average + warmupSamples uint8 +} + +func NewMovingAverage(decay float64, warmupSamples uint8) *MovingAverage { + return &MovingAverage{ + mu: sync.RWMutex{}, + decay: decay, + warmupSamples: warmupSamples, + count: 0, + value: 0, + } +} + +// Add a value to the series and updates the moving average +func (e *MovingAverage) Add(value float64) { + e.mu.Lock() + defer e.mu.Unlock() + + switch { + case e.count < e.warmupSamples: + e.count++ + e.value += value + case e.count == e.warmupSamples: + e.count++ + e.value = e.value / float64(e.warmupSamples) + e.value = (value * e.decay) + (e.value * (1 - e.decay)) + default: + e.value = (value * e.decay) + (e.value * (1 - e.decay)) + } +} + +func (e *MovingAverage) WarmedUp() bool { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.count > e.warmupSamples +} + +// Value returns the current value of the average, or 0.0 if the series hasn't +// warmed up yet +func (e *MovingAverage) Value() float64 { + e.mu.RLock() + defer e.mu.RUnlock() + + if !e.WarmedUp() { + return 0.0 + } + + return e.value +} + +// Set sets the moving average value +func (e *MovingAverage) Set(value float64) { + e.mu.Lock() + e.value = value + e.mu.Unlock() + + if !e.WarmedUp() { + e.count = e.warmupSamples + 1 + } +} diff --git a/pkg/routers/latency/moving_average_test.go b/pkg/routers/latency/moving_average_test.go new file mode 100644 index 00000000..bb59d701 --- /dev/null +++ b/pkg/routers/latency/moving_average_test.go @@ -0,0 +1,37 @@ +package latency + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMovingAverage_WarmUpAndAverage(t *testing.T) { + latencies := []float64{100, 100, 150} + movingAverage := NewMovingAverage(0.9, 3) + + for _, latency := range latencies { + movingAverage.Add(latency) + + require.False(t, movingAverage.WarmedUp()) + require.InDelta(t, 0.0, movingAverage.Value(), 0.0001) + } + + movingAverage.Add(160) + + require.True(t, movingAverage.WarmedUp()) + require.InDelta(t, 155.6667, movingAverage.Value(), 0.0001) + + movingAverage.Add(160) + require.True(t, movingAverage.WarmedUp()) + require.InDelta(t, 159.5667, movingAverage.Value(), 0.0001) +} + +func TestMovingAverage_SetValue(t *testing.T) { + movingAverage := NewMovingAverage(0.9, 3) + + movingAverage.Set(200.0) + + require.True(t, movingAverage.WarmedUp()) + require.InDelta(t, 200.0, movingAverage.Value(), 0.0001) +} diff --git a/pkg/routers/manager.go b/pkg/routers/manager.go new file mode 100644 index 00000000..e8e9e5a7 --- /dev/null +++ b/pkg/routers/manager.go @@ -0,0 +1,52 @@ +package routers + +import ( + "errors" + + "glide/pkg/telemetry" +) + +var ErrRouterNotFound = errors.New("no router found with given ID") + +type RouterManager struct { + Config *Config + telemetry *telemetry.Telemetry + langRouterMap *map[string]*LangRouter + langRouters []*LangRouter +} + +// NewManager creates a new instance of Router Manager that creates, holds and returns all routers +func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { + langRouters, err := cfg.BuildLangRouters(tel) + if err != nil { + return nil, err + } + + langRouterMap := make(map[string]*LangRouter, len(langRouters)) + + for _, router := range langRouters { + langRouterMap[router.ID()] = router + } + + manager := RouterManager{ + Config: cfg, + telemetry: tel, + langRouters: langRouters, + langRouterMap: &langRouterMap, + } + + return &manager, err +} + +func (r *RouterManager) GetLangRouters() []*LangRouter { + return r.langRouters +} + +// GetLangRouter returns a router by type and ID +func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { + if router, found := (*r.langRouterMap)[routerID]; found { + return router, nil + } + + return nil, ErrRouterNotFound +} diff --git a/pkg/routers/retry/config.go b/pkg/routers/retry/config.go new file mode 100644 index 00000000..6c3199a1 --- /dev/null +++ b/pkg/routers/retry/config.go @@ -0,0 +1,21 @@ +package retry + +import "time" + +type ExpRetryConfig struct { + MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries"` + BaseMultiplier int `yaml:"base_multiplier,omitempty" json:"base_multiplier"` + MinDelay time.Duration `yaml:"min_delay,omitempty" json:"min_delay" swaggertype:"primitive,integer"` + MaxDelay *time.Duration `yaml:"max_delay,omitempty" json:"max_delay" swaggertype:"primitive,integer"` +} + +func DefaultExpRetryConfig() *ExpRetryConfig { + maxDelay := 5 * time.Second + + return &ExpRetryConfig{ + MaxRetries: 3, + BaseMultiplier: 2, + MinDelay: 2 * time.Second, + MaxDelay: &maxDelay, + } +} diff --git a/pkg/routers/retry/exp.go b/pkg/routers/retry/exp.go new file mode 100644 index 00000000..46af5ed7 --- /dev/null +++ b/pkg/routers/retry/exp.go @@ -0,0 +1,77 @@ +package retry + +import ( + "context" + "time" +) + +// ExpRetry increase wait time exponentially with try number (delay = minDelay * baseMultiplier ^ attempt) +type ExpRetry struct { + maxRetries int + baseMultiplier int + minDelay time.Duration + maxDelay *time.Duration +} + +func NewExpRetry(maxRetries int, baseMultiplier int, minDelay time.Duration, maxDelay *time.Duration) *ExpRetry { + return &ExpRetry{ + maxRetries: maxRetries, + baseMultiplier: baseMultiplier, + minDelay: minDelay, + maxDelay: maxDelay, + } +} + +func (r *ExpRetry) Iterator() *ExpRetryIterator { + return &ExpRetryIterator{ + attempt: 0, + maxRetries: r.maxRetries, + baseMultiplier: r.baseMultiplier, + minDelay: r.minDelay, + maxDelay: r.maxDelay, + } +} + +type ExpRetryIterator struct { + attempt int + maxRetries int + baseMultiplier int + minDelay time.Duration + maxDelay *time.Duration +} + +func (i *ExpRetryIterator) HasNext() bool { + return i.attempt < i.maxRetries +} + +func (i *ExpRetryIterator) getNextWaitDuration(attempt int) time.Duration { + delay := i.minDelay + + if attempt > 0 { + delay = time.Duration(float64(delay) * float64(i.baseMultiplier<<(attempt-1))) + } + + if delay < i.minDelay { + delay = i.minDelay + } + + if i.maxDelay != nil && delay > *i.maxDelay { + delay = *i.maxDelay + } + + return delay +} + +func (i *ExpRetryIterator) WaitNext(ctx context.Context) error { + t := time.NewTimer(i.getNextWaitDuration(i.attempt)) + i.attempt++ + + defer t.Stop() + + select { + case <-t.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/pkg/routers/retry/exp_test.go b/pkg/routers/retry/exp_test.go new file mode 100644 index 00000000..ae7306b1 --- /dev/null +++ b/pkg/routers/retry/exp_test.go @@ -0,0 +1,47 @@ +package retry + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestExpRetry_RetryLoop(t *testing.T) { + maxDelay := 10 * time.Millisecond + ctx := context.Background() + + retry := NewExpRetry(3, 2, 2*time.Millisecond, &maxDelay) + + idx := 0 + iterator := retry.Iterator() + + for iterator.HasNext() { + idx++ + + require.NoError(t, iterator.WaitNext(ctx)) + } + + require.Equal(t, 3, idx) +} + +func TestExpRetry_WaitTime(t *testing.T) { + maxRetries := 4 + maxDelay := 10 * time.Millisecond + expectedDelays := []time.Duration{ + 2 * time.Millisecond, + 4 * time.Millisecond, + 8 * time.Millisecond, + 10 * time.Millisecond, + 10 * time.Millisecond, + } + + retry := NewExpRetry(maxRetries, 2, 2*time.Millisecond, &maxDelay) + + iterator := retry.Iterator() + + for attempt, expectedDelay := range expectedDelays { + require.Equal(t, expectedDelay, iterator.getNextWaitDuration(attempt)) + } +} diff --git a/pkg/routers/router.go b/pkg/routers/router.go new file mode 100644 index 00000000..5b7747c4 --- /dev/null +++ b/pkg/routers/router.go @@ -0,0 +1,111 @@ +package routers + +import ( + "context" + "errors" + + "glide/pkg/routers/retry" + "go.uber.org/zap" + + "glide/pkg/providers" + + "glide/pkg/api/schemas" + "glide/pkg/routers/routing" + "glide/pkg/telemetry" +) + +var ( + ErrNoModels = errors.New("no models configured for router") + ErrNoModelAvailable = errors.New("could not handle request because all providers are not available") +) + +type LangRouter struct { + routerID string + Config *LangRouterConfig + routing routing.LangModelRouting + retry *retry.ExpRetry + models []providers.LanguageModel + telemetry *telemetry.Telemetry +} + +func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { + models, err := cfg.BuildModels(tel) + if err != nil { + return nil, err + } + + strategy, err := cfg.BuildRouting(models) + if err != nil { + return nil, err + } + + router := &LangRouter{ + routerID: cfg.ID, + Config: cfg, + models: models, + retry: cfg.BuildRetry(), + routing: strategy, + telemetry: tel, + } + + return router, err +} + +func (r *LangRouter) ID() string { + return r.routerID +} + +func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { + if len(r.models) == 0 { + return nil, ErrNoModels + } + + retryIterator := r.retry.Iterator() + + for retryIterator.HasNext() { + modelIterator := r.routing.Iterator() + + for { + model, err := modelIterator.Next() + + if errors.Is(err, routing.ErrNoHealthyModels) { + // no healthy model in the pool. Let's retry after some time + break + } + + langModel := model.(providers.LanguageModel) + + resp, err := langModel.Chat(ctx, request) + if err != nil { + r.telemetry.Logger.Warn( + "lang model failed processing chat request", + zap.String("routerID", r.ID()), + zap.String("modelID", langModel.ID()), + zap.String("provider", langModel.Provider()), + zap.Error(err), + ) + + continue + } + + resp.RouterID = r.routerID + + return resp, nil + } + + // no providers were available to handle the request, + // so we have to wait a bit with a hope there is some available next time + r.telemetry.Logger.Warn("no healthy model found, wait and retry", zap.String("routerID", r.ID())) + + err := retryIterator.WaitNext(ctx) + if err != nil { + // something has cancelled the context + return nil, err + } + } + + // if we reach this part, then we are in trouble + r.telemetry.Logger.Error("no model was available to handle request", zap.String("routerID", r.ID())) + + return nil, ErrNoModelAvailable +} diff --git a/pkg/routers/router_test.go b/pkg/routers/router_test.go new file mode 100644 index 00000000..77fb7226 --- /dev/null +++ b/pkg/routers/router_test.go @@ -0,0 +1,244 @@ +package routers + +import ( + "context" + "testing" + "time" + + "glide/pkg/routers/latency" + + "glide/pkg/providers/clients" + + "github.com/stretchr/testify/require" + "glide/pkg/api/schemas" + "glide/pkg/providers" + "glide/pkg/routers/health" + "glide/pkg/routers/retry" + "glide/pkg/routers/routing" + "glide/pkg/telemetry" +) + +func TestLangRouter_Priority_PickFistHealthy(t *testing.T) { + budget := health.NewErrorBudget(3, health.SEC) + latConfig := latency.DefaultConfig() + + langModels := []providers.LanguageModel{ + providers.NewLangModel( + "first", + providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}}), + *budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + routing: routing.NewPriority(models), + models: langModels, + telemetry: telemetry.NewTelemetryMock(), + } + + ctx := context.Background() + req := schemas.NewChatFromStr("tell me a dad joke") + + for i := 0; i < 2; i++ { + resp, err := router.Chat(ctx, req) + + require.Equal(t, "first", resp.ModelID) + require.Equal(t, "test_router", resp.RouterID) + require.NoError(t, err) + } +} + +func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) { + budget := health.NewErrorBudget(1, health.SEC) + latConfig := latency.DefaultConfig() + langModels := []providers.LanguageModel{ + providers.NewLangModel( + "first", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "third", + providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), + *budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + expectedModels := []string{"third", "third"} + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + routing: routing.NewPriority(models), + models: langModels, + telemetry: telemetry.NewTelemetryMock(), + } + + ctx := context.Background() + req := schemas.NewChatFromStr("tell me a dad joke") + + for _, modelID := range expectedModels { + resp, err := router.Chat(ctx, req) + + require.NoError(t, err) + require.Equal(t, modelID, resp.ModelID) + require.Equal(t, "test_router", resp.RouterID) + } +} + +func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) { + budget := health.NewErrorBudget(1, health.MILLI) + latConfig := latency.DefaultConfig() + langModels := []providers.LanguageModel{ + providers.NewLangModel( + "first", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}), + *budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), + routing: routing.NewPriority(models), + models: langModels, + telemetry: telemetry.NewTelemetryMock(), + } + + resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + + require.NoError(t, err) + require.Equal(t, "first", resp.ModelID) + require.Equal(t, "test_router", resp.RouterID) +} + +func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) { + budget := health.NewErrorBudget(1, health.MIN) + latConfig := latency.DefaultConfig() + langModels := []providers.LanguageModel{ + providers.NewLangModel( + "first", + providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), + *budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), + routing: routing.NewPriority(models), + models: langModels, + telemetry: telemetry.NewTelemetryMock(), + } + + for i := 0; i < 2; i++ { + resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + + require.NoError(t, err) + require.Equal(t, "second", resp.ModelID) + require.Equal(t, "test_router", resp.RouterID) + } +} + +func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) { + budget := health.NewErrorBudget(1, health.SEC) + latConfig := latency.DefaultConfig() + langModels := []providers.LanguageModel{ + providers.NewLangModel( + "first", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), + *budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), + *budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), + routing: routing.NewPriority(models), + models: langModels, + telemetry: telemetry.NewTelemetryMock(), + } + + _, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + + require.Error(t, err) +} diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go new file mode 100644 index 00000000..2b65dc4c --- /dev/null +++ b/pkg/routers/routing/least_latency.go @@ -0,0 +1,152 @@ +package routing + +import ( + "sync" + "sync/atomic" + "time" + + "glide/pkg/providers" +) + +const ( + LeastLatency Strategy = "least_latency" +) + +// ModelSchedule defines latency update schedule for models +type ModelSchedule struct { + mu sync.RWMutex + model providers.Model + expireAt time.Time +} + +func NewSchedule(model providers.Model) *ModelSchedule { + schedule := &ModelSchedule{ + model: model, + } + + schedule.Update() + + return schedule +} + +func (s *ModelSchedule) ExpireAt() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.expireAt +} + +func (s *ModelSchedule) Expired() bool { + s.mu.RLock() + defer s.mu.RUnlock() + + return time.Now().After(s.expireAt) +} + +// Update expands the expiration deadline +func (s *ModelSchedule) Update() { + s.mu.Lock() + defer s.mu.Unlock() + + s.expireAt = time.Now().Add(*s.model.LatencyUpdateInterval()) +} + +// LeastLatencyRouting routes requests to the model that responses the fastest +// At the beginning, we try to send requests to all models to find out the quickest one. +// After that, we use the that model for some time. But we don't want to stick to that model forever (as some +// other model latency may improve over time overperform the best one), +// so we need to send some traffic to other models from time to time to update their latency stats +type LeastLatencyRouting struct { + warmupIdx atomic.Uint32 + schedules []*ModelSchedule +} + +func NewLeastLatencyRouting(models []providers.Model) *LeastLatencyRouting { + schedules := make([]*ModelSchedule, 0, len(models)) + + for _, model := range models { + schedules = append(schedules, NewSchedule(model)) + } + + return &LeastLatencyRouting{ + schedules: schedules, + } +} + +func (r *LeastLatencyRouting) Iterator() LangModelIterator { + return r +} + +// Next picks a model with the least average latency over time +// The algorithm consists of two stages: +// - warm up: Before considering model latencies we may want to collect more than one sample to make better decisions. +// To learn about latencies, we route requests to all "cold" models in round-robin manner +// - least latency selection: Once all models are warmed, we pick one with the least latency +// +// Additionally, we should update our stats as response latency is a dynamic distribution, +// we cannot simply stick to the fastest model discovered on the warmup stage (as we could overlook +// other model latencies that might have improved over time). +// For that, we introduced expiration time after which the model receives a request +// even if it was not the fastest to respond +func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop + coldSchedules := r.getColdModelSchedules() + + if len(coldSchedules) > 0 { + // warm up models + idx := r.warmupIdx.Add(1) - 1 + + schedule := coldSchedules[idx%uint32(len(coldSchedules))] + schedule.Update() + + return schedule.model, nil + } + + // latency-based routing + var nextSchedule *ModelSchedule + + for _, schedule := range r.schedules { + if !schedule.model.Healthy() { + // cannot do much with unavailable model + continue + } + + if nextSchedule == nil { + nextSchedule = schedule + continue + } + + // We pick either the earliest expired model or one with the least response latency + + if schedule.Expired() && schedule.ExpireAt().Before(nextSchedule.ExpireAt()) { + // if the model latency is expired, then it should be picked only if + // it's expiration time happened earlier than the prev picked model + nextSchedule = schedule + continue + } + + if !schedule.Expired() && !nextSchedule.Expired() && + nextSchedule.model.Latency().Value() > schedule.model.Latency().Value() { + nextSchedule = schedule + } + } + + if nextSchedule != nil { + nextSchedule.Update() + + return nextSchedule.model, nil + } + + return nil, ErrNoHealthyModels +} + +func (r *LeastLatencyRouting) getColdModelSchedules() []*ModelSchedule { + coldModels := make([]*ModelSchedule, 0, len(r.schedules)) + + for _, schedule := range r.schedules { + if schedule.model.Healthy() && !schedule.model.Latency().WarmedUp() { + coldModels = append(coldModels, schedule) + } + } + + return coldModels +} diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go new file mode 100644 index 00000000..dbbc699c --- /dev/null +++ b/pkg/routers/routing/least_latency_test.go @@ -0,0 +1,156 @@ +package routing + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "glide/pkg/providers" +) + +func TestLeastLatencyRouting_Warmup(t *testing.T) { + type Model struct { + modelID string + healthy bool + latency float64 + } + + type TestCase struct { + models []Model + expectedModelIDs []string + } + + tests := map[string]TestCase{ + "all cold models": {[]Model{{"first", true, 0.0}, {"second", true, 0.0}, {"third", true, 0.0}}, []string{"first", "second", "third"}}, + "all cold models & unhealthy": {[]Model{{"first", true, 0.0}, {"second", false, 0.0}, {"third", true, 0.0}}, []string{"first", "third", "first"}}, + "some models are warmed": {[]Model{{"first", true, 100.0}, {"second", true, 0.0}, {"third", true, 120.0}}, []string{"second", "second", "second"}}, + "cold unhealthy model": {[]Model{{"first", true, 120.0}, {"second", false, 0.0}, {"third", true, 100.0}}, []string{"third", "third", "third"}}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + models := make([]providers.Model, 0, len(tc.models)) + + for _, model := range tc.models { + models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + } + + routing := NewLeastLatencyRouting(models) + iterator := routing.Iterator() + + // loop three times over the whole pool to check if we return back to the begging of the list + for _, modelID := range tc.expectedModelIDs { + model, err := iterator.Next() + + require.NoError(t, err) + require.Equal(t, modelID, model.ID()) + } + }) + } +} + +func TestLeastLatencyRouting_Routing(t *testing.T) { + type Model struct { + modelID string + healthy bool + latency float64 + expireAt time.Time + } + + type TestCase struct { + models []Model + expectedModelIDs []string + } + + tests := map[string]TestCase{ + "no cold expired models": { + []Model{ + {"first", true, 100.0, time.Now().Add(30 * time.Second)}, + {"second", true, 80.0, time.Now().Add(30 * time.Second)}, + {"third", true, 101.0, time.Now().Add(30 * time.Second)}, + }, + []string{"second", "second", "second"}, + }, + "one expired model": { + []Model{ + {"first", true, 100.0, time.Now().Add(30 * time.Second)}, + {"second", true, 80.0, time.Now().Add(30 * time.Second)}, + {"third", true, 101.0, time.Now().Add(-30 * time.Second)}, + }, + []string{"third", "second", "second"}, + }, + "two expired models": { + []Model{ + {"first", true, 100.0, time.Now().Add(-60 * time.Second)}, + {"second", true, 80.0, time.Now().Add(30 * time.Second)}, + {"third", true, 101.0, time.Now().Add(-30 * time.Second)}, + }, + []string{"first", "third", "second"}, + }, + "all expired models": { + []Model{ + {"first", true, 100.0, time.Now().Add(-30 * time.Second)}, + {"second", true, 80.0, time.Now().Add(-20 * time.Second)}, + {"third", true, 101.0, time.Now().Add(-60 * time.Second)}, + }, + []string{"third", "first", "second"}, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + schedules := make([]*ModelSchedule, 0, len(tc.models)) + + for _, model := range tc.models { + schedules = append(schedules, &ModelSchedule{ + model: providers.NewLangModelMock( + model.modelID, + model.healthy, + model.latency, + 1, + ), + expireAt: model.expireAt, + }) + } + + routing := LeastLatencyRouting{ + schedules: schedules, + } + + iterator := routing.Iterator() + + // loop three times over the whole pool to check if we return back to the begging of the list + for _, modelID := range tc.expectedModelIDs { + model, err := iterator.Next() + + require.NoError(t, err) + require.Equal(t, modelID, model.ID()) + } + }) + } +} + +func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { + tests := map[string][]float64{ + "all cold models unhealthy": {0.0, 0.0, 0.0}, + "all warm models unhealthy": {100.0, 120.0, 150.0}, + "cold & warm models unhealthy": {0.0, 120.0, 150.0}, + } + + for name, latencies := range tests { + t.Run(name, func(t *testing.T) { + models := make([]providers.Model, 0, len(latencies)) + + for idx, latency := range latencies { + models = append(models, providers.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + } + + routing := NewLeastLatencyRouting(models) + iterator := routing.Iterator() + + _, err := iterator.Next() + require.ErrorIs(t, err, ErrNoHealthyModels) + }) + } +} diff --git a/pkg/routers/routing/priority.go b/pkg/routers/routing/priority.go new file mode 100644 index 00000000..1b0254ef --- /dev/null +++ b/pkg/routers/routing/priority.go @@ -0,0 +1,55 @@ +package routing + +import ( + "sync/atomic" + + "glide/pkg/providers" +) + +const ( + Priority Strategy = "priority" +) + +// PriorityRouting routes request to the first healthy model defined in the routing config +// +// Priority of models are defined as position of the model on the list +// (e.g. the first model definition has the highest priority, then the second model definition and so on) +type PriorityRouting struct { + models []providers.Model +} + +func NewPriority(models []providers.Model) *PriorityRouting { + return &PriorityRouting{ + models: models, + } +} + +func (r *PriorityRouting) Iterator() LangModelIterator { + iterator := PriorityIterator{ + idx: &atomic.Uint64{}, + models: r.models, + } + + return iterator +} + +type PriorityIterator struct { + idx *atomic.Uint64 + models []providers.Model +} + +func (r PriorityIterator) Next() (providers.Model, error) { + models := r.models + + for idx := int(r.idx.Load()); idx < len(models); idx = int(r.idx.Add(1)) { + model := models[idx] + + if !model.Healthy() { + continue + } + + return model, nil + } + + return nil, ErrNoHealthyModels +} diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go new file mode 100644 index 00000000..4b0d8f94 --- /dev/null +++ b/pkg/routers/routing/priority_test.go @@ -0,0 +1,60 @@ +package routing + +import ( + "testing" + + "github.com/stretchr/testify/require" + "glide/pkg/providers" +) + +func TestPriorityRouting_PickModelsInOrder(t *testing.T) { + type Model struct { + modelID string + healthy bool + } + + type TestCase struct { + models []Model + expectedModelIDs []string + } + + tests := map[string]TestCase{ + "all healthy": {[]Model{{"first", true}, {"second", true}, {"third", true}}, []string{"first", "first", "first"}}, + "first unhealthy": {[]Model{{"first", false}, {"second", true}, {"third", true}}, []string{"second", "second", "second"}}, + "first two unhealthy": {[]Model{{"first", false}, {"second", false}, {"third", true}}, []string{"third", "third", "third"}}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + models := make([]providers.Model, 0, len(tc.models)) + + for _, model := range tc.models { + models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + } + + routing := NewPriority(models) + iterator := routing.Iterator() + + // loop three times over the whole pool to check if we return back to the begging of the list + for _, modelID := range tc.expectedModelIDs { + model, err := iterator.Next() + require.NoError(t, err) + require.Equal(t, modelID, model.ID()) + } + }) + } +} + +func TestPriorityRouting_NoHealthyModels(t *testing.T) { + models := []providers.Model{ + providers.NewLangModelMock("first", false, 0, 1), + providers.NewLangModelMock("second", false, 0, 1), + providers.NewLangModelMock("third", false, 0, 1), + } + + routing := NewPriority(models) + iterator := routing.Iterator() + + _, err := iterator.Next() + require.Error(t, err) +} diff --git a/pkg/routers/routing/round_robin.go b/pkg/routers/routing/round_robin.go new file mode 100644 index 00000000..fb7e7e23 --- /dev/null +++ b/pkg/routers/routing/round_robin.go @@ -0,0 +1,46 @@ +package routing + +import ( + "sync/atomic" + + "glide/pkg/providers" +) + +const ( + RoundRobin Strategy = "round_robin" +) + +// RoundRobinRouting routes request to the next model in the list in cycle +type RoundRobinRouting struct { + idx atomic.Uint64 + models []providers.Model +} + +func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting { + return &RoundRobinRouting{ + models: models, + } +} + +func (r *RoundRobinRouting) Iterator() LangModelIterator { + return r +} + +func (r *RoundRobinRouting) Next() (providers.Model, error) { + modelLen := len(r.models) + + // in order to avoid infinite loop in case of no healthy model is available, + // we need to track whether we made a whole cycle around the model slice looking for a healthy model + for i := 0; i < modelLen; i++ { + idx := r.idx.Add(1) - 1 + model := r.models[idx%uint64(modelLen)] + + if !model.Healthy() { + continue + } + + return model, nil + } + + return nil, ErrNoHealthyModels +} diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go new file mode 100644 index 00000000..becdf69f --- /dev/null +++ b/pkg/routers/routing/round_robin_test.go @@ -0,0 +1,63 @@ +package routing + +import ( + "testing" + + "github.com/stretchr/testify/require" + "glide/pkg/providers" +) + +func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { + type Model struct { + modelID string + healthy bool + } + + type TestCase struct { + models []Model + expectedModelIDs []string + } + + tests := map[string]TestCase{ + "all healthy": {[]Model{{"first", true}, {"second", true}, {"third", true}}, []string{"first", "second", "third"}}, + "unhealthy in the middle": {[]Model{{"first", true}, {"second", false}, {"third", true}}, []string{"first", "third"}}, + "two unhealthy": {[]Model{{"first", true}, {"second", false}, {"third", false}}, []string{"first"}}, + "first unhealthy": {[]Model{{"first", false}, {"second", true}, {"third", true}}, []string{"second", "third"}}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + models := make([]providers.Model, 0, len(tc.models)) + + for _, model := range tc.models { + models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + } + + routing := NewRoundRobinRouting(models) + iterator := routing.Iterator() + + for i := 0; i < 3; i++ { + // loop three times over the whole pool to check if we return back to the begging of the list + for _, modelID := range tc.expectedModelIDs { + model, err := iterator.Next() + require.NoError(t, err) + require.Equal(t, modelID, model.ID()) + } + } + }) + } +} + +func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { + models := []providers.Model{ + providers.NewLangModelMock("first", false, 0, 1), + providers.NewLangModelMock("second", false, 0, 1), + providers.NewLangModelMock("third", false, 0, 1), + } + + routing := NewRoundRobinRouting(models) + iterator := routing.Iterator() + + _, err := iterator.Next() + require.Error(t, err) +} diff --git a/pkg/routers/routing/strategies.go b/pkg/routers/routing/strategies.go new file mode 100644 index 00000000..1cd3aab8 --- /dev/null +++ b/pkg/routers/routing/strategies.go @@ -0,0 +1,20 @@ +package routing + +import ( + "errors" + + "glide/pkg/providers" +) + +var ErrNoHealthyModels = errors.New("no healthy models found") + +// Strategy defines supported routing strategies for language routers +type Strategy string + +type LangModelRouting interface { + Iterator() LangModelIterator +} + +type LangModelIterator interface { + Next() (providers.Model, error) +} diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/routers/routing/weighted_round_robin.go new file mode 100644 index 00000000..3e06a601 --- /dev/null +++ b/pkg/routers/routing/weighted_round_robin.go @@ -0,0 +1,91 @@ +package routing + +import ( + "sync" + + "glide/pkg/providers" +) + +const ( + WeightedRoundRobin Strategy = "weighted_round_robin" +) + +type Weighter struct { + model providers.Model + currentWeight int +} + +func (w *Weighter) Current() int { + return w.currentWeight +} + +func (w *Weighter) Weight() int { + return w.model.Weight() +} + +func (w *Weighter) Incr() { + w.currentWeight += w.Weight() +} + +func (w *Weighter) Decr(totalWeight int) { + w.currentWeight -= totalWeight +} + +type WRoundRobinRouting struct { + mu sync.Mutex + weights []*Weighter +} + +func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting { + weights := make([]*Weighter, 0, len(models)) + + for _, model := range models { + weights = append(weights, &Weighter{ + model: model, + currentWeight: 0, + }) + } + + return &WRoundRobinRouting{ + weights: weights, + } +} + +func (r *WRoundRobinRouting) Iterator() LangModelIterator { + return r +} + +func (r *WRoundRobinRouting) Next() (providers.Model, error) { + r.mu.Lock() + defer r.mu.Unlock() + + totalWeight := 0 + + var maxWeighter *Weighter + + for _, weighter := range r.weights { + if !weighter.model.Healthy() { + continue + } + + weighter.Incr() + totalWeight += weighter.Weight() + + if maxWeighter == nil { + maxWeighter = weighter + continue + } + + if weighter.Current() > maxWeighter.Current() { + maxWeighter = weighter + } + } + + if maxWeighter != nil { + maxWeighter.Decr(totalWeight) + + return maxWeighter.model, nil + } + + return nil, ErrNoHealthyModels +} diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go new file mode 100644 index 00000000..71a412b3 --- /dev/null +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -0,0 +1,153 @@ +package routing + +import ( + "testing" + + "github.com/stretchr/testify/require" + "glide/pkg/providers" +) + +func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { + type Model struct { + modelID string + healthy bool + weight int + } + + type TestCase struct { + models []Model + numTries int + distribution map[string]int + } + + tests := map[string]TestCase{ + "equal weights 1": { + []Model{ + {"first", true, 1}, + {"second", true, 1}, + {"three", true, 1}, + }, + 999, + map[string]int{ + "first": 333, + "second": 333, + "three": 333, + }, + }, + "equal weights 2": { + []Model{ + {"first", true, 2}, + {"second", true, 2}, + {"three", true, 2}, + }, + 999, + map[string]int{ + "first": 333, + "second": 333, + "three": 333, + }, + }, + "4-2 split": { + []Model{ + {"first", true, 4}, + {"second", true, 2}, + {"three", true, 2}, + }, + 1000, + map[string]int{ + "first": 500, + "second": 250, + "three": 250, + }, + }, + "5-2-3 split": { + []Model{ + {"first", true, 2}, + {"second", true, 5}, + {"three", true, 3}, + }, + 1000, + map[string]int{ + "first": 200, + "second": 500, + "three": 300, + }, + }, + "1-2-3 split": { + []Model{ + {"first", true, 1}, + {"second", true, 2}, + {"three", true, 3}, + }, + 1000, + map[string]int{ + "first": 167, + "second": 333, + "three": 500, + }, + }, + "pareto split": { + []Model{ + {"first", true, 80}, + {"second", true, 20}, + }, + 1000, + map[string]int{ + "first": 800, + "second": 200, + }, + }, + "zero weight": { + []Model{ + {"first", true, 2}, + {"second", true, 0}, + {"three", true, 2}, + }, + 1000, + map[string]int{ + "first": 500, + "three": 500, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + models := make([]providers.Model, 0, len(tc.models)) + + for _, model := range tc.models { + models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + } + + routing := NewWeightedRoundRobin(models) + iterator := routing.Iterator() + + actualDistribution := make(map[string]int, len(tc.models)) + + // loop three times over the whole pool to check if we return back to the begging of the list + for i := 0; i < tc.numTries; i++ { + model, err := iterator.Next() + + require.NoError(t, err) + + actualDistribution[model.ID()]++ + } + + require.Equal(t, tc.distribution, actualDistribution) + }) + } +} + +func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { + models := []providers.Model{ + providers.NewLangModelMock("first", false, 0, 1), + providers.NewLangModelMock("second", false, 0, 2), + providers.NewLangModelMock("third", false, 0, 3), + } + + routing := NewWeightedRoundRobin(models) + iterator := routing.Iterator() + + _, err := iterator.Next() + require.Error(t, err) +} diff --git a/pkg/telemetry/logging.go b/pkg/telemetry/logging.go new file mode 100644 index 00000000..dbeeb408 --- /dev/null +++ b/pkg/telemetry/logging.go @@ -0,0 +1,106 @@ +package telemetry + +import ( + "github.com/cloudwego/hertz/pkg/common/hlog" + hertzzap "github.com/hertz-contrib/logger/zap" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type LogConfig struct { + // Level is the minimum enabled logging level. + Level zapcore.Level `yaml:"level"` + + // Encoding sets the logger's encoding. Valid values are "json", "console" + Encoding string `yaml:"encoding"` + + // DisableCaller stops annotating logs with the calling function's file name and line number. + // By default, all logs are annotated. + DisableCaller bool `yaml:"disable_caller"` + + // DisableStacktrace completely disables automatic stacktrace capturing. By + // default, stacktraces are captured for WarnLevel and above logs in + // development and ErrorLevel and above in production. + DisableStacktrace bool `yaml:"disable_stacktrace"` + + // OutputPaths is a list of URLs or file paths to write logging output to. + OutputPaths []string `yaml:"output_paths"` + + // InitialFields is a collection of fields to add to the root logger. + InitialFields map[string]interface{} `yaml:"initial_fields"` +} + +func DefaultLogConfig() *LogConfig { + return &LogConfig{ + Level: zap.InfoLevel, + Encoding: "json", + DisableCaller: false, + DisableStacktrace: false, + OutputPaths: []string{"stdout"}, + InitialFields: make(map[string]interface{}), + } +} + +func (c *LogConfig) ToZapConfig() *zap.Config { + zapConfig := zap.NewProductionConfig() + + if c.Encoding == "console" { + zapConfig = zap.NewDevelopmentConfig() + + // Human-readable timestamps for console format of logs. + zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + // Colorized plain console logs + zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + } + + zapConfig.Level = zap.NewAtomicLevelAt(c.Level) + zapConfig.DisableCaller = c.DisableCaller + zapConfig.DisableStacktrace = c.DisableStacktrace + zapConfig.OutputPaths = c.OutputPaths + zapConfig.InitialFields = c.InitialFields + + return &zapConfig +} + +func NewHertzLogger(zapConfig *zap.Config) (*hertzzap.Logger, error) { + // Both hertzzap and zap have a set of private methods that prevents from leveraging + // their native encoder & sink building functionality + // We had to copy & paste some of those to get it working + var encoder zapcore.Encoder + + if zapConfig.Encoding == "console" { + encoder = zapcore.NewConsoleEncoder(zapConfig.EncoderConfig) + } else { + encoder = zapcore.NewJSONEncoder(zapConfig.EncoderConfig) + } + + sink, _, err := zap.Open(zapConfig.OutputPaths...) + if err != nil { + return nil, err + } + + return hertzzap.NewLogger( + hertzzap.WithCoreEnc(encoder), + hertzzap.WithCoreWs(sink), + hertzzap.WithCoreLevel(zapConfig.Level), + hertzzap.WithZapOptions(zap.AddCallerSkip(3)), + ), nil +} + +func NewLogger(cfg *LogConfig) (*zap.Logger, error) { + zapConfig := cfg.ToZapConfig() + + logger, err := zapConfig.Build() + if err != nil { + return nil, err + } + + hertzLogger, err := NewHertzLogger(zapConfig) + if err != nil { + return nil, err + } + + hlog.SetLogger(hertzLogger) + + return logger, nil +} diff --git a/pkg/telemetry/telemetry.go b/pkg/telemetry/telemetry.go new file mode 100644 index 00000000..e1c99428 --- /dev/null +++ b/pkg/telemetry/telemetry.go @@ -0,0 +1,40 @@ +package telemetry + +import "go.uber.org/zap" + +type Config struct { + LogConfig *LogConfig `yaml:"logging"` + // TODO: add OTEL config +} + +type Telemetry struct { + Config *Config + Logger *zap.Logger + // TODO: add OTEL meter, tracer +} + +func DefaultConfig() *Config { + return &Config{ + LogConfig: DefaultLogConfig(), + } +} + +func NewTelemetry(cfg *Config) (*Telemetry, error) { + logger, err := NewLogger(cfg.LogConfig) + if err != nil { + return nil, err + } + + return &Telemetry{ + Config: cfg, + Logger: logger, + }, nil +} + +// NewTelemetryMock returns Telemetry object with NoOp loggers, meters, tracers +func NewTelemetryMock() *Telemetry { + return &Telemetry{ + Config: DefaultConfig(), + Logger: zap.NewNop(), + } +} diff --git a/pkg/version.go b/pkg/version.go new file mode 100644 index 00000000..5ad11003 --- /dev/null +++ b/pkg/version.go @@ -0,0 +1,29 @@ +package pkg + +import ( + "fmt" + "runtime" +) + +// version must be set from the contents of VERSION file by go build's +// -X main.version= option in the Makefile. +var version = "devel" + +// commitSha will be the hash that the binary was built from +// and will be populated by the Makefile +var commitSha = "unknown" + +// buildDate captures the time when the build happened +var buildDate = "unknown" + +var FullVersion string + +func init() { + FullVersion = fmt.Sprintf( + "%s (commit: %s, runtime: %s, buildDate: %s)", + version, + commitSha, + runtime.Version(), + buildDate, + ) +}