Skip to content

Commit

Permalink
*: pass rsa key size correctly (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Oct 18, 2022
1 parent c0ab948 commit 9917e02
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
14 changes: 9 additions & 5 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath strin
}
}

certPEM, keyPEM, caPEM, err := CreateTempTLS(expiration)
certPEM, keyPEM, caPEM, err := CreateTempTLS(rsaKeySize, expiration)
if err != nil {
return err
}
Expand Down Expand Up @@ -99,7 +99,11 @@ func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, m
return nil
}

func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
func CreateTempTLS(rsaKeySize int, expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
if rsaKeySize < 1024 {
rsaKeySize = 1024
}

// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Expand All @@ -120,7 +124,7 @@ func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *byt
}

// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
caPrivKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -159,7 +163,7 @@ func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *byt
KeyUsage: x509.KeyUsageDigitalSignature,
}

certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
certPrivKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -190,7 +194,7 @@ func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *byt

// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
certPEM, keyPEM, caPEM, uerr := CreateTempTLS(DefaultCertExpiration)
certPEM, keyPEM, caPEM, uerr := CreateTempTLS(0, DefaultCertExpiration)
if uerr != nil {
err = uerr
return
Expand Down
28 changes: 28 additions & 0 deletions lib/util/security/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package security

import (
"testing"

"github.com/stretchr/testify/require"
)

func BenchmarkCreateTLS(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, _, err := CreateTempTLS(0, DefaultCertExpiration)
require.Nil(b, err)
}
}
8 changes: 4 additions & 4 deletions pkg/manager/cert/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ func TestReloadCerts(t *testing.T) {
dir := t.TempDir()
lg := logger.CreateLoggerForTest(t)
sqlCfg := &config.TLSConfig{AutoCerts: true}
err := security.AutoTLS(lg, sqlCfg, true, dir, "sql", 1024)
err := security.AutoTLS(lg, sqlCfg, true, dir, "sql", 0)
require.NoError(t, err)
clusterCfg := &config.TLSConfig{AutoCerts: true}
err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 1024)
err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 0)
require.NoError(t, err)

cfg := &config.Config{
Expand Down Expand Up @@ -80,10 +80,10 @@ func TestReloadCerts(t *testing.T) {

var before = getAllCertificates(t, certMgr)
sqlCfg = &config.TLSConfig{AutoCerts: true}
err = security.AutoTLS(lg, sqlCfg, true, dir, "sql", 1024)
err = security.AutoTLS(lg, sqlCfg, true, dir, "sql", 0)
require.NoError(t, err)
clusterCfg = &config.TLSConfig{AutoCerts: true}
err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 1024)
err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 0)
require.NoError(t, err)

timer := time.NewTimer(10 * time.Second)
Expand Down

0 comments on commit 9917e02

Please sign in to comment.