-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
277 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package disk | ||
|
||
import ( | ||
"os" | ||
"path/filepath" | ||
|
||
"github.com/danjacques/gofslock/fslock" | ||
"github.com/pingcap/tidb/util/logutil" | ||
"go.uber.org/zap" | ||
) | ||
|
||
const ( | ||
lockFile = "_dir.lock" | ||
) | ||
|
||
// InitializeTempDir initializes the temp directory. | ||
func InitializeTempDir(tempDir string) error { | ||
_, err := os.Stat(tempDir) | ||
if err != nil && !os.IsExist(err) { | ||
err = os.MkdirAll(tempDir, 0750) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
_, err = fslock.Lock(filepath.Join(tempDir, lockFile)) | ||
if err != nil { | ||
switch err { | ||
case fslock.ErrLockHeld: | ||
logutil.BgLogger().Error("The current temporary storage dir has been occupied by another instance, "+ | ||
"check tmp-storage-path config and make sure they are different.", zap.String("TempStoragePath", tempDir), zap.Error(err)) | ||
default: | ||
logutil.BgLogger().Error("Failed to acquire exclusive lock on the temporary storage dir.", zap.String("TempStoragePath", tempDir), zap.Error(err)) | ||
} | ||
return err | ||
} | ||
|
||
subDirs, err := os.ReadDir(tempDir) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// If it exists others files except lock file, creates another goroutine to clean them. | ||
if len(subDirs) > 2 { | ||
go func() { | ||
for _, subDir := range subDirs { | ||
// Do not remove the lock file. | ||
switch subDir.Name() { | ||
case lockFile: | ||
continue | ||
} | ||
err := os.RemoveAll(filepath.Join(tempDir, subDir.Name())) | ||
if err != nil { | ||
logutil.BgLogger().Warn("Remove temporary file error", | ||
zap.String("tempStorageSubDir", filepath.Join(tempDir, subDir.Name())), zap.Error(err)) | ||
} | ||
} | ||
}() | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
package security | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"go.uber.org/zap" | ||
"math/big" | ||
"os" | ||
"path/filepath" | ||
"time" | ||
|
||
"github.com/pingcap/errors" | ||
"github.com/pingcap/tidb/util/logutil" | ||
) | ||
|
||
func CreateServerTLSConfig(ca, key, cert, minTLSVer, path string, rsaKeySize int) (tlsConfig *tls.Config, err error) { | ||
if len(cert) == 0 || len(key) == 0 { | ||
cert = filepath.Join(path, "/cert.pem") | ||
key = filepath.Join(path, "/key.pem") | ||
err = createTLSCertificates(cert, key, rsaKeySize) | ||
if err != nil { | ||
logutil.BgLogger().Warn("TLS Certificate creation failed", zap.Error(err)) | ||
return | ||
} | ||
} | ||
|
||
var tlsCert tls.Certificate | ||
tlsCert, err = tls.LoadX509KeyPair(cert, key) | ||
if err != nil { | ||
logutil.BgLogger().Warn("load x509 failed", zap.Error(err)) | ||
err = errors.Trace(err) | ||
return | ||
} | ||
|
||
var minTLSVersion uint16 = tls.VersionTLS11 | ||
switch minTLSVer { | ||
case "TLSv1.0": | ||
minTLSVersion = tls.VersionTLS10 | ||
case "TLSv1.1": | ||
minTLSVersion = tls.VersionTLS11 | ||
case "TLSv1.2": | ||
minTLSVersion = tls.VersionTLS12 | ||
case "TLSv1.3": | ||
minTLSVersion = tls.VersionTLS13 | ||
case "": | ||
default: | ||
logutil.BgLogger().Warn( | ||
"Invalid TLS version, using default instead", | ||
zap.String("tls-version", minTLSVer), | ||
) | ||
} | ||
if minTLSVersion < tls.VersionTLS12 { | ||
logutil.BgLogger().Warn( | ||
"Minimum TLS version allows pre-TLSv1.2 protocols, this is not recommended", | ||
) | ||
} | ||
|
||
// Try loading CA cert. | ||
clientAuthPolicy := tls.NoClientCert | ||
var certPool *x509.CertPool | ||
if len(ca) > 0 { | ||
var caCert []byte | ||
caCert, err = os.ReadFile(ca) | ||
if err != nil { | ||
logutil.BgLogger().Warn("read file failed", zap.Error(err)) | ||
err = errors.Trace(err) | ||
return | ||
} | ||
certPool = x509.NewCertPool() | ||
if certPool.AppendCertsFromPEM(caCert) { | ||
clientAuthPolicy = tls.VerifyClientCertIfGiven | ||
} | ||
} | ||
|
||
// This excludes ciphers listed in tls.InsecureCipherSuites() and can be used to filter out more | ||
var cipherSuites []uint16 | ||
var cipherNames []string | ||
for _, sc := range tls.CipherSuites() { | ||
switch sc.ID { | ||
case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: | ||
logutil.BgLogger().Info("Disabling weak cipherSuite", zap.String("cipherSuite", sc.Name)) | ||
default: | ||
cipherNames = append(cipherNames, sc.Name) | ||
cipherSuites = append(cipherSuites, sc.ID) | ||
} | ||
|
||
} | ||
logutil.BgLogger().Info("Enabled ciphersuites", zap.Strings("cipherNames", cipherNames)) | ||
|
||
/* #nosec G402 */ | ||
tlsConfig = &tls.Config{ | ||
Certificates: []tls.Certificate{tlsCert}, | ||
ClientCAs: certPool, | ||
ClientAuth: clientAuthPolicy, | ||
MinVersion: minTLSVersion, | ||
CipherSuites: cipherSuites, | ||
} | ||
return | ||
} | ||
|
||
func createTLSCertificates(certpath string, keypath string, rsaKeySize int) error { | ||
privkey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
certValidity := 90 * 24 * time.Hour // 90 days | ||
notBefore := time.Now() | ||
notAfter := notBefore.Add(certValidity) | ||
hostname, err := os.Hostname() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
template := x509.Certificate{ | ||
Subject: pkix.Name{ | ||
CommonName: "TiDB_Server_Auto_Generated_Server_Certificate", | ||
}, | ||
SerialNumber: big.NewInt(1), | ||
NotBefore: notBefore, | ||
NotAfter: notAfter, | ||
DNSNames: []string{hostname}, | ||
} | ||
|
||
// DER: Distinguished Encoding Rules, this is the ASN.1 encoding rule of the certificate. | ||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, privkey) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
certOut, err := os.Create(certpath) | ||
if err != nil { | ||
return err | ||
} | ||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { | ||
return err | ||
} | ||
if err := certOut.Close(); err != nil { | ||
return err | ||
} | ||
|
||
keyOut, err := os.OpenFile(keypath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
privBytes, err := x509.MarshalPKCS8PrivateKey(privkey) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { | ||
return err | ||
} | ||
|
||
if err := keyOut.Close(); err != nil { | ||
return err | ||
} | ||
|
||
logutil.BgLogger().Info("TLS Certificates created", zap.String("cert", certpath), zap.String("key", keypath), | ||
zap.Duration("validity", certValidity), zap.Int("rsaKeySize", rsaKeySize)) | ||
return nil | ||
} | ||
|
||
// ToTLSConfig generates tls's config based on security section of the config. | ||
func CreateClusterTLSConfig(sslCA, sslKey, sslCert string) (tlsConfig *tls.Config, err error) { | ||
if len(sslCA) != 0 { | ||
certPool := x509.NewCertPool() | ||
// Create a certificate pool from the certificate authority | ||
var ca []byte | ||
ca, err = os.ReadFile(sslCA) | ||
if err != nil { | ||
err = errors.Errorf("could not read ca certificate: %s", err) | ||
return | ||
} | ||
// Append the certificates from the CA | ||
if !certPool.AppendCertsFromPEM(ca) { | ||
err = errors.New("failed to append ca certs") | ||
return | ||
} | ||
tlsConfig = &tls.Config{ | ||
RootCAs: certPool, | ||
ClientCAs: certPool, | ||
} | ||
|
||
if len(sslCert) != 0 && len(sslKey) != 0 { | ||
getCert := func() (*tls.Certificate, error) { | ||
// Load the client certificates from disk | ||
cert, err := tls.LoadX509KeyPair(sslCert, sslKey) | ||
if err != nil { | ||
return nil, errors.Errorf("could not load client key pair: %s", err) | ||
} | ||
return &cert, nil | ||
} | ||
// pre-test cert's loading. | ||
if _, err = getCert(); err != nil { | ||
return | ||
} | ||
tlsConfig.GetClientCertificate = func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, err error) { | ||
return getCert() | ||
} | ||
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) { | ||
return getCert() | ||
} | ||
} | ||
} | ||
return | ||
} | ||
|
||
func CreateClientTLSConfig() *tls.Config { | ||
return &tls.Config{} | ||
} |