diff --git a/pkg/drivers/mysql/mysql.go b/pkg/drivers/mysql/mysql.go index 335fc714..967476e7 100644 --- a/pkg/drivers/mysql/mysql.go +++ b/pkg/drivers/mysql/mysql.go @@ -5,6 +5,7 @@ import ( cryptotls "crypto/tls" "database/sql" "fmt" + "strings" "github.com/go-sql-driver/mysql" "github.com/k3s-io/kine/pkg/drivers/generic" @@ -137,7 +138,7 @@ func createDBIfNotExist(dataSourceName string) error { if err != nil { return err } - dbName := config.DBName + dbName := quoteIdentifier(config.DBName) db, err := sql.Open("mysql", dataSourceName) if err != nil { @@ -188,3 +189,7 @@ func prepareDSN(dataSourceName string, tlsConfig *cryptotls.Config) (string, err return parsedDSN, nil } + +func quoteIdentifier(id string) string { + return "`" + strings.ReplaceAll(id, "`", "``") + "`" +} diff --git a/pkg/drivers/pgsql/pgsql.go b/pkg/drivers/pgsql/pgsql.go index 4fe3d297..0030bc75 100644 --- a/pkg/drivers/pgsql/pgsql.go +++ b/pkg/drivers/pgsql/pgsql.go @@ -123,7 +123,7 @@ func createDBIfNotExist(dataSourceName string) error { return err } - dbName := strings.SplitN(u.Path, "/", 2)[1] + dbName := quoteIdentifier(strings.SplitN(u.Path, "/", 2)[1]) db, err := sql.Open("postgres", dataSourceName) if err != nil { return err @@ -208,3 +208,7 @@ func prepareDSN(dataSourceName string, tlsInfo tls.Config) (string, error) { u.RawQuery = params.Encode() return u.String(), nil } + +func quoteIdentifier(id string) string { + return pq.QuoteIdentifier(id) +}