diff --git a/README.md b/README.md index 0a75fc9..b272770 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,8 @@ table: users # user: root # password: pwd # database: e5sub +# ssl_mode is only required when the database requires a SSL connection (e.g. TiDB Cloud) +# ssl_mode: PREFERRED sqlite: db: data.db ``` diff --git a/README_zhCN.md b/README_zhCN.md index 9c6527a..a55f5f2 100644 --- a/README_zhCN.md +++ b/README_zhCN.md @@ -126,6 +126,8 @@ table: users # user: root # password: pwd # database: e5sub +# ssl_mode仅在数据库需要SSL链接时才需要配置(如连接TiDB Cloud) +# ssl_mode: PREFERRED sqlite: db: data.db ``` diff --git a/config/config.go b/config/config.go index a21734f..73478be 100644 --- a/config/config.go +++ b/config/config.go @@ -50,11 +50,13 @@ func Init() { switch DB { case "mysql": Mysql = mysqlConfig{ - Host: viper.GetString("mysql.host"), - Port: viper.GetInt("mysql.port"), - User: viper.GetString("mysql.user"), - Password: viper.GetString("mysql.password"), - DB: viper.GetString("mysql.database"), + Host: viper.GetString("mysql.host"), + Port: viper.GetInt("mysql.port"), + User: viper.GetString("mysql.user"), + Password: viper.GetString("mysql.password"), + DB: viper.GetString("mysql.database"), + SSLMode: viper.GetString("mysql.ssl_mode"), + EnabledTLSProtocols: viper.GetString("mysql.enabled_tls_protocols"), } case "sqlite": Sqlite = sqliteConfig{ diff --git a/config/model.go b/config/model.go index d9ce26e..9b2861a 100644 --- a/config/model.go +++ b/config/model.go @@ -19,9 +19,11 @@ type sqliteConfig struct { DB string `json:"db,omitempty"` } type mysqlConfig struct { - Host string `json:"host,omitempty"` - Port int `json:"port,omitempty"` - User string `json:"user,omitempty"` - Password string `json:"password,omitempty"` - DB string `json:"db,omitempty"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + DB string `json:"db,omitempty"` + SSLMode string `json:"ssl_mode,omitempty"` + EnabledTLSProtocols string `json:"enabled_tls_protocols,omitempty"` } diff --git a/db/db.go b/db/db.go index 7dee403..bdcc60e 100644 --- a/db/db.go +++ b/db/db.go @@ -21,13 +21,19 @@ func Init() { switch config.DB { case "mysql": - dial = mysql.Open(fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", + var dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.Mysql.User, config.Mysql.Password, config.Mysql.Host, config.Mysql.Port, config.Mysql.DB, - )) + ) + + if config.Mysql.SSLMode != "" { + dsn += "&tls=" + config.Mysql.SSLMode + } + + dial = mysql.Open(dsn) case "sqlite": dial = sqlite.Open(config.Sqlite.DB) }