Skip to content

Commit

Permalink
fix: the load env logic for db creds injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Paras-Wednesday committed Sep 26, 2024
1 parent 5aba242 commit ab71495
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 37 deletions.
72 changes: 46 additions & 26 deletions internal/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package config
import (
"encoding/json"
"fmt"
"go-template/pkg/utl/convert"
"log"
"os"
"path/filepath"
"runtime"
"strconv"

"github.com/joho/godotenv"

"go-template/pkg/utl/convert"
)

func GetString(key string) string {
Expand Down Expand Up @@ -57,9 +58,12 @@ func FileName() string {
}

func LoadEnv() error {
const (
localEnvFile = "local"
)
_, filename, _, ok := runtime.Caller(0)
if !ok {
return fmt.Errorf("Error getting current file path")
return fmt.Errorf("error getting current file path")
}

prefix := fmt.Sprintf("%s/", filepath.Join(filepath.Dir(filename), "../../"))
Expand All @@ -71,45 +75,61 @@ func LoadEnv() error {

envName := os.Getenv("ENVIRONMENT_NAME")
if envName == "" {
envName = "local"
envName = localEnvFile
}
log.Println("envName: " + envName)

envVarInjection := GetBool("ENV_INJECTION")
if !envVarInjection || envName == "local" {
if !envVarInjection || envName == localEnvFile {
err = godotenv.Load(fmt.Sprintf("%s.env.%s", prefix, envName))

if err != nil {
fmt.Printf(".env.%s\n", envName)
log.Println(err)
return err
return fmt.Errorf("failed to load env for environment %q file: %w", envName, err)
}
fmt.Println("loaded", fmt.Sprintf("%s.env.%s", prefix, envName))
return nil
}

dbCredsInjected := GetBool("COPILOT_DB_CREDS_VIA_SECRETS_MANAGER")

// except for local environment the db creds should be
// injected through the secret manager
if envName != localEnvFile && !dbCredsInjected {
return fmt.Errorf("db creds should be injected through secret manager")
}

// if db creds are injected, extract those
if dbCredsInjected {
type copilotSecrets struct {
Username string `json:"username"`
Host string `json:"host"`
DBName string `json:"dbname"`
Password string `json:"password"`
Port int `json:"port"`
}
secrets := &copilotSecrets{}
return extractDBCredsFromSecret()
}
// otherwise
return nil
}

err := json.Unmarshal([]byte(os.Getenv("DB_SECRET")), secrets)
if err != nil {
return err
}
// extractDBCredsFromSecret helper function to extract db secret
func extractDBCredsFromSecret() error {
type copilotSecrets struct {
Username string `json:"username"`
Host string `json:"host"`
DBName string `json:"dbname"`
Password string `json:"password"`
Port int `json:"port"`
}
secrets, dbSecret := &copilotSecrets{}, os.Getenv("DB_SECRET")

if dbSecret == "" {
return fmt.Errorf("'DB_SECRET' environment var is not set or is empty")
}

os.Setenv("PSQL_DBNAME", secrets.DBName)
os.Setenv("PSQL_HOST", secrets.Host)
os.Setenv("PSQL_PASS", secrets.Password)
os.Setenv("PSQL_PORT", strconv.Itoa(secrets.Port))
os.Setenv("PSQL_USER", secrets.Username)
err := json.Unmarshal([]byte(dbSecret), secrets)
if err != nil {
return fmt.Errorf("couldn't unmarshal db secret: %w", err)
}
return fmt.Errorf("COPILOT_DB_CREDS_VIA_SECRETS_MANAGER should have had a value")

os.Setenv("PSQL_DBNAME", secrets.DBName)
os.Setenv("PSQL_HOST", secrets.Host)
os.Setenv("PSQL_PASS", secrets.Password)
os.Setenv("PSQL_PORT", strconv.Itoa(secrets.Port))
os.Setenv("PSQL_USER", secrets.Username)

return nil
}
151 changes: 140 additions & 11 deletions internal/config/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package config_test

import (
"fmt"
. "go-template/internal/config"
"os"
"testing"

"github.com/stretchr/testify/assert"

. "go-template/internal/config"
)

func TestGetString(t *testing.T) {
Expand Down Expand Up @@ -231,6 +232,7 @@ func loadLocalEnv() envTestCaseArgs {
},
}
}

func errorOnEnvInjectionAndCopilotFalse() envTestCaseArgs {
return envTestCaseArgs{
name: "Error when ENV_INJECTION and COPILOT_DB_CREDS_VIA_SECRETS_MANAGER false",
Expand All @@ -254,10 +256,67 @@ func errorOnEnvInjectionAndCopilotFalse() envTestCaseArgs {
}
}

func loadOnDbCredsInjected(username string, host string, dbname string, password string, port string) envTestCaseArgs {
func loadOnCopilotTrueAndLocalEnv() envTestCaseArgs {
return envTestCaseArgs{
name: "dbCredsInjected True",
name: "Load local without copilot",
wantErr: false,
args: args{
setEnv: []keyValueArgs{
{
key: "ENV_INJECTION",
value: "true",
},
{
key: "ENVIRONMENT_NAME",
value: "local",
},
{
key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER",
value: "false",
},
},
},
}
}

func errorOnDbCredsInjectedInDevEnv() envTestCaseArgs {
return envTestCaseArgs{
name: "dbCredsInjected True for `develop` environment,with invalid json in db secret",
wantErr: true,
args: args{
setEnv: []keyValueArgs{
{
key: "ENV_INJECTION",
value: "true",
},
{
key: "ENVIRONMENT_NAME",
value: "develop",
},
{
key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER",
value: "true",
},
{
key: "DB_SECRET",
value: "invalid json",
},
},
expectedKeyValues: []keyValueArgs{},
},
}
}

func loadOnDbCredsInjectedInDevEnv(
username string,
host string,
dbname string,
password string,
port string,
) envTestCaseArgs {
return envTestCaseArgs{
name: "dbCredsInjected True for `develop` environment,and should parse the db secret",
wantErr: false,
args: args{
setEnv: []keyValueArgs{
{
Expand All @@ -274,27 +333,76 @@ func loadOnDbCredsInjected(username string, host string, dbname string, password
},
{
key: "DB_SECRET",
value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": "%s", "dbname": "%s", "host": "%s"}`,
value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": %s, "dbname": "%s", "host": "%s"}`,
username,
password,
port,
host,
dbname),
dbname,
host),
},
},
expectedKeyValues: []keyValueArgs{
{
key: "PSQL_USER",
value: username,
},
{
key: "PSQL_PORT",
value: port,
},
{
key: "PSQL_PASS",
value: password,
},
{
key: "PSQL_HOST",
value: host,
},
{
key: "PSQL_DBNAME",
value: dbname,
},
},
},
}
}


Check failure on line 370 in internal/config/env_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not `gofmt`-ed with `-s` (gofmt)
func loadDbCredsInjectedInLocalEnv(
username string,
host string,
dbname string,
password string,
port string,
) envTestCaseArgs {
return envTestCaseArgs{
name: "dbCredsInjected True for local environment,and should parse the db secret",
wantErr: false,
args: args{
setEnv: []keyValueArgs{
{
key: "PSQL_HOST",
value: host,
key: "ENV_INJECTION",
value: "true",
},
{
key: "ENVIRONMENT_NAME",
value: "local",
},
{
key: "COPILOT_DB_CREDS_VIA_SECRETS_MANAGER",
value: "true",
},
{
key: "DB_SECRET",
value: fmt.Sprintf(`{"username": "%s", "password": "%s", "port": %s, "dbname": "%s", "host": "%s"}`,
username,
password,
port,
dbname,
host),
},
},
expectedKeyValues: []keyValueArgs{
{
key: "PSQL_USER",
value: username,
Expand All @@ -303,31 +411,51 @@ func loadOnDbCredsInjected(username string, host string, dbname string, password
key: "PSQL_PORT",
value: port,
},
{
key: "PSQL_PASS",
value: password,
},
{
key: "PSQL_HOST",
value: host,
},
{
key: "PSQL_DBNAME",
value: dbname,
},
},
},
}
}

func errorOnWrongEnvName() envTestCaseArgs {
return envTestCaseArgs{
name: "Failed to load env",
name: "Failed to load env for local1",
wantErr: true,
args: args{
setEnv: []keyValueArgs{
{
key: "ENVIRONMENT_NAME",
value: "local1",
},
{
key: "ENV_INJECTION",
value: "false",
},
},
},
}
}

func getTestCases(username string, host string, dbname string, password string, port string) []envTestCaseArgs {
return []envTestCaseArgs{
loadLocalEnvIfNoEnvName(),
loadLocalEnv(),
errorOnEnvInjectionAndCopilotFalse(),
loadOnDbCredsInjected(username, host, dbname, password, port),
loadOnCopilotTrueAndLocalEnv(),
errorOnDbCredsInjectedInDevEnv(),
loadOnDbCredsInjectedInDevEnv(username, host, dbname, password, port),
loadDbCredsInjectedInLocalEnv(username, host, dbname, password, port),
errorOnWrongEnvName(),
}
}
Expand All @@ -342,7 +470,8 @@ func testLoadEnv(t *testing.T, tt struct {
name string
wantErr bool
args args
}) {
},
) {
if err := LoadEnv(); (err != nil) != tt.wantErr {
t.Errorf("LoadEnv() error = %v, wantErr %v", err, tt.wantErr)
} else {
Expand Down

0 comments on commit ab71495

Please sign in to comment.