Skip to content

Add unit tests for internal/service module #3781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/services/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ func CheckVersions(ctx context.Context, fsys afero.Fs) []imageVersion {
func listRemoteImages(ctx context.Context, projectRef string) map[string]string {
linked := make(map[string]string, 4)
var wg sync.WaitGroup
var mu sync.Mutex

wg.Add(1)
go func() {
defer wg.Done()
if version, err := tenant.GetDatabaseVersion(ctx, projectRef); err == nil {
mu.Lock()
linked[utils.Config.Db.Image] = version
mu.Unlock()
}
}()
keys, err := tenant.GetApiKeys(ctx, projectRef)
Expand All @@ -84,13 +88,17 @@ func listRemoteImages(ctx context.Context, projectRef string) map[string]string
go func() {
defer wg.Done()
if version, err := api.GetGotrueVersion(ctx); err == nil {
mu.Lock()
linked[utils.Config.Auth.Image] = version
mu.Unlock()
}
}()
go func() {
defer wg.Done()
if version, err := api.GetPostgrestVersion(ctx); err == nil {
mu.Lock()
linked[utils.Config.Api.Image] = version
mu.Unlock()
}
}()
wg.Wait()
Expand Down
335 changes: 335 additions & 0 deletions internal/services/services_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
package services

import (
"context"
"testing"

"github.com/h2non/gock"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/internal/utils/flags"
)

// TestRun tests the main Run function that displays service versions
func TestRun(t *testing.T) {
// Test case: Display service versions without linked project
t.Run("displays service versions without linked project", func(t *testing.T) {
// Setup: Create an in-memory filesystem
fsys := afero.NewMemMapFs()

// Execute: Call the Run function
err := Run(context.Background(), fsys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make these tests pass, you need to mock all api requests using gock.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it — I’ll use gock to mock the API calls and make the tests pass. Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Use a fake but valid access token for tests, with #nosec G101 to avoid gosec false positives
  • Fix race condition in listRemoteImages by protecting map writes with sync.Mutex
  • All tests now pass with -race and are CI-friendly


// Verify: Check that no error occurred
assert.NoError(t, err)
})

// Test case: Display service versions with linked project
t.Run("displays service versions with linked project", func(t *testing.T) {
// Setup: Create an in-memory filesystem and simulate linked project
fsys := afero.NewMemMapFs()

// Create project config file with project reference
projectRef := "abcdefghijklmnopqrst"
require.NoError(t, utils.InitConfig(utils.InitParams{
ProjectId: projectRef,
}, fsys))
flags.ProjectRef = projectRef

// Mock all API requests
defer gock.OffAll()

// Mock API keys
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + projectRef + "/api-keys").
Reply(200).
JSON([]map[string]string{{"name": "anon", "api_key": "test-key"}})

// Mock database version
gock.New(utils.DefaultApiHost).
Get("/v1/projects").
Reply(200).
JSON([]map[string]interface{}{
{
"id": projectRef,
"database": map[string]string{"version": "1.0.0"},
},
})

// Mock auth version
gock.New("https://" + utils.GetSupabaseHost(projectRef)).
Get("/auth/v1/health").
Reply(200).
JSON(map[string]string{"version": "2.0.0"})

// Mock postgrest version
gock.New("https://" + utils.GetSupabaseHost(projectRef)).
Get("/rest/v1/").
Reply(200).
JSON(map[string]interface{}{
"swagger": "2.0",
"info": map[string]string{"version": "3.0.0"},
})

// Execute: Call the Run function
err := Run(context.Background(), fsys)

// Verify: Check that no error occurred
assert.NoError(t, err)
})
}

// TestCheckVersions tests the function that checks local and remote service versions
func TestCheckVersions(t *testing.T) {
// Test case: Check local versions only
t.Run("checks local versions", func(t *testing.T) {
// Setup: Create an in-memory filesystem
fsys := afero.NewMemMapFs()

// Execute: Call CheckVersions function
versions := CheckVersions(context.Background(), fsys)

// Verify: Check that versions are returned and contain required fields
assert.NotEmpty(t, versions)
for _, v := range versions {
assert.NotEmpty(t, v.Name, "Service name should not be empty")
assert.NotEmpty(t, v.Local, "Local version should not be empty")
}
})

// Test case: Check both local and remote versions
t.Run("checks local and remote versions", func(t *testing.T) {
// Setup: Create an in-memory filesystem and simulate linked project
fsys := afero.NewMemMapFs()

// Create project config file with project reference
projectRef := "abcdefghijklmnopqrst"
require.NoError(t, utils.InitConfig(utils.InitParams{
ProjectId: projectRef,
}, fsys))

// Set project reference in flags
flags.ProjectRef = projectRef

// Execute: Call CheckVersions function
versions := CheckVersions(context.Background(), fsys)

// Verify: Check that versions are returned and contain required fields
assert.NotEmpty(t, versions)
for _, v := range versions {
assert.NotEmpty(t, v.Name, "Service name should not be empty")
assert.NotEmpty(t, v.Local, "Local version should not be empty")
// Remote version might be empty if not linked
}
})

// Test case: Handle version mismatch
t.Run("handles version mismatch", func(t *testing.T) {
// Setup: Create an in-memory filesystem and simulate linked project
fsys := afero.NewMemMapFs()

// Create project config file with project reference
projectRef := "abcdefghijklmnopqrst"
require.NoError(t, utils.InitConfig(utils.InitParams{
ProjectId: projectRef,
}, fsys))

// Set project reference in flags
flags.ProjectRef = projectRef

// Execute: Call CheckVersions function
versions := CheckVersions(context.Background(), fsys)

// Verify: Check that versions are returned and contain required fields
assert.NotEmpty(t, versions)
for _, v := range versions {
assert.NotEmpty(t, v.Name, "Service name should not be empty")
assert.NotEmpty(t, v.Local, "Local version should not be empty")
// Remote version might be empty if not linked
}
})

// Test case: Verify version comparison logic
t.Run("compares local and remote versions correctly", func(t *testing.T) {
fsys := afero.NewMemMapFs()
projectRef := "abcdefghijklmnopqrst"

// Setup: Create linked project with specific versions
require.NoError(t, utils.InitConfig(utils.InitParams{
ProjectId: projectRef,
}, fsys))
flags.ProjectRef = projectRef

// Mock remote versions
// #nosec G101 -- This is a fake token for testing purposes only
token := "sbp_0102030405060708091011121314151617181920"
require.NoError(t, utils.SaveAccessToken(token, fsys))

defer gock.OffAll()
// Mock API responses with specific versions
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + projectRef + "/api-keys").
Reply(200).
JSON([]map[string]string{{"name": "anon", "api_key": "test-key"}})

gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + projectRef + "/database/version").
Reply(200).
JSON(map[string]string{"version": "1.0.0"})

versions := CheckVersions(context.Background(), fsys)

// Verify version comparison logic
for _, v := range versions {
assert.NotEmpty(t, v.Name)
assert.NotEmpty(t, v.Local)
// Check if remote versions are properly assigned
}
})
}

// TestListRemoteImages tests the function that retrieves remote service versions
func TestListRemoteImages(t *testing.T) {
// Test case: Get remote versions successfully
t.Run("gets remote versions successfully", func(t *testing.T) {
// Setup: Create context and project reference
ctx := context.Background()
projectRef := "abcdefghijklmnopqrst"

// Setup: Create in-memory filesystem
fsys := afero.NewMemMapFs()

// Setup: Create access token file with valid format
// #nosec G101 -- This is a fake token for testing purposes only
token := "sbp_0102030405060708091011121314151617181920"
require.NoError(t, utils.SaveAccessToken(token, fsys))

// Setup: Mock API responses
defer gock.OffAll()

// Mock API keys response
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + projectRef + "/api-keys").
Reply(200).
JSON([]map[string]string{
{"name": "anon", "api_key": "test-key"},
})

gock.New(utils.DefaultApiHost).
Get("/v1/projects").
Reply(200).
JSON([]map[string]interface{}{
{
"id": projectRef,
"database": map[string]string{
"version": "1.0.0",
},
},
})

gock.New("https://" + utils.GetSupabaseHost(projectRef)).
Get("/auth/v1/health").
Reply(200).
JSON(map[string]string{"version": "2.0.0"})

// Mock postgrest version response (endpoint = /rest/v1/ sur le host du projet)
gock.New("https://" + utils.GetSupabaseHost(projectRef)).
Get("/rest/v1/").
Reply(200).
JSON(map[string]interface{}{
"swagger": "2.0",
"info": map[string]string{
"version": "3.0.0",
},
})

// Execute: Call listRemoteImages function
remoteVersions := listRemoteImages(ctx, projectRef)

// Verify: Check that remote versions are returned
assert.NotNil(t, remoteVersions)
assert.NotEmpty(t, remoteVersions)

// Verify: Check that all expected versions are present
for _, version := range remoteVersions {
assert.NotEmpty(t, version)
}
})

// Test case: Handle API errors
t.Run("handles API errors", func(t *testing.T) {
// Setup: Create context and project reference
ctx := context.Background()
projectRef := "invalid-project"

// Setup: Create in-memory filesystem
fsys := afero.NewMemMapFs()

// Setup: Create access token file with valid format
// #nosec G101 -- This is a fake token for testing purposes only
token := "sbp_0102030405060708091011121314151617181920"
require.NoError(t, utils.SaveAccessToken(token, fsys))

// Setup: Mock API error response
defer gock.OffAll()
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + projectRef + "/api-keys").
Reply(404)

// Execute: Call listRemoteImages function
remoteVersions := listRemoteImages(ctx, projectRef)

// Verify: Check that remote versions are empty
assert.Empty(t, remoteVersions)
})

// Test case: Handle missing access token
t.Run("handles missing access token", func(t *testing.T) {
// Setup: Create context and project reference
ctx := context.Background()
projectRef := "abcdefghijklmnopqrst"

// Setup: Create in-memory filesystem without access token
afero.NewMemMapFs()

// Execute: Call listRemoteImages function
remoteVersions := listRemoteImages(ctx, projectRef)

// Verify: Check that remote versions are empty
assert.Empty(t, remoteVersions)
})
}

// TestSuggestUpdateCmd tests the function that generates update command suggestions
func TestSuggestUpdateCmd(t *testing.T) {
// Test case: Generate update command for version mismatch
t.Run("generates update command for version mismatch", func(t *testing.T) {
// Setup: Create map of service images with version mismatches
serviceImages := map[string]string{
"service1": "v1.0.0",
"service2": "v2.0.0",
}

// Execute: Call suggestUpdateCmd function
cmd := suggestUpdateCmd(serviceImages)

// Verify: Check that command contains expected content
assert.Contains(t, cmd, "WARNING:")
assert.Contains(t, cmd, "supabase link")
})

// Test case: Handle empty service images
t.Run("handles empty service images", func(t *testing.T) {
// Setup: Create empty map of service images
serviceImages := map[string]string{}

// Execute: Call suggestUpdateCmd function
cmd := suggestUpdateCmd(serviceImages)

// Verify: Check that command contains expected content
assert.Contains(t, cmd, "WARNING:")
assert.Contains(t, cmd, "supabase link")
})
}