Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle extra padding if key length > 2048 #648

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ import (
"os"
"strings"
"syscall"
"time"

"golang.org/x/crypto/ssh/terminal"

"github.com/gopcua/opcua"
"github.com/gopcua/opcua/debug"
"github.com/gopcua/opcua/errors"
"github.com/gopcua/opcua/ua"
"github.com/gopcua/opcua/uatest"
)

var (
Expand Down Expand Up @@ -119,7 +121,16 @@ func clientOptsFromFlags(endpoints []*ua.EndpointDescription) []opcua.Option {
var cert []byte
if *gencert || (*certfile != "" && *keyfile != "") {
if *gencert {
generate_cert(*appuri, 2048, *certfile, *keyfile)
certPEM, keyPEM, err := uatest.GenerateCert(*appuri, 2048, 24*time.Hour)
if err != nil {
log.Fatalf("failed to generate cert: %v", err)
}
if err := os.WriteFile(*certfile, certPEM, 0644); err != nil {
log.Fatalf("failed to write %s: %v", *certfile, err)
}
if err := os.WriteFile(*keyfile, keyPEM, 0644); err != nil {
log.Fatalf("failed to write %s: %v", *keyfile, err)
}
}
debug.Printf("Loading cert/key from %s/%s", *certfile, *keyfile)
c, err := tls.LoadX509KeyPair(*certfile, *keyfile)
Expand Down
18 changes: 16 additions & 2 deletions uasc/secure_channel_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,19 @@ func (c *channelInstance) signAndEncrypt(m *Message, b []byte) ([]byte, error) {
var encryptedLength int
if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric {
plaintextBlockSize := c.algo.PlaintextBlockSize()
paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + c.algo.SignatureLength() + 1) % plaintextBlockSize)
extraPadding := c.algo.RemoteSignatureLength() > 256
paddingBytes := 1
if extraPadding {
paddingBytes = 2
}
paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + c.algo.SignatureLength() + paddingBytes) % plaintextBlockSize)

for i := 0; i <= paddingLength; i++ {
b = append(b, byte(paddingLength))
}
if extraPadding {
b = append(b, byte(paddingLength>>8))
}
encryptedLength = ((len(b[headerLength:]) + c.algo.SignatureLength()) / plaintextBlockSize) * c.algo.BlockSize()
} else { // MessageSecurityModeSign
encryptedLength = len(b[headerLength:]) + c.algo.SignatureLength()
Expand Down Expand Up @@ -235,7 +243,13 @@ func (c *channelInstance) verifyAndDecrypt(m *MessageChunk, r []byte) ([]byte, e

var paddingLength int
if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric {
paddingLength = int(messageToVerify[len(messageToVerify)-1]) + 1
paddingLength = int(messageToVerify[len(messageToVerify)-1])
if c.algo.SignatureLength() > 256 {
paddingLength <<= 8
paddingLength += int(messageToVerify[len(messageToVerify)-2])
paddingLength += 1
}
paddingLength += 1
}

b = messageToVerify[headerLength : len(messageToVerify)-paddingLength]
Expand Down
161 changes: 161 additions & 0 deletions uasc/secure_channel_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package uasc

import (
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"math"
"testing"
"time"

"github.com/gopcua/opcua/id"
"github.com/gopcua/opcua/ua"
"github.com/gopcua/opcua/uapolicy"
"github.com/gopcua/opcua/uatest"

"github.com/pascaldekloe/goe/verify"
)
Expand Down Expand Up @@ -145,3 +152,157 @@ func TestNewRequestMessage(t *testing.T) {
})
}
}

func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) {
buildSecPolicy := func(bits int, uri string) *uapolicy.EncryptionAlgorithm {
t.Helper()

certPEM, keyPEM, err := uatest.GenerateCert("localhost", bits, 24*time.Hour)
block, _ := pem.Decode(keyPEM)
pk, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
certblock, _ := pem.Decode(certPEM)
remoteX509Cert, err := x509.ParseCertificate(certblock.Bytes)
if err != nil {
t.Fatal(err)
}
remoteKey := remoteX509Cert.PublicKey.(*rsa.PublicKey)
alg, _ := uapolicy.Asymmetric(uri, pk, remoteKey)
return alg
}

getConfig := func(uri string) *Config {
t.Helper()

if uri == ua.SecurityPolicyURINone {
return &Config{SecurityMode: ua.MessageSecurityModeNone}
}
return &Config{SecurityMode: ua.MessageSecurityModeSignAndEncrypt}
}

tests := []struct {
name string
c *channelInstance
m *Message
b []byte
}{}

for _, uri := range ua.SecurityPolicyURIs {
for i, keyLength := range []int{2048, 4096} {
if i == 1 && (uri == ua.SecurityPolicyURIBasic128Rsa15 || uri == ua.SecurityPolicyURIBasic256) {
continue
}
tests = append(tests, struct {
name string
c *channelInstance
m *Message
b []byte
}{fmt.Sprintf("encrypt/decrypt: bits: %d uri: %s", keyLength, uri),
&channelInstance{
sc: &SecureChannel{cfg: getConfig(uri)},
algo: buildSecPolicy(keyLength, uri),
},
&Message{
MessageHeader: &MessageHeader{
Header: &Header{
MessageType: MessageTypeOpenSecureChannel,
ChunkType: ChunkTypeFinal,
},
AsymmetricSecurityHeader: &AsymmetricSecurityHeader{
SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo",
},
SequenceHeader: &SequenceHeader{
SequenceNumber: 1,
RequestID: 1,
},
},
},
[]byte{ // OpenSecureChannelRequest
// Message Header
// MessageType: OPN
0x4f, 0x50, 0x4e,
// Chunk Type: Final
0x46,
// MessageSize: 131
0x8E, 0x00, 0x00, 0x00,
// SecureChannelID: 0
0x00, 0x00, 0x00, 0x00,
// AsymmetricSecurityHeader
// SecurityPolicyURILength
0x2e, 0x00, 0x00, 0x00,
// SecurityPolicyURI
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67,
0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78,
0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50,
0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75,
0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69,
0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f,
// SenderCertificate
0xff, 0xff, 0xff, 0xff,
// ReceiverCertificateThumbprint
0xff, 0xff, 0xff, 0xff,
// Sequence Header
// SequenceNumber
0x01, 0x00, 0x00, 0x00,
// RequestID
0x01, 0x00, 0x00, 0x00,
// TypeID
0x01, 0x00, 0xbe, 0x01,

// RequestHeader
// - AuthenticationToken
0x00, 0x00,
// - Timestamp
0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01,
// - RequestHandle
0x01, 0x00, 0x00, 0x00,
// - ReturnDiagnostics
0xff, 0x03, 0x00, 0x00,
// - AuditEntry
0xff, 0xff, 0xff, 0xff,
// - TimeoutHint
0x00, 0x00, 0x00, 0x00,
// - AdditionalHeader
// - TypeID
0x00, 0x00,
// - EncodingMask
0x00,
// ClientProtocolVersion
0x00, 0x00, 0x00, 0x00,
// SecurityTokenRequestType
0x00, 0x00, 0x00, 0x00,
// MessageSecurityMode
0x01, 0x00, 0x00, 0x00,
// ClientNonce
0xff, 0xff, 0xff, 0xff,
// RequestedLifetime
0x80, 0x8d, 0x5b, 0x00,
}})
}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := tt.c.signAndEncrypt(tt.m, tt.b)
if err != nil {
t.Fatalf("error: message encrypt: %v", err)
}

m := new(MessageChunk)
if _, err := m.Decode(cipher); err != nil {
t.Fatalf("error: message decode: %v", err)
}
plain, err := tt.c.verifyAndDecrypt(m, cipher)
if err != nil {
t.Fatalf("error: message decrypt: %v", err)
}

headerLength := 12 + m.AsymmetricSecurityHeader.Len()
if got, want := plain, tt.b[headerLength:]; !bytes.Equal(got, want) {
t.Fatalf("got bytes %v want %v", got, want)
}
})
}
}
49 changes: 9 additions & 40 deletions examples/crypto/generate_cert.go → uatest/generate_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
// Generate a self-signed X.509 certificate for a TLS server. Outputs to
// 'cert.pem' and 'key.pem' and will overwrite existing files.

// Based on src/crypto/tls/generate_cert.go from the Go SDK
// Modified by the Gopcua Authors for use in creating an OPC-UA compliant client certificate

package main
package uatest

import (
"crypto/ecdsa"
Expand All @@ -17,7 +18,6 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log"
"math/big"
"net"
"net/url"
Expand All @@ -26,33 +26,26 @@ import (
"time"
)

func generate_cert(host string, rsaBits int, certFile, keyFile string) {

func GenerateCert(host string, rsaBits int, validFor time.Duration) (certPEM, keyPEM []byte, err error) {
if len(host) == 0 {
log.Fatalf("Missing required host parameter")
return nil, nil, fmt.Errorf("missing required host parameter")
}
if rsaBits == 0 {
rsaBits = 2048
}
if len(certFile) == 0 {
certFile = "cert.pem"
}
if len(keyFile) == 0 {
keyFile = "key.pem"
}

priv, err := rsa.GenerateKey(rand.Reader, rsaBits)
if err != nil {
log.Fatalf("failed to generate private key: %s", err)
return nil, nil, fmt.Errorf("failed to generate private key: %s", err)
}

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour) // 1 year
notAfter := notBefore.Add(validFor)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
log.Fatalf("failed to generate serial number: %s", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %s", err)
}

template := x509.Certificate{
Expand Down Expand Up @@ -82,34 +75,10 @@ func generate_cert(host string, rsaBits int, certFile, keyFile string) {

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
log.Fatalf("Failed to create certificate: %s", err)
}

certOut, err := os.Create(certFile)
if err != nil {
log.Fatalf("failed to open %s for writing: %s", certFile, err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
log.Fatalf("failed to write data to %s: %s", certFile, err)
}
if err := certOut.Close(); err != nil {
log.Fatalf("error closing %s: %s", certFile, err)
}
log.Printf("wrote %s\n", certFile)

keyOut, err := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
log.Printf("failed to open %s for writing: %s", keyFile, err)
return
}
if err := pem.Encode(keyOut, pemBlockForKey(priv)); err != nil {
log.Fatalf("failed to write data to %s: %s", keyFile, err)
}
if err := keyOut.Close(); err != nil {
log.Fatalf("error closing %s: %s", keyFile, err)
return nil, nil, fmt.Errorf("Failed to create certificate: %s", err)
}
log.Printf("wrote %s\n", keyFile)

return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), pem.EncodeToMemory(pemBlockForKey(priv)), nil
}

func publicKey(priv interface{}) interface{} {
Expand Down