diff --git a/examples/crypto/crypto.go b/examples/crypto/crypto.go index 9a48ce7b..874f28bf 100644 --- a/examples/crypto/crypto.go +++ b/examples/crypto/crypto.go @@ -15,6 +15,7 @@ import ( "os" "strings" "syscall" + "time" "golang.org/x/crypto/ssh/terminal" @@ -22,6 +23,7 @@ import ( "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" + "github.com/gopcua/opcua/uatest" ) var ( @@ -119,7 +121,16 @@ func clientOptsFromFlags(endpoints []*ua.EndpointDescription) []opcua.Option { var cert []byte if *gencert || (*certfile != "" && *keyfile != "") { if *gencert { - generate_cert(*appuri, 2048, *certfile, *keyfile) + certPEM, keyPEM, err := uatest.GenerateCert(*appuri, 2048, 24*time.Hour) + if err != nil { + log.Fatalf("failed to generate cert: %v", err) + } + if err := os.WriteFile(*certfile, certPEM, 0644); err != nil { + log.Fatalf("failed to write %s: %v", *certfile, err) + } + if err := os.WriteFile(*keyfile, keyPEM, 0644); err != nil { + log.Fatalf("failed to write %s: %v", *keyfile, err) + } } debug.Printf("Loading cert/key from %s/%s", *certfile, *keyfile) c, err := tls.LoadX509KeyPair(*certfile, *keyfile) diff --git a/uasc/secure_channel_instance.go b/uasc/secure_channel_instance.go index 8950c6ff..bf123eb8 100644 --- a/uasc/secure_channel_instance.go +++ b/uasc/secure_channel_instance.go @@ -170,11 +170,19 @@ func (c *channelInstance) signAndEncrypt(m *Message, b []byte) ([]byte, error) { var encryptedLength int if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { plaintextBlockSize := c.algo.PlaintextBlockSize() - paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + c.algo.SignatureLength() + 1) % plaintextBlockSize) + extraPadding := c.algo.RemoteSignatureLength() > 256 + paddingBytes := 1 + if extraPadding { + paddingBytes = 2 + } + paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + c.algo.SignatureLength() + paddingBytes) % plaintextBlockSize) for i := 0; i <= paddingLength; i++ { b = append(b, byte(paddingLength)) } + if extraPadding { + b = append(b, byte(paddingLength>>8)) + } encryptedLength = ((len(b[headerLength:]) + c.algo.SignatureLength()) / plaintextBlockSize) * c.algo.BlockSize() } else { // MessageSecurityModeSign encryptedLength = len(b[headerLength:]) + c.algo.SignatureLength() @@ -235,7 +243,13 @@ func (c *channelInstance) verifyAndDecrypt(m *MessageChunk, r []byte) ([]byte, e var paddingLength int if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { - paddingLength = int(messageToVerify[len(messageToVerify)-1]) + 1 + paddingLength = int(messageToVerify[len(messageToVerify)-1]) + if c.algo.SignatureLength() > 256 { + paddingLength <<= 8 + paddingLength += int(messageToVerify[len(messageToVerify)-2]) + paddingLength += 1 + } + paddingLength += 1 } b = messageToVerify[headerLength : len(messageToVerify)-paddingLength] diff --git a/uasc/secure_channel_test.go b/uasc/secure_channel_test.go index 0aadf46c..1149d8e9 100644 --- a/uasc/secure_channel_test.go +++ b/uasc/secure_channel_test.go @@ -1,12 +1,19 @@ package uasc import ( + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" "math" "testing" "time" "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" + "github.com/gopcua/opcua/uapolicy" + "github.com/gopcua/opcua/uatest" "github.com/pascaldekloe/goe/verify" ) @@ -145,3 +152,157 @@ func TestNewRequestMessage(t *testing.T) { }) } } + +func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { + buildSecPolicy := func(bits int, uri string) *uapolicy.EncryptionAlgorithm { + t.Helper() + + certPEM, keyPEM, err := uatest.GenerateCert("localhost", bits, 24*time.Hour) + block, _ := pem.Decode(keyPEM) + pk, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + t.Fatal(err) + } + certblock, _ := pem.Decode(certPEM) + remoteX509Cert, err := x509.ParseCertificate(certblock.Bytes) + if err != nil { + t.Fatal(err) + } + remoteKey := remoteX509Cert.PublicKey.(*rsa.PublicKey) + alg, _ := uapolicy.Asymmetric(uri, pk, remoteKey) + return alg + } + + getConfig := func(uri string) *Config { + t.Helper() + + if uri == ua.SecurityPolicyURINone { + return &Config{SecurityMode: ua.MessageSecurityModeNone} + } + return &Config{SecurityMode: ua.MessageSecurityModeSignAndEncrypt} + } + + tests := []struct { + name string + c *channelInstance + m *Message + b []byte + }{} + + for _, uri := range ua.SecurityPolicyURIs { + for i, keyLength := range []int{2048, 4096} { + if i == 1 && (uri == ua.SecurityPolicyURIBasic128Rsa15 || uri == ua.SecurityPolicyURIBasic256) { + continue + } + tests = append(tests, struct { + name string + c *channelInstance + m *Message + b []byte + }{fmt.Sprintf("encrypt/decrypt: bits: %d uri: %s", keyLength, uri), + &channelInstance{ + sc: &SecureChannel{cfg: getConfig(uri)}, + algo: buildSecPolicy(keyLength, uri), + }, + &Message{ + MessageHeader: &MessageHeader{ + Header: &Header{ + MessageType: MessageTypeOpenSecureChannel, + ChunkType: ChunkTypeFinal, + }, + AsymmetricSecurityHeader: &AsymmetricSecurityHeader{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + SequenceHeader: &SequenceHeader{ + SequenceNumber: 1, + RequestID: 1, + }, + }, + }, + []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: OPN + 0x4f, 0x50, 0x4e, + // Chunk Type: Final + 0x46, + // MessageSize: 131 + 0x8E, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // AsymmetricSecurityHeader + // SecurityPolicyURILength + 0x2e, 0x00, 0x00, 0x00, + // SecurityPolicyURI + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67, + 0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50, + 0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75, + 0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69, + 0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f, + // SenderCertificate + 0xff, 0xff, 0xff, 0xff, + // ReceiverCertificateThumbprint + 0xff, 0xff, 0xff, 0xff, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xbe, 0x01, + + // RequestHeader + // - AuthenticationToken + 0x00, 0x00, + // - Timestamp + 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01, + // - RequestHandle + 0x01, 0x00, 0x00, 0x00, + // - ReturnDiagnostics + 0xff, 0x03, 0x00, 0x00, + // - AuditEntry + 0xff, 0xff, 0xff, 0xff, + // - TimeoutHint + 0x00, 0x00, 0x00, 0x00, + // - AdditionalHeader + // - TypeID + 0x00, 0x00, + // - EncodingMask + 0x00, + // ClientProtocolVersion + 0x00, 0x00, 0x00, 0x00, + // SecurityTokenRequestType + 0x00, 0x00, 0x00, 0x00, + // MessageSecurityMode + 0x01, 0x00, 0x00, 0x00, + // ClientNonce + 0xff, 0xff, 0xff, 0xff, + // RequestedLifetime + 0x80, 0x8d, 0x5b, 0x00, + }}) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := tt.c.signAndEncrypt(tt.m, tt.b) + if err != nil { + t.Fatalf("error: message encrypt: %v", err) + } + + m := new(MessageChunk) + if _, err := m.Decode(cipher); err != nil { + t.Fatalf("error: message decode: %v", err) + } + plain, err := tt.c.verifyAndDecrypt(m, cipher) + if err != nil { + t.Fatalf("error: message decrypt: %v", err) + } + + headerLength := 12 + m.AsymmetricSecurityHeader.Len() + if got, want := plain, tt.b[headerLength:]; !bytes.Equal(got, want) { + t.Fatalf("got bytes %v want %v", got, want) + } + }) + } +} diff --git a/examples/crypto/generate_cert.go b/uatest/generate_cert.go similarity index 65% rename from examples/crypto/generate_cert.go rename to uatest/generate_cert.go index d2142b17..0e2cdc51 100644 --- a/examples/crypto/generate_cert.go +++ b/uatest/generate_cert.go @@ -5,9 +5,10 @@ // Generate a self-signed X.509 certificate for a TLS server. Outputs to // 'cert.pem' and 'key.pem' and will overwrite existing files. +// Based on src/crypto/tls/generate_cert.go from the Go SDK // Modified by the Gopcua Authors for use in creating an OPC-UA compliant client certificate -package main +package uatest import ( "crypto/ecdsa" @@ -17,7 +18,6 @@ import ( "crypto/x509/pkix" "encoding/pem" "fmt" - "log" "math/big" "net" "net/url" @@ -26,33 +26,26 @@ import ( "time" ) -func generate_cert(host string, rsaBits int, certFile, keyFile string) { - +func GenerateCert(host string, rsaBits int, validFor time.Duration) (certPEM, keyPEM []byte, err error) { if len(host) == 0 { - log.Fatalf("Missing required host parameter") + return nil, nil, fmt.Errorf("missing required host parameter") } if rsaBits == 0 { rsaBits = 2048 } - if len(certFile) == 0 { - certFile = "cert.pem" - } - if len(keyFile) == 0 { - keyFile = "key.pem" - } priv, err := rsa.GenerateKey(rand.Reader, rsaBits) if err != nil { - log.Fatalf("failed to generate private key: %s", err) + return nil, nil, fmt.Errorf("failed to generate private key: %s", err) } notBefore := time.Now() - notAfter := notBefore.Add(365 * 24 * time.Hour) // 1 year + notAfter := notBefore.Add(validFor) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - log.Fatalf("failed to generate serial number: %s", err) + return nil, nil, fmt.Errorf("failed to generate serial number: %s", err) } template := x509.Certificate{ @@ -82,34 +75,10 @@ func generate_cert(host string, rsaBits int, certFile, keyFile string) { derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) if err != nil { - log.Fatalf("Failed to create certificate: %s", err) - } - - certOut, err := os.Create(certFile) - if err != nil { - log.Fatalf("failed to open %s for writing: %s", certFile, err) - } - if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - log.Fatalf("failed to write data to %s: %s", certFile, err) - } - if err := certOut.Close(); err != nil { - log.Fatalf("error closing %s: %s", certFile, err) - } - log.Printf("wrote %s\n", certFile) - - keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) - if err != nil { - log.Printf("failed to open %s for writing: %s", keyFile, err) - return - } - if err := pem.Encode(keyOut, pemBlockForKey(priv)); err != nil { - log.Fatalf("failed to write data to %s: %s", keyFile, err) - } - if err := keyOut.Close(); err != nil { - log.Fatalf("error closing %s: %s", keyFile, err) + return nil, nil, fmt.Errorf("Failed to create certificate: %s", err) } - log.Printf("wrote %s\n", keyFile) + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), pem.EncodeToMemory(pemBlockForKey(priv)), nil } func publicKey(priv interface{}) interface{} {