diff --git a/client.go b/client.go index 4e088f2f..b03145cd 100644 --- a/client.go +++ b/client.go @@ -8,7 +8,6 @@ import ( "context" "crypto/rand" "fmt" - "io" "log" "reflect" "sort" @@ -106,9 +105,6 @@ type Client struct { subscriptions map[uint32]*Subscription subMux sync.RWMutex - //cancelMonitor cancels the monitorChannel goroutine - cancelMonitor context.CancelFunc - // once initializes session once sync.Once } @@ -168,71 +164,33 @@ func (c *Client) Dial(ctx context.Context) error { if c.sechan != nil { return errors.Errorf("secure channel already connected") } + var err error c.conn, err = uacp.Dial(ctx, c.endpointURL) if err != nil { return err } + c.sechan, err = uasc.NewSecureChannel(c.endpointURL, c.conn, c.cfg) if err != nil { _ = c.conn.Close() return err } - // Issue #313: decouple the dial context from the monitor context - // mctx must *not* be a child context of 'ctx'. Otherwise, the - // monitor go routine terminates whenever the dial context is done - // which may get triggered unexpectedly by a timer context. - var mctx context.Context - mctx, c.cancelMonitor = context.WithCancel(context.Background()) - go c.monitorChannel(mctx) - return c.openSecureChannel(mctx, c.sechan.Open) -} - -func (c *Client) openSecureChannel(ctx context.Context, open func() error) error { - if err := open(); err != nil { - c.cancelMonitor() - _ = c.conn.Close() - c.sechan = nil - return err - } - return c.scheduleRenewingToken(ctx) -} - -func (c *Client) monitorChannel(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - default: - msg := c.sechan.Receive(ctx) - if msg.Err != nil { - if msg.Err == io.EOF { - debug.Printf("Connection closed") - } else { - debug.Printf("Received error: %s", msg.Err) - } - // todo (dh): apart from the above message, we're ignoring this error because there is nothing watching it - // I'd prefer to have a way to return the error to the upper application. - return - } - debug.Printf("Received unsolicited message from server: %T", msg.V) - } - } + return c.sechan.Open(ctx) } // Close closes the session and the secure channel. func (c *Client) Close() error { - if c.sechan == nil { - return ua.StatusBadServerNotConnected - } + defer c.conn.Close() + // try to close the session but ignore any error // so that we close the underlying channel and connection. _ = c.CloseSession() - if c.cancelMonitor != nil { - c.cancelMonitor() - } - return c.sechan.Close() + + _ = c.sechan.Close() + + return nil } var errNotConnected = errors.New("not connected") @@ -757,25 +715,6 @@ func (c *Client) HistoryReadRawModified(nodes []*ua.HistoryReadValueID, details return res, err } -func (c *Client) scheduleRenewingToken(ctx context.Context) error { - if c.sechan == nil { - return ua.StatusBadServerNotConnected - } - timer := time.NewTimer(time.Duration(0.75*float64(c.sechan.Lifetime())) * time.Millisecond) // 0.75 is from Part 4, Section 5.5.2.1 - - go func() { - select { - case <-ctx.Done(): - timer.Stop() - case <-timer.C: - debug.Printf("renewing security token...") - // Ignore the error. openSecureChannel will close the connection on error and the user will surely notice - _ = c.openSecureChannel(ctx, c.sechan.Renew) - } - }() - return nil -} - // safeAssign implements a type-safe assign from T to *T. func safeAssign(t, ptrT interface{}) error { if reflect.TypeOf(t) != reflect.TypeOf(ptrT).Elem() { diff --git a/uacp/conn.go b/uacp/conn.go index 09d4a196..03b2798f 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -292,6 +292,8 @@ const hdrlen = 8 // The size of b must be at least ReceiveBufSize. Otherwise, // the function returns an error. func (c *Conn) Receive() ([]byte, error) { + // TODO(kung-foo): allow user-specified buffer + // TODO(kung-foo): sync.Pool b := make([]byte, c.ack.ReceiveBufSize) if _, err := io.ReadFull(c, b[:hdrlen]); err != nil { diff --git a/uapolicy/securitypolicy.go b/uapolicy/securitypolicy.go index d2410e4c..3b309338 100644 --- a/uapolicy/securitypolicy.go +++ b/uapolicy/securitypolicy.go @@ -8,7 +8,9 @@ package uapolicy import ( + "crypto/rand" "crypto/rsa" + "io" "sort" "github.com/gopcua/opcua/errors" @@ -145,6 +147,17 @@ func (e *EncryptionAlgorithm) NonceLength() int { return e.nonceLength } +func (e *EncryptionAlgorithm) MakeNonce() ([]byte, error) { + b := make([]byte, e.NonceLength()) + // note: we use `rand.Reader` instead of `rand.Read(...)` to ensure that we don't accidentally switch to using + // math/rand (which has a default, fixed seed). Only crypto/rand exposes a global `io.Reader` var. + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return nil, err + } + return b, nil +} + // EncryptionURI returns the URI for the encryption algorithm as defined // by the OPC-UA profiles in Part 7 func (e *EncryptionAlgorithm) EncryptionURI() string { diff --git a/uasc/message_test.go b/uasc/message_test.go index b7a9992a..1153b612 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -22,11 +22,13 @@ func TestMessage(t *testing.T) { cfg: &Config{ SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", }, - requestID: 1, - sequenceNumber: 1, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, securityTokenID: 0, } - m := s.newMessage( + m := instance.newMessage( &ua.OpenSecureChannelRequest{ RequestHeader: &ua.RequestHeader{ AuthenticationToken: ua.NewTwoByteNodeID(0), @@ -41,6 +43,7 @@ func TestMessage(t *testing.T) { RequestedLifetime: 6000000, }, id.OpenSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), ) // set message size manually, since it is computed in Encode @@ -110,18 +113,21 @@ func TestMessage(t *testing.T) { // RequestedLifetime 0x80, 0x8d, 0x5b, 0x00, }, - }, { + }, + { Name: "MSG", Struct: func() interface{} { s := &SecureChannel{ cfg: &Config{ SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", }, - requestID: 1, - sequenceNumber: 1, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, securityTokenID: 0, } - m := s.newMessage( + m := instance.newMessage( &ua.GetEndpointsRequest{ RequestHeader: &ua.RequestHeader{ AuthenticationToken: ua.NewTwoByteNodeID(0), @@ -133,6 +139,7 @@ func TestMessage(t *testing.T) { EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", }, id.GetEndpointsRequest_Encoding_DefaultBinary, + s.nextRequestID(), ) // set message size manually, since it is computed in Encode @@ -185,11 +192,13 @@ func TestMessage(t *testing.T) { cfg: &Config{ SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", }, - requestID: 1, - sequenceNumber: 1, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, securityTokenID: 0, } - m := s.newMessage( + m := instance.newMessage( &ua.CloseSecureChannelRequest{ RequestHeader: &ua.RequestHeader{ AuthenticationToken: ua.NewTwoByteNodeID(0), @@ -200,6 +209,7 @@ func TestMessage(t *testing.T) { }, }, id.CloseSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), ) // set message size manually, since it is computed in Encode diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 50ce0d6b..2751f6ee 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -6,7 +6,6 @@ package uasc import ( "context" - "crypto/rand" "crypto/rsa" "crypto/x509" "io" @@ -17,21 +16,17 @@ import ( "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" - "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uacp" "github.com/gopcua/opcua/uapolicy" ) const ( - secureChannelCreated int32 = iota - secureChannelOpen - secureChannelClosed timeoutLeniency = 250 * time.Millisecond MaxTimeout = math.MaxUint32 * time.Millisecond ) -type Response struct { +type response struct { ReqID uint32 SCID uint32 V interface{} @@ -39,67 +34,56 @@ type Response struct { } type SecureChannel struct { - EndpointURL string + endpointURL string - // c is the uacp connection. + // c is the uacp connection c *uacp.Conn // cfg is the configuration for the secure channel. cfg *Config - // reqhdr is the header for the next request. - reqhdr *ua.RequestHeader + // time returns the current time. When not set it defaults to time.Now(). + time func() time.Time - // state is the state of the secure channel. - // Must be accessed with atomic.LoadInt32/StoreInt32 - state int32 + // closing is channel used to indicate to go routines that the secure channel is closing + closing chan struct{} - // mu guards handler which contains the response channels - // for the outstanding requests. The key is the request - // handle which is part of the Request and Response headers. - mu sync.Mutex - handler map[uint32]chan Response + // closingMu is used to protect the _changing_ of the mutex + // i.e. when we _read_ from the closing chan we acquire a read lock, and when in `reset`, we acquire a write lock + closingMu sync.RWMutex - chunks map[uint32][]*MessageChunk + // startDispatcher ensures only one dispatcher is running + startDispatcher sync.Once - enc *uapolicy.EncryptionAlgorithm + // requestID is a "global" counter shared between multiple channels and tokens + requestID uint32 + requestIDMu sync.Mutex - // time returns the current time. When not set it defaults to time.Now(). - time func() time.Time + // instances maps secure channel IDs to a list to channel states + instances map[uint32][]*channelInstance + activeInstance *channelInstance + instancesMu sync.Mutex - // The lifetime of the SecurityToken in milliseconds. The UTC expiration time for the token - // may be calculated by adding the lifetime to the createdAt time. - lifetime uint32 - - // secureChannelID is a unique identifier for the SecureChannel assigned by the Server. - // If a Server receives a SecureChannelId which it does not recognize it shall return an - // appropriate transport layer error. - // - // When a Server starts the first SecureChannelId used should be a value that is likely to - // be unique after each restart. This ensures that a Server restart does not cause - // previously connected Clients to accidentally ‘reuse’ SecureChannels that did not belong - // to them. - secureChannelID uint32 - - // sequenceNumber is a monotonically increasing sequence number assigned by the sender to each - // MessageChunk sent over the SecureChannel. - sequenceNumber uint32 - - // requestID is an identifier assigned by the Client to OPC UA request Message. All MessageChunks - // for the request and the associated response use the same identifier - requestID uint32 - - // securityTokenID is a unique identifier for the SecureChannel SecurityToken used to secure the Message. - // This identifier is returned by the Server in an OpenSecureChannel response Message. - // If a Server receives a TokenId which it does not recognize it shall return an appropriate - // transport layer error. - securityTokenID uint32 + // handles maps request IDs to response channels + handlers map[uint32]chan *response + handlersMu sync.Mutex + + // chunks maintains a temporary list of chunks for a given request ID + chunks map[uint32][]*MessageChunk + chunksMu sync.Mutex + + // openingInstance is a temporary var that allows the dispatcher know how to handle a open channel request + // note: we only allow a single "open" request in flight at any point in time. The mutex is held for the entire + // duration of the "open" request. + openingInstance *channelInstance + openingMu sync.Mutex } func NewSecureChannel(endpoint string, c *uacp.Conn, cfg *Config) (*SecureChannel, error) { if c == nil { return nil, errors.Errorf("no connection") } + if cfg == nil { return nil, errors.Errorf("no secure channel config") } @@ -114,258 +98,203 @@ func NewSecureChannel(endpoint string, c *uacp.Conn, cfg *Config) (*SecureChanne } // Force the security mode to None if the policy is also None + // TODO: I don't like that a SecureChannel changes the incoming config if cfg.SecurityPolicyURI == ua.SecurityPolicyURINone { cfg.SecurityMode = ua.MessageSecurityModeNone } - return &SecureChannel{ - EndpointURL: endpoint, + s := &SecureChannel{ + endpointURL: endpoint, c: c, cfg: cfg, - reqhdr: &ua.RequestHeader{ - TimeoutHint: uint32(cfg.RequestTimeout / time.Millisecond), - AdditionalHeader: ua.NewExtensionObject(nil), - }, - state: secureChannelCreated, - handler: make(map[uint32]chan Response), - chunks: make(map[uint32][]*MessageChunk), - requestID: cfg.RequestIDSeed, - }, nil -} + requestID: cfg.RequestIDSeed, + } -func (s *SecureChannel) LocalEndpoint() string { - return s.EndpointURL -} + s.reset() -func (s *SecureChannel) Lifetime() uint32 { - return s.lifetime + return s, nil } -func (s *SecureChannel) setState(n int32) { - atomic.StoreInt32(&s.state, n) +func (s *SecureChannel) reset() { + s.closingMu.Lock() + defer s.closingMu.Unlock() + + // note: we _don't_ reset s.requestID + s.closing = make(chan struct{}) + s.startDispatcher = sync.Once{} + s.instances = make(map[uint32][]*channelInstance) + s.chunks = make(map[uint32][]*MessageChunk) + s.handlers = make(map[uint32]chan *response) + s.activeInstance = nil + s.openingInstance = nil } -func (s *SecureChannel) hasState(n int32) bool { - return atomic.LoadInt32(&s.state) == n +func (s *SecureChannel) getActiveChannelInstance() (*channelInstance, error) { + s.instancesMu.Lock() + defer s.instancesMu.Unlock() + if s.activeInstance == nil { + return nil, errors.Errorf("sechan: secure channel not open.") + } + return s.activeInstance, nil } -// SendRequest sends the service request and calls h with the response. -func (s *SecureChannel) SendRequest(req ua.Request, authToken *ua.NodeID, h func(interface{}) error) error { - return s.SendRequestWithTimeout(req, authToken, s.cfg.RequestTimeout, h) -} +func (s *SecureChannel) dispatcher() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -// SendRequestWithTimeout sends the service request and calls h with the response with a specific timeout. -func (s *SecureChannel) SendRequestWithTimeout(req ua.Request, authToken *ua.NodeID, timeout time.Duration, h func(interface{}) error) error { - respRequired := h != nil + s.closingMu.RLock() + defer s.closingMu.RUnlock() - ch, reqid, err := s.sendAsyncWithTimeout(req, authToken, respRequired, timeout) - if err != nil { - return err - } + for { + select { + case <-s.closing: + return + default: + resp := s.receive(ctx) + if resp.Err == io.EOF { + return + } - if !respRequired { - return nil - } + if resp.Err != nil { + debug.Printf("uasc %d/%d: err: %v", s.c.ID(), resp.ReqID, resp.Err) + } else { + debug.Printf("uasc %d/%d: recv %T", s.c.ID(), resp.ReqID, resp.V) + } - // `+ timeoutLeniency` to give the server a chance to respond to TimeoutHint - timer := time.NewTimer(timeout + timeoutLeniency) - defer timer.Stop() + ch, ok := s.popHandler(resp.ReqID) - select { - case resp := <-ch: - if resp.Err != nil { - if resp.V != nil { - _ = h(resp.V) // ignore result because resp.Err takes precedence + if !ok { + debug.Printf("uasc %d/%d: no handler for %T", s.c.ID(), resp.ReqID, resp.V) + continue + } + + debug.Printf("sending %T to handler\n", resp.V) + select { + case ch <- resp: + default: + // this should never happen since the chan is of size one + debug.Printf("unexpected state. channel write should always succeed.") } - return resp.Err } - return h(resp.V) - case <-timer.C: - s.mu.Lock() - s.popHandlerLock(reqid) - s.mu.Unlock() - return ua.StatusBadTimeout } } -// sendAsyncWithTimeout sends the service request with a specific timeout and -// returns a channel which will receive the response when it arrives. -func (s *SecureChannel) sendAsyncWithTimeout(req ua.Request, authToken *ua.NodeID, respReq bool, timeout time.Duration) (resp chan Response, reqID uint32, err error) { - s.mu.Lock() - defer s.mu.Unlock() +// receive receives message chunks from the secure channel, decodes and forwards +// them to the registered callback channel, if there is one. Otherwise, +// the message is dropped. +func (s *SecureChannel) receive(ctx context.Context) *response { + for { + select { + case <-ctx.Done(): + return &response{Err: ctx.Err()} - // encode the message - m, err := s.newRequestMessage(req, authToken, timeout) - if err != nil { - return nil, 0, err - } - reqid := m.SequenceHeader.RequestID - b, err := m.Encode() - if err != nil { - return nil, reqid, err - } + default: + chunk, err := s.readChunk() + if err == io.EOF { + debug.Printf("uasc readChunk EOF") + return &response{Err: err} + } - // encrypt the message prior to sending it - // if SecurityMode == None, this returns the byte stream untouched - b, err = s.signAndEncrypt(m, b) - if err != nil { - return nil, reqid, err - } + if err != nil { + return &response{Err: err} + } - // send the message - if _, err := s.c.Write(b); err != nil { - return nil, reqid, err - } - debug.Printf("uasc %d/%d: send %T with %d bytes", s.c.ID(), reqid, req, len(b)) + hdr := chunk.Header + reqID := chunk.SequenceHeader.RequestID - // register the handler if a callback was passed - if !respReq { - return nil, 0, nil - } - resp = make(chan Response) - if s.handler[reqid] != nil { - return nil, reqid, errors.Errorf("error: duplicate handler registration for request id %d", reqid) - } - s.handler[reqid] = resp - return resp, reqid, nil -} + resp := &response{ + ReqID: reqID, + SCID: chunk.MessageHeader.Header.SecureChannelID, + } -// New creates a OPC UA Secure Conversation message.New -// MessageType of UASC is determined depending on the type of service given as below. -// -// Service type: OpenSecureChannel => Message type: OPN. -// -// Service type: CloseSecureChannel => Message type: CLO. -// -// Service type: Others => Message type: MSG. -// -func (s *SecureChannel) newMessage(srv interface{}, typeID uint16) *Message { - switch typeID { - case id.OpenSecureChannelRequest_Encoding_DefaultBinary, id.OpenSecureChannelResponse_Encoding_DefaultBinary: - // Do not send the thumbprint for security mode None - // even if we have a certificate. - // - // See https://github.com/gopcua/opcua/issues/259 - thumbprint := s.cfg.Thumbprint - if s.cfg.SecurityMode == ua.MessageSecurityModeNone { - thumbprint = nil - } + debug.Printf("uasc %d/%d: recv %s%c with %d bytes", s.c.ID(), reqID, hdr.MessageType, hdr.ChunkType, hdr.MessageSize) - return &Message{ - MessageHeader: &MessageHeader{ - Header: NewHeader(MessageTypeOpenSecureChannel, ChunkTypeFinal, s.secureChannelID), - AsymmetricSecurityHeader: NewAsymmetricSecurityHeader(s.cfg.SecurityPolicyURI, s.cfg.Certificate, thumbprint), - SequenceHeader: NewSequenceHeader(s.sequenceNumber, s.requestID), - }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), - Service: srv, - } + s.chunksMu.Lock() - case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: - return &Message{ - MessageHeader: &MessageHeader{ - Header: NewHeader(MessageTypeCloseSecureChannel, ChunkTypeFinal, s.secureChannelID), - SymmetricSecurityHeader: NewSymmetricSecurityHeader(s.securityTokenID), - SequenceHeader: NewSequenceHeader(s.sequenceNumber, s.requestID), - }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), - Service: srv, - } + switch hdr.ChunkType { + case 'A': + delete(s.chunks, reqID) + s.chunksMu.Unlock() - default: - return &Message{ - MessageHeader: &MessageHeader{ - Header: NewHeader(MessageTypeMessage, ChunkTypeFinal, s.secureChannelID), - SymmetricSecurityHeader: NewSymmetricSecurityHeader(s.securityTokenID), - SequenceHeader: NewSequenceHeader(s.sequenceNumber, s.requestID), - }, - TypeID: ua.NewFourByteExpandedNodeID(0, typeID), - Service: srv, - } - } -} + msga := new(MessageAbort) + if _, err := msga.Decode(chunk.Data); err != nil { + debug.Printf("conn %d/%d: invalid MSGA chunk. %s", s.c.ID(), reqID, err) + resp.Err = ua.StatusBadDecodingError + return resp + } -func (s *SecureChannel) newRequestMessage(req ua.Request, authToken *ua.NodeID, timeout time.Duration) (*Message, error) { - typeID := ua.ServiceTypeID(req) - if typeID == 0 { - return nil, errors.Errorf("unknown service %T. Did you call register?", req) - } - if authToken == nil { - authToken = ua.NewTwoByteNodeID(0) - } + return &response{ReqID: reqID, Err: ua.StatusCode(msga.ErrorCode)} - s.sequenceNumber++ - if s.sequenceNumber > math.MaxUint32-1023 { - s.sequenceNumber = 1 - } - s.requestID++ - if s.requestID == 0 { - s.requestID = 1 - } - s.reqhdr.RequestHandle++ - if s.reqhdr.RequestHandle == 0 { - s.reqhdr.RequestHandle = 1 - } - s.reqhdr.AuthenticationToken = authToken - s.reqhdr.Timestamp = s.timeNow() - if timeout > 0 && timeout < s.cfg.RequestTimeout { - timeout = s.cfg.RequestTimeout - } - s.reqhdr.TimeoutHint = uint32(timeout / time.Millisecond) - req.SetHeader(s.reqhdr) + case 'C': + s.chunks[reqID] = append(s.chunks[reqID], chunk) + if n := len(s.chunks[reqID]); uint32(n) > s.c.MaxChunkCount() { + delete(s.chunks, reqID) + s.chunksMu.Unlock() + resp.Err = errors.Errorf("too many chunks: %d > %d", n, s.c.MaxChunkCount()) + return resp + } + s.chunksMu.Unlock() + continue + } - // encode the message - return s.newMessage(req, typeID), nil -} + // merge chunks + all := append(s.chunks[reqID], chunk) + delete(s.chunks, reqID) -// SendResponse sends a service response. -// todo(fs): this method is most likely needed for the server and we haven't tested it yet. -// todo(fs): it exists to implement the handleOpenSecureChannelRequest() method during the -// todo(fs): refactor to remove the reflect code. It will likely change. -func (s *SecureChannel) SendResponse(req ua.Response) error { - typeID := ua.ServiceTypeID(req) - if typeID == 0 { - return errors.Errorf("unknown service %T. Did you call register?", req) - } + s.chunksMu.Unlock() - // encode the message - m := s.newMessage(req, typeID) - reqid := m.SequenceHeader.RequestID - b, err := m.Encode() - if err != nil { - return err - } + b, err := mergeChunks(all) + if err != nil { + resp.Err = err + return resp + } - // encrypt the message prior to sending it - // if SecurityMode == None, this returns the byte stream untouched - b, err = s.signAndEncrypt(m, b) - if err != nil { - return err - } + if uint32(len(b)) > s.c.MaxMessageSize() { + resp.Err = errors.Errorf("message too large: %d > %d", uint32(len(b)), s.c.MaxMessageSize()) + return resp + } - // send the message - if _, err := s.c.Write(b); err != nil { - return err - } - debug.Printf("uasc %d/%d: send %T with %d bytes", s.c.ID(), reqid, req, len(b)) + // since we are not decoding the ResponseHeader separately + // we need to drop every message that has an error since we + // cannot get to the RequestHandle in the ResponseHeader. + // To fix this we must a) decode the ResponseHeader separately + // and subsequently remove it and the TypeID from all service + // structs and tests. We also need to add a deadline to all + // handlers and check them periodically to time them out. + _, svc, err := ua.DecodeService(b) + if err != nil { + resp.Err = err + return resp + } + + resp.V = svc + + // If the service status is not OK then bubble + // that error up to the caller. + if r, ok := svc.(ua.Response); ok { + if status := r.Header().ServiceResult; status != ua.StatusOK { + resp.Err = status + return resp + } + } - return nil + return resp + } + } } func (s *SecureChannel) readChunk() (*MessageChunk, error) { // read a full message from the underlying conn. b, err := s.c.Receive() - if err == io.EOF || s.hasState(secureChannelClosed) { + if err == io.EOF || len(b) == 0 { return nil, io.EOF } - if errf, ok := err.(*uacp.Error); ok { - return nil, errf - } + if err != nil { return nil, errors.Errorf("sechan: read header failed: %s %#v", err, err) } - const hdrlen = 12 + const hdrlen = 12 // TODO: move to pkg level const h := new(Header) if _, err := h.Decode(b[:hdrlen]); err != nil { return nil, errors.Errorf("sechan: decode header failed: %s", err) @@ -377,52 +306,39 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) { return nil, errors.Errorf("sechan: decode chunk failed: %s", err) } - // OPN Request, initialize encryption - // todo(dh): How to account for renew requests? + var decryptWith *channelInstance + switch m.MessageType { case "OPN": - debug.Printf("uasc: OPN Request") + debug.Printf("uasc OPN Request") + // Make sure we have a valid security header if m.AsymmetricSecurityHeader == nil { return nil, ua.StatusBadDecodingError // todo(dh): check if this is the correct error } - // Load the remote certificates from the security header, if present - var remoteKey *rsa.PublicKey - if m.SecurityPolicyURI != ua.SecurityPolicyURINone { - remoteKey, err = uapolicy.PublicKey(m.AsymmetricSecurityHeader.SenderCertificate) - if err != nil { - return nil, err - } + if s.openingInstance == nil { + return nil, errors.Errorf("sechan: invalid state. openingInstance is nil.") + } + if m.SecurityPolicyURI != ua.SecurityPolicyURINone { s.cfg.RemoteCertificate = m.AsymmetricSecurityHeader.SenderCertificate debug.Printf("Setting securityPolicy to %s", m.SecurityPolicyURI) } s.cfg.SecurityPolicyURI = m.SecurityPolicyURI - s.requestID = m.RequestID - - s.enc, err = uapolicy.Asymmetric(m.SecurityPolicyURI, s.cfg.LocalKey, remoteKey) - if err != nil { - return nil, err - } + decryptWith = s.openingInstance case "CLO": - if !s.hasState(secureChannelOpen) { - return nil, ua.StatusBadSecureChannelIDInvalid - } - - // We received the close request so no response is necessary. - // Returning io.EOF signals to the calling methods that the channel is to be shut down - s.setState(secureChannelClosed) - return nil, io.EOF - case "MSG": + // nop + default: + return nil, errors.Errorf("sechan: unknown message type: %s", m.MessageType) } - // Decrypts the block and returns data back into m.Data - m.Data, err = s.verifyAndDecrypt(m, b) + // Decrypt the block and put data back into m.Data + m.Data, err = s.verifyAndDecrypt(m, b, decryptWith) if err != nil { return nil, err } @@ -433,250 +349,82 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) { } m.Data = m.Data[n:] - if s.secureChannelID == 0 { - s.secureChannelID = h.SecureChannelID - debug.Printf("uasc %d/%d: set secure channel id to %d", s.c.ID(), m.SequenceHeader.RequestID, s.secureChannelID) - } - return m, nil } -// Receive waits for a complete message to be read from the channel and sends -// it back to the caller. If the caller was initiated from a SendRequest(), the -// message is directed to the registered callback function and Receive() does -// not return. Otherwise, if no handler is detected, the Receive returns with -// the message as a return value. -// -// This behaviour means that anticipated results are automatically directed -// back to their callers but unsolicited messages are sent to the caller of -// Receive() to handle. -func (s *SecureChannel) Receive(ctx context.Context) Response { - for { - select { - case <-ctx.Done(): - return Response{Err: io.EOF} - default: - reqid, svc, err := s.receive(ctx) - if _, ok := err.(*uacp.Error); ok || err == io.EOF { - s.notifyCallers(ctx, err) - return Response{ - ReqID: reqid, - SCID: s.secureChannelID, - V: svc, - Err: err, - } - } - if err != nil { - debug.Printf("uasc %d/%d: err: %v", s.c.ID(), reqid, err) - } else { - debug.Printf("uasc %d/%d: recv %T", s.c.ID(), reqid, svc) - } - - // Revert data race fix from #232 with an additional type check - if _, ok := svc.(ua.Request); ok { - s.requestID = reqid - } - - switch svc.(type) { - case *ua.OpenSecureChannelRequest: - err := s.handleOpenSecureChannelRequest(svc) - if err != nil { - return Response{ - Err: err, - } - } - continue - } - - // check if we have a pending request handler for this response. - s.mu.Lock() - ch, ok := s.handler[reqid] - delete(s.handler, reqid) - s.mu.Unlock() - if !ok { - debug.Printf("uasc %d/%d: no handler for %T, returning result to caller", s.c.ID(), reqid, svc) - return Response{ - ReqID: reqid, - SCID: s.secureChannelID, - V: svc, - Err: err, - } - } - - // send response to caller - go func() { - debug.Printf("sending %T to handler\n", svc) - r := Response{ - ReqID: reqid, - SCID: s.secureChannelID, - V: svc, - Err: err, - } - select { - case <-ctx.Done(): - case ch <- r: - } - }() - } +// verifyAndDecrypt verifies and optionally decrypts a message. if `instance` is given, then it will only use that +// state. Otherwise it will look up states by channel ID and try each. +func (s *SecureChannel) verifyAndDecrypt(m *MessageChunk, b []byte, instance *channelInstance) ([]byte, error) { + if instance != nil { + return instance.verifyAndDecrypt(m, b) } -} - -// receive receives message chunks from the secure channel, decodes and forwards -// them to the registered callback channel, if there is one. Otherwise, -// the message is dropped. -func (s *SecureChannel) receive(ctx context.Context) (uint32, interface{}, error) { - for { - select { - case <-ctx.Done(): - return 0, nil, nil - - default: - chunk, err := s.readChunk() - if err == io.EOF { - return 0, nil, err - } - if errf, ok := err.(*uacp.Error); ok { - s.notifyCallers(ctx, errf) - return 0, nil, errf - } - if err != nil { - debug.Printf("error received while receiving chunk: %s", err) - continue - } - - hdr := chunk.Header - reqid := chunk.SequenceHeader.RequestID - debug.Printf("uasc %d/%d: recv %s%c with %d bytes", s.c.ID(), reqid, hdr.MessageType, hdr.ChunkType, hdr.MessageSize) - - switch hdr.ChunkType { - case 'A': - delete(s.chunks, reqid) - msga := new(MessageAbort) - if _, err := msga.Decode(chunk.Data); err != nil { - debug.Printf("conn %d/%d: invalid MSGA chunk. %s", s.c.ID(), reqid, err) - return reqid, nil, ua.StatusBadDecodingError - } - - return reqid, nil, ua.StatusCode(msga.ErrorCode) - - case 'C': - s.chunks[reqid] = append(s.chunks[reqid], chunk) - if n := len(s.chunks[reqid]); uint32(n) > s.c.MaxChunkCount() { - delete(s.chunks, reqid) - return reqid, nil, errors.Errorf("too many chunks: %d > %d", n, s.c.MaxChunkCount()) - } - continue - } - - // merge chunks - all := append(s.chunks[reqid], chunk) - delete(s.chunks, reqid) - b, err := mergeChunks(all) - if err != nil { - return reqid, nil, errors.Errorf("chunk merge error: %v", err) - } - - if uint32(len(b)) > s.c.MaxMessageSize() { - return reqid, nil, errors.Errorf("message too large: %d > %d", uint32(len(b)), s.c.MaxMessageSize()) - } + instances := s.getInstancesBySecureChannelID(m.MessageHeader.SecureChannelID) + if len(instances) == 0 { + return nil, errors.Errorf("sechan: unable to find instance for SecureChannelID=%d", m.MessageHeader.SecureChannelID) + } - // since we are not decoding the ResponseHeader separately - // we need to drop every message that has an error since we - // cannot get to the RequestHandle in the ResponseHeader. - // To fix this we must a) decode the ResponseHeader separately - // and subsequently remove it and the TypeID from all service - // structs and tests. We also need to add a deadline to all - // handlers and check them periodically to time them out. - _, svc, err := ua.DecodeService(b) - if err != nil { - return reqid, nil, err - } + var ( + err error + verified []byte + ) - // If the service status is not OK then bubble - // that error up to the caller. - if resp, ok := svc.(ua.Response); ok { - status := resp.Header().ServiceResult - debug.Printf("uasc %d/%d: res:%v", s.c.ID(), reqid, status) - if status != ua.StatusOK { - return reqid, svc, status - } - } - return reqid, svc, err + for i := len(instances) - 1; i >= 0; i-- { + // instances[i].Lock() + if verified, err = instances[i].verifyAndDecrypt(m, b); err == nil { + // instances[i].Unlock() + return verified, nil } + // instances[i].Unlock() + debug.Printf("attempting an older channel state...") } -} -func (s *SecureChannel) notifyCallers(ctx context.Context, err error) { - s.mu.Lock() - var reqids []uint32 - for rid := range s.handler { - reqids = append(reqids, rid) - } - for _, rid := range reqids { - s.notifyCallerLock(ctx, rid, nil, err) - } - s.mu.Unlock() + return nil, err } -func (s *SecureChannel) notifyCallerLock(ctx context.Context, reqid uint32, svc interface{}, err error) { - if err != nil { - debug.Printf("uasc %d/%d: %v", s.c.ID(), reqid, err) - } else { - debug.Printf("uasc %d/%d: recv %T", s.c.ID(), reqid, svc) - } +func (s *SecureChannel) getInstancesBySecureChannelID(id uint32) []*channelInstance { + s.instancesMu.Lock() + defer s.instancesMu.Unlock() - // check if we have a pending request handler for this response. - ch := s.popHandlerLock(reqid) - - // no handler -> next response - if ch == nil { - debug.Printf("uasc %d/%d: no handler for %T", s.c.ID(), reqid, svc) - return + instances := s.instances[id] + if instances == nil { + return nil } - // send response to caller - go func() { - r := Response{ - ReqID: reqid, - SCID: s.secureChannelID, - V: svc, - Err: err, - } - select { - case <-ctx.Done(): - case ch <- r: - } - close(ch) - }() + // return a copy of the slice in case a renewal is triggered + cpy := make([]*channelInstance, len(instances)) + copy(cpy, instances) + + return instances } -// Open opens a new secure channel with a server -func (s *SecureChannel) Open() error { - return s.openSecureChannel(ua.SecurityTokenRequestTypeIssue) +func (s *SecureChannel) LocalEndpoint() string { + return s.endpointURL } -func (s *SecureChannel) Renew() error { - return s.openSecureChannel(ua.SecurityTokenRequestTypeRenew) +func (s *SecureChannel) Open(ctx context.Context) error { + return s.open(ctx, nil, ua.SecurityTokenRequestTypeIssue) } -// Close closes an existing secure channel -func (s *SecureChannel) Close() error { - if err := s.closeSecureChannel(); err != nil && err != io.EOF { - debug.Printf("failed to send close secure channel request: %s", err) - } +func (s *SecureChannel) open(ctx context.Context, instance *channelInstance, requestType ua.SecurityTokenRequestType) error { + // TODO: do something with the context + + s.openingMu.Lock() + defer s.openingMu.Unlock() - if err := s.c.Close(); err != nil && err != io.EOF { - debug.Printf("failed to close transport connection: %s", err) + if s.openingInstance != nil { + return errors.Errorf("sechan: invalid state. openingInstance must be nil when opening a new secure channel.") } - return io.EOF -} + var ( + err error + localKey *rsa.PrivateKey + remoteKey *rsa.PublicKey + ) -func (s *SecureChannel) openSecureChannel(requestType ua.SecurityTokenRequestType) error { - var err error - var localKey *rsa.PrivateKey - var remoteKey *rsa.PublicKey + s.startDispatcher.Do(func() { + go s.dispatcher() + }) // Set the encryption methods to Asymmetric with the appropriate // public keys. OpenSecureChannel is always encrypted with the @@ -692,19 +440,39 @@ func (s *SecureChannel) openSecureChannel(requestType ua.SecurityTokenRequestTyp return err } var ok bool - remoteKey, ok = remoteCert.PublicKey.(*rsa.PublicKey) - if !ok { + if remoteKey, ok = remoteCert.PublicKey.(*rsa.PublicKey); !ok { return ua.StatusBadCertificateInvalid } } - s.enc, err = uapolicy.Asymmetric(s.cfg.SecurityPolicyURI, localKey, remoteKey) + algo, err := uapolicy.Asymmetric(s.cfg.SecurityPolicyURI, localKey, remoteKey) if err != nil { return err } - nonce := make([]byte, s.enc.NonceLength()) - if _, err := rand.Read(nonce); err != nil { + s.openingInstance = newChannelInstance(s) + + if requestType == ua.SecurityTokenRequestTypeRenew { + // TODO: lock? sequenceNumber++? + // this seems racy. if another request goes out while the other open request is in flight then won't an error + // be raised on the server? can the sequenceNumber be as "global" as the request ID? + s.openingInstance.sequenceNumber = instance.sequenceNumber + } + + // trigger cleanup after we are all done + defer func() { + if s.openingInstance == nil || s.openingInstance.state != channelActive { + debug.Printf("failed to open a new secure channel") + } + s.openingInstance = nil + }() + + reqID := s.nextRequestID() + + s.openingInstance.algo = algo + + localNonce, err := algo.MakeNonce() + if err != nil { return err } @@ -712,99 +480,296 @@ func (s *SecureChannel) openSecureChannel(requestType ua.SecurityTokenRequestTyp ClientProtocolVersion: 0, RequestType: requestType, SecurityMode: s.cfg.SecurityMode, - ClientNonce: nonce, + ClientNonce: localNonce, RequestedLifetime: s.cfg.Lifetime, } - return s.SendRequest(req, nil, func(v interface{}) error { + return s.sendRequestWithTimeout(req, reqID, s.openingInstance, nil, s.cfg.RequestTimeout, func(v interface{}) error { resp, ok := v.(*ua.OpenSecureChannelResponse) if !ok { - return errors.Errorf("got %T, want OpenSecureChannelResponse", req) - } - s.securityTokenID = resp.SecurityToken.TokenID - s.lifetime = resp.SecurityToken.RevisedLifetime - debug.Printf("received security token tokenID: %v, createdAt: %v, lifetime %v", resp.SecurityToken.TokenID, resp.SecurityToken.CreatedAt, resp.SecurityToken.RevisedLifetime) - - s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, resp.ServerNonce) - if err != nil { - return err + return errors.Errorf("got %T, want OpenSecureChannelResponse", v) } - - s.setState(secureChannelOpen) - return nil + return s.handleOpenSecureChannelResponse(resp, localNonce, s.openingInstance) }) } -// closeSecureChannel sends CloseSecureChannelRequest on top of UASC to SecureChannel. -func (s *SecureChannel) closeSecureChannel() error { - req := &ua.CloseSecureChannelRequest{} +func (s *SecureChannel) handleOpenSecureChannelResponse(resp *ua.OpenSecureChannelResponse, localNonce []byte, instance *channelInstance) (err error) { + instance.state = channelActive + instance.secureChannelID = resp.SecurityToken.ChannelID + instance.securityTokenID = resp.SecurityToken.TokenID + instance.createdAt = resp.SecurityToken.CreatedAt + instance.revisedLifetime = time.Millisecond * time.Duration(resp.SecurityToken.RevisedLifetime) - defer s.setState(secureChannelClosed) - // Don't send the CloseSecureChannel message if it was never fully opened (due to ERR, etc) - if !s.hasState(secureChannelOpen) { - return io.EOF + // allow the client to specify a lifetime that is smaller + if int64(s.cfg.Lifetime) < int64(instance.revisedLifetime/time.Millisecond) { + instance.revisedLifetime = time.Millisecond * time.Duration(s.cfg.Lifetime) } - err := s.SendRequest(req, nil, nil) - if err != nil { + if instance.algo, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, localNonce, resp.ServerNonce); err != nil { return err } - return io.EOF + s.instancesMu.Lock() + defer s.instancesMu.Unlock() + + if _, ok := s.instances[resp.SecurityToken.ChannelID]; ok { + // since there's already an existing entry for this SecureChannelID it means we are in a renewal + s.instances[resp.SecurityToken.ChannelID] = append( + s.instances[resp.SecurityToken.ChannelID], + s.openingInstance, + ) + } else { + s.instances[resp.SecurityToken.ChannelID] = []*channelInstance{s.openingInstance} + } + + s.activeInstance = instance + + debug.Printf("received security token: channelID=%d tokenID=%d createdAt=%s lifetime=%s", instance.secureChannelID, instance.securityTokenID, instance.createdAt.Format(time.RFC3339), instance.revisedLifetime) + + if s.cfg.SecurityMode != ua.MessageSecurityModeNone { + go s.scheduleRenewal(instance) + go s.scheduleExpiration(instance) + } + + return +} + +func (s *SecureChannel) scheduleRenewal(instance *channelInstance) { + // https://reference.opcfoundation.org/v104/Core/docs/Part4/5.5.2/#5.5.2.1 + // Clients should request a new SecurityToken after 75 % of its lifetime has elapsed. This should ensure that + // clients will receive the new SecurityToken before the old one actually expire + const renewAfter = 0.75 + when := time.Second * time.Duration(instance.revisedLifetime.Seconds()*renewAfter) + + debug.Printf("channelID %d will be refreshed in %s (%s)", instance.secureChannelID, when, time.Now().UTC().Add(when).Format(time.RFC3339)) + + t := time.NewTimer(when) + defer t.Stop() + + s.closingMu.RLock() + defer s.closingMu.RUnlock() + + select { + case <-s.closing: + return + case <-t.C: + } + + // TODO: where should this error go? + _ = s.renew(instance) +} + +func (s *SecureChannel) renew(instance *channelInstance) error { + // lock ensure no one else renews this at the same time + instance.Lock() + defer instance.Unlock() + + return s.open(context.Background(), instance, ua.SecurityTokenRequestTypeRenew) } -func (s *SecureChannel) handleOpenSecureChannelRequest(svc interface{}) error { - debug.Printf("handleOpenSecureChannelRequest: Got OPN Request\n") +func (s *SecureChannel) scheduleExpiration(instance *channelInstance) { + // https://reference.opcfoundation.org/v104/Core/docs/Part4/5.5.2/#5.5.2.1 + // Clients should accept Messages secured by an expired SecurityToken for up to 25 % of the token lifetime. + const expireAfter = 1.25 + when := instance.createdAt.Add(time.Second * time.Duration(instance.revisedLifetime.Seconds()*expireAfter)) + + debug.Printf("channelID %d/%d will expire at %s", instance.secureChannelID, instance.securityTokenID, when.UTC().Format(time.RFC3339)) - var err error + t := time.NewTimer(time.Until(when)) - req, ok := svc.(*ua.OpenSecureChannelRequest) - if !ok { - debug.Printf("Expected OpenSecureChannel Request, got %T\n", svc) + s.closingMu.RLock() + defer s.closingMu.RUnlock() + + select { + case <-s.closing: + return + case <-t.C: } - s.cfg.Lifetime = req.RequestedLifetime - s.cfg.SecurityMode = req.SecurityMode + s.instancesMu.Lock() + defer s.instancesMu.Unlock() - nonce := make([]byte, s.enc.NonceLength()) - if _, err := rand.Read(nonce); err != nil { + oldInstances := s.instances[instance.securityTokenID] + + s.instances[instance.securityTokenID] = []*channelInstance{} + + for _, oldInstance := range oldInstances { + if oldInstance.secureChannelID != instance.secureChannelID { + // something has gone horribly wrong! + debug.Printf("secureChannelID mismatch during scheduleExpiration!") + } + if oldInstance.securityTokenID == instance.securityTokenID { + continue + } + s.instances[instance.securityTokenID] = append( + s.instances[instance.securityTokenID], + oldInstance, + ) + } +} + +func (s *SecureChannel) sendRequestWithTimeout( + req ua.Request, + reqID uint32, + instance *channelInstance, + authToken *ua.NodeID, + timeout time.Duration, + h func(interface{}) error) error { + + respRequired := h != nil + + ch, err := s.sendAsyncWithTimeout(req, reqID, instance, authToken, respRequired, timeout) + if err != nil { return err } - resp := &ua.OpenSecureChannelResponse{ - ResponseHeader: &ua.ResponseHeader{ - Timestamp: s.timeNow(), - RequestHandle: req.RequestHeader.RequestHandle, - ServiceDiagnostics: &ua.DiagnosticInfo{}, - StringTable: []string{}, - AdditionalHeader: ua.NewExtensionObject(nil), - }, - ServerProtocolVersion: 0, - SecurityToken: &ua.ChannelSecurityToken{ - ChannelID: s.secureChannelID, - TokenID: s.securityTokenID, - CreatedAt: s.timeNow(), - RevisedLifetime: req.RequestedLifetime, - }, - ServerNonce: nonce, - } - - if err := s.SendResponse(resp); err != nil { + + if !respRequired { + return nil + } + + // `+ timeoutLeniency` to give the server a chance to respond to TimeoutHint + timer := time.NewTimer(timeout + timeoutLeniency) + defer timer.Stop() + + select { + case resp := <-ch: + if resp.Err != nil { + if resp.V != nil { + _ = h(resp.V) // ignore result because resp.Err takes precedence + } + return resp.Err + } + return h(resp.V) + case <-timer.C: + s.popHandler(reqID) + return ua.StatusBadTimeout + } +} + +func (s *SecureChannel) popHandler(reqID uint32) (chan *response, bool) { + s.handlersMu.Lock() + defer s.handlersMu.Unlock() + + ch, ok := s.handlers[reqID] + if ok { + delete(s.handlers, reqID) + } + return ch, ok +} + +func (s *SecureChannel) Renew(ctx context.Context) error { + instance, err := s.getActiveChannelInstance() + if err != nil { return err } - s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, req.ClientNonce) + return s.renew(instance) +} + +// SendRequest sends the service request and calls h with the response. +func (s *SecureChannel) SendRequest(req ua.Request, authToken *ua.NodeID, h func(interface{}) error) error { + return s.SendRequestWithTimeout(req, authToken, s.cfg.RequestTimeout, h) +} + +func (s *SecureChannel) SendRequestWithTimeout(req ua.Request, authToken *ua.NodeID, timeout time.Duration, h func(interface{}) error) error { + active, err := s.getActiveChannelInstance() if err != nil { return err } - s.setState(secureChannelOpen) - return nil + return s.sendRequestWithTimeout(req, s.nextRequestID(), active, authToken, timeout, h) +} + +func (s *SecureChannel) sendAsyncWithTimeout( + req ua.Request, + reqID uint32, + instance *channelInstance, + authToken *ua.NodeID, + respRequired bool, + timeout time.Duration, +) (<-chan *response, error) { + + instance.Lock() + + m, err := instance.newRequestMessage(req, reqID, authToken, timeout) + if err != nil { + instance.Unlock() + return nil, err + } + + b, err := m.Encode() + if err != nil { + instance.Unlock() + return nil, err + } + + b, err = instance.signAndEncrypt(m, b) + if err != nil { + instance.Unlock() + return nil, err + } + + instance.Unlock() + + var resp chan *response + + if respRequired { + // register the handler if a callback was passed + resp = make(chan *response, 1) + + s.handlersMu.Lock() + + if s.handlers[reqID] != nil { + s.handlersMu.Unlock() + return nil, errors.Errorf("error: duplicate handler registration for request id %d", reqID) + } + + s.handlers[reqID] = resp + s.handlersMu.Unlock() + } + + // send the message + var n int + if n, err = s.c.Write(b); err != nil { + return nil, err + } + + atomic.AddUint64(&instance.bytesSent, uint64(n)) + atomic.AddUint32(&instance.messagesSent, 1) + + debug.Printf("uasc %d/%d: send %T with %d bytes", s.c.ID(), reqID, req, len(b)) + + return resp, nil +} + +func (s *SecureChannel) nextRequestID() uint32 { + s.requestIDMu.Lock() + defer s.requestIDMu.Unlock() + + s.requestID++ + if s.requestID == 0 { + s.requestID = 1 + } + + return s.requestID } -func (s *SecureChannel) popHandlerLock(reqid uint32) chan Response { - ch := s.handler[reqid] - delete(s.handler, reqid) - return ch +// Close closes an existing secure channel +func (s *SecureChannel) Close() error { + debug.Printf("uasc Close()") + + defer func() { + close(s.closing) + s.reset() + }() + + err := s.SendRequest(&ua.CloseSecureChannelRequest{}, nil, nil) + + if err != nil { + return err + } + + return io.EOF } func (s *SecureChannel) timeNow() time.Time { @@ -825,8 +790,11 @@ func mergeChunks(chunks []*MessageChunk) ([]byte, error) { // todo(fs): check if this is correct and necessary // sort.Sort(bySequence(chunks)) - var b []byte - var seqnr uint32 + var ( + b []byte + seqnr uint32 + ) + for _, c := range chunks { if c.SequenceHeader.SequenceNumber == seqnr { continue // duplicate chunk diff --git a/uasc/secure_channel_crypto.go b/uasc/secure_channel_crypto.go index ebd59baf..22fe4a1c 100644 --- a/uasc/secure_channel_crypto.go +++ b/uasc/secure_channel_crypto.go @@ -1,3 +1,7 @@ +// Copyright 2018-2020 opcua authors. All rights reserved. +// Use of this source code is governed by a MIT-style license that can be +// found in the LICENSE file. + package uasc import ( @@ -9,112 +13,8 @@ import ( "github.com/gopcua/opcua/uapolicy" ) -// signAndEncrypt encrypts the message bytes stored in b and returns the -// data signed and encrypted per the security policy information from the -// secure channel. -func (s *SecureChannel) signAndEncrypt(m *Message, b []byte) ([]byte, error) { - // Nothing to do - if s.cfg.SecurityMode == ua.MessageSecurityModeNone { - return b, nil - } - - var isAsymmetric bool - if s.hasState(secureChannelCreated) { - isAsymmetric = true - } - - var headerLength int - if isAsymmetric { - headerLength = 12 + m.AsymmetricSecurityHeader.Len() - } else { - headerLength = 12 + m.SymmetricSecurityHeader.Len() - } - - var encryptedLength int - if s.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { - plaintextBlockSize := s.enc.PlaintextBlockSize() - paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + s.enc.SignatureLength() + 1) % plaintextBlockSize) - - for i := 0; i <= paddingLength; i++ { - b = append(b, byte(paddingLength)) - } - encryptedLength = ((len(b[headerLength:]) + s.enc.SignatureLength()) / plaintextBlockSize) * s.enc.BlockSize() - } else { // MessageSecurityModeSign - encryptedLength = len(b[headerLength:]) + s.enc.SignatureLength() - } - - // Fix header size to account for signing / encryption - binary.LittleEndian.PutUint32(b[4:], uint32(headerLength+encryptedLength)) - m.Header.MessageSize = uint32(headerLength + encryptedLength) - - signature, err := s.enc.Signature(b) - if err != nil { - return nil, ua.StatusBadSecurityChecksFailed - } - - b = append(b, signature...) - c := b[headerLength:] - if s.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { - c, err = s.enc.Encrypt(c) - if err != nil { - return nil, ua.StatusBadSecurityChecksFailed - } - } - return append(b[:headerLength], c...), nil -} - -// verifyAndDecrypt decrypts an incoming message stored in b and returns the -// data in plaintext. After decryption, the message signature is also verified. -// Any error in decryption or verification of the signature will return an error -// The result is stored in m.Data -func (s *SecureChannel) verifyAndDecrypt(m *MessageChunk, b []byte) ([]byte, error) { - var err error - - var isAsymmetric bool - if s.hasState(secureChannelCreated) { - isAsymmetric = true - } - - var headerLength int - if isAsymmetric { - headerLength = 12 + m.AsymmetricSecurityHeader.Len() - } else { - headerLength = 12 + m.SymmetricSecurityHeader.Len() - } - - // Nothing to do - if s.cfg.SecurityMode == ua.MessageSecurityModeNone { - return m.Data, nil - } - - if s.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { - p, err := s.enc.Decrypt(b[headerLength:]) - if err != nil { - return nil, ua.StatusBadSecurityChecksFailed - } - b = append(b[:headerLength], p...) - } - - signature := b[len(b)-s.enc.RemoteSignatureLength():] - messageToVerify := b[:len(b)-s.enc.RemoteSignatureLength()] - - if err = s.enc.VerifySignature(messageToVerify, signature); err != nil { - return nil, ua.StatusBadSecurityChecksFailed - } - - var paddingLength int - if s.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { - paddingLength = int(messageToVerify[len(messageToVerify)-1]) + 1 - } - - b = messageToVerify[headerLength : len(messageToVerify)-paddingLength] - - return b, nil -} - // NewSessionSignature issues a new signature for the client to send on the next ActivateSessionRequest func (s *SecureChannel) NewSessionSignature(cert, nonce []byte) ([]byte, string, error) { - if s.cfg.SecurityMode == ua.MessageSecurityModeNone { return nil, "", nil } @@ -139,9 +39,8 @@ func (s *SecureChannel) NewSessionSignature(cert, nonce []byte) ([]byte, string, return sig, sigAlg, nil } -// VerifySessionSignature checks the integrity of a Create/Activate Session Response's signature +// VerifySessionSignature checks the integrity of a Create/Activate Session response's signature func (s *SecureChannel) VerifySessionSignature(cert, nonce, signature []byte) error { - if s.cfg.SecurityMode == ua.MessageSecurityModeNone { return nil } @@ -166,7 +65,6 @@ func (s *SecureChannel) VerifySessionSignature(cert, nonce, signature []byte) er // EncryptUserPassword issues a new signature for the client to send in ActivateSessionRequest func (s *SecureChannel) EncryptUserPassword(policyURI, password string, cert, nonce []byte) ([]byte, string, error) { - if policyURI == ua.SecurityPolicyURINone { return []byte(password), "", nil } @@ -203,7 +101,6 @@ func (s *SecureChannel) EncryptUserPassword(policyURI, password string, cert, no // NewUserTokenSignature issues a new signature for the client to send in ActivateSessionRequest func (s *SecureChannel) NewUserTokenSignature(policyURI string, cert, nonce []byte) ([]byte, string, error) { - if policyURI == ua.SecurityPolicyURINone { return nil, "", nil } diff --git a/uasc/secure_channel_instance.go b/uasc/secure_channel_instance.go new file mode 100644 index 00000000..e6e62655 --- /dev/null +++ b/uasc/secure_channel_instance.go @@ -0,0 +1,226 @@ +// Copyright 2018-2020 opcua authors. All rights reserved. +// Use of this source code is governed by a MIT-style license that can be +// found in the LICENSE file. + +package uasc + +import ( + "encoding/binary" + "math" + "sync" + "time" + + "github.com/gopcua/opcua/errors" + "github.com/gopcua/opcua/id" + "github.com/gopcua/opcua/ua" + "github.com/gopcua/opcua/uapolicy" +) + +type instanceState int + +const ( + channelOpening instanceState = iota + channelActive +) + +type channelInstance struct { + sync.Mutex + sc *SecureChannel + state instanceState + createdAt time.Time + revisedLifetime time.Duration + secureChannelID uint32 + securityTokenID uint32 + sequenceNumber uint32 + algo *uapolicy.EncryptionAlgorithm + + messagesSent uint32 + // messagesReceived uint32 + bytesSent uint64 + // bytesReceived uint64 +} + +func newChannelInstance(sc *SecureChannel) *channelInstance { + return &channelInstance{ + sc: sc, + state: channelOpening, + } +} + +func (c *channelInstance) nextSequenceNumber() uint32 { + // lock must be held + c.sequenceNumber++ + if c.sequenceNumber > math.MaxUint32-1023 { + c.sequenceNumber = 1 + } + + return c.sequenceNumber +} + +func (c *channelInstance) newRequestMessage(req ua.Request, reqID uint32, authToken *ua.NodeID, timeout time.Duration) (*Message, error) { + typeID := ua.ServiceTypeID(req) + if typeID == 0 { + return nil, errors.Errorf("unknown service %T. Did you call register?", req) + } + if authToken == nil { + authToken = ua.NewTwoByteNodeID(0) + } + + reqHdr := &ua.RequestHeader{ + AuthenticationToken: authToken, + Timestamp: c.sc.timeNow(), + RequestHandle: reqID, // TODO: can I cheat like this? + } + + if timeout > 0 && timeout < c.sc.cfg.RequestTimeout { + timeout = c.sc.cfg.RequestTimeout + } + reqHdr.TimeoutHint = uint32(timeout / time.Millisecond) + req.SetHeader(reqHdr) + + // encode the message + return c.newMessage(req, typeID, reqID), nil +} + +func (c *channelInstance) newMessage(srv interface{}, typeID uint16, requestID uint32) *Message { + sequenceNumber := c.nextSequenceNumber() + + switch typeID { + case id.OpenSecureChannelRequest_Encoding_DefaultBinary, id.OpenSecureChannelResponse_Encoding_DefaultBinary: + // Do not send the thumbprint for security mode None + // even if we have a certificate. + // + // See https://github.com/gopcua/opcua/issues/259 + thumbprint := c.sc.cfg.Thumbprint + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeNone { + thumbprint = nil + } + + return &Message{ + MessageHeader: &MessageHeader{ + Header: NewHeader(MessageTypeOpenSecureChannel, ChunkTypeFinal, c.secureChannelID), + AsymmetricSecurityHeader: NewAsymmetricSecurityHeader(c.sc.cfg.SecurityPolicyURI, c.sc.cfg.Certificate, thumbprint), + SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), + }, + TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + Service: srv, + } + + case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: + return &Message{ + MessageHeader: &MessageHeader{ + Header: NewHeader(MessageTypeCloseSecureChannel, ChunkTypeFinal, c.secureChannelID), + SymmetricSecurityHeader: NewSymmetricSecurityHeader(c.securityTokenID), + SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), + }, + TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + Service: srv, + } + + default: + return &Message{ + MessageHeader: &MessageHeader{ + Header: NewHeader(MessageTypeMessage, ChunkTypeFinal, c.secureChannelID), + SymmetricSecurityHeader: NewSymmetricSecurityHeader(c.securityTokenID), + SequenceHeader: NewSequenceHeader(sequenceNumber, requestID), + }, + TypeID: ua.NewFourByteExpandedNodeID(0, typeID), + Service: srv, + } + } +} + +// signAndEncrypt encrypts the message bytes stored in b and returns the +// data signed and encrypted per the security policy information from the +// secure channel. +func (c *channelInstance) signAndEncrypt(m *Message, b []byte) ([]byte, error) { + // Nothing to do + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeNone { + return b, nil + } + + isAsymmetric := m.MessageHeader.AsymmetricSecurityHeader != nil + + var headerLength int + + if isAsymmetric { + headerLength = 12 + m.AsymmetricSecurityHeader.Len() + } else { + headerLength = 12 + m.SymmetricSecurityHeader.Len() + } + + var encryptedLength int + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { + plaintextBlockSize := c.algo.PlaintextBlockSize() + paddingLength := plaintextBlockSize - ((len(b[headerLength:]) + c.algo.SignatureLength() + 1) % plaintextBlockSize) + + for i := 0; i <= paddingLength; i++ { + b = append(b, byte(paddingLength)) + } + encryptedLength = ((len(b[headerLength:]) + c.algo.SignatureLength()) / plaintextBlockSize) * c.algo.BlockSize() + } else { // MessageSecurityModeSign + encryptedLength = len(b[headerLength:]) + c.algo.SignatureLength() + } + + // Fix header size to account for signing / encryption + binary.LittleEndian.PutUint32(b[4:], uint32(headerLength+encryptedLength)) + m.Header.MessageSize = uint32(headerLength + encryptedLength) + + signature, err := c.algo.Signature(b) + if err != nil { + return nil, ua.StatusBadSecurityChecksFailed + } + + b = append(b, signature...) + p := b[headerLength:] + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { + p, err = c.algo.Encrypt(p) + if err != nil { + return nil, ua.StatusBadSecurityChecksFailed + } + } + return append(b[:headerLength], p...), nil +} + +func (c *channelInstance) verifyAndDecrypt(m *MessageChunk, r []byte) ([]byte, error) { + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeNone { + return m.Data, nil + } + + isAsymmetric := m.AsymmetricSecurityHeader != nil + + headerLength := 12 + + if isAsymmetric { + headerLength += m.AsymmetricSecurityHeader.Len() + } else { + headerLength += m.SymmetricSecurityHeader.Len() + } + + b := make([]byte, len(r)) + copy(b, r) + + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { + p, err := c.algo.Decrypt(b[headerLength:]) + if err != nil { + return nil, ua.StatusBadSecurityChecksFailed + } + b = append(b[:headerLength], p...) + } + + signature := b[len(b)-c.algo.RemoteSignatureLength():] + messageToVerify := b[:len(b)-c.algo.RemoteSignatureLength()] + + if err := c.algo.VerifySignature(messageToVerify, signature); err != nil { + return nil, ua.StatusBadSecurityChecksFailed + } + + var paddingLength int + if c.sc.cfg.SecurityMode == ua.MessageSecurityModeSignAndEncrypt || isAsymmetric { + paddingLength = int(messageToVerify[len(messageToVerify)-1]) + 1 + } + + b = messageToVerify[headerLength : len(messageToVerify)-paddingLength] + + return b, nil +} diff --git a/uasc/secure_channel_test.go b/uasc/secure_channel_test.go index c16a2bae..0aadf46c 100644 --- a/uasc/secure_channel_test.go +++ b/uasc/secure_channel_test.go @@ -13,6 +13,16 @@ import ( func TestNewRequestMessage(t *testing.T) { fixedTime := func() time.Time { return time.Date(2019, 1, 1, 12, 13, 14, 0, time.UTC) } + + buildSecureChannel := func(sc *SecureChannel, instance *channelInstance) *SecureChannel { + if instance == nil { + instance = newChannelInstance(sc) + } + sc.activeInstance = instance + sc.activeInstance.sc = sc + return sc + } + tests := []struct { name string sechan *SecureChannel @@ -23,11 +33,11 @@ func TestNewRequestMessage(t *testing.T) { }{ { name: "first-request", - sechan: &SecureChannel{ - cfg: &Config{}, - reqhdr: &ua.RequestHeader{}, - time: fixedTime, - }, + sechan: buildSecureChannel(&SecureChannel{ + cfg: &Config{}, + // reqhdr: &ua.RequestHeader{}, + time: fixedTime, + }, nil), req: &ua.ReadRequest{}, m: &Message{ MessageHeader: &MessageHeader{ @@ -53,15 +63,19 @@ func TestNewRequestMessage(t *testing.T) { }, { name: "subsequent-request", - sechan: &SecureChannel{ - cfg: &Config{}, - sequenceNumber: 777, - requestID: 555, - reqhdr: &ua.RequestHeader{ - RequestHandle: 444, + sechan: buildSecureChannel( + &SecureChannel{ + cfg: &Config{}, + requestID: 555, + // reqhdr: &ua.RequestHeader{ + // RequestHandle: 444, + // }, + time: fixedTime, }, - time: fixedTime, - }, + &channelInstance{ + sequenceNumber: 777, + }, + ), req: &ua.ReadRequest{}, m: &Message{ MessageHeader: &MessageHeader{ @@ -80,22 +94,22 @@ func TestNewRequestMessage(t *testing.T) { RequestHeader: &ua.RequestHeader{ AuthenticationToken: ua.NewTwoByteNodeID(0), Timestamp: fixedTime(), - RequestHandle: 445, + RequestHandle: 556, }, }, }, }, { name: "counter-rollover", - sechan: &SecureChannel{ - cfg: &Config{}, - sequenceNumber: math.MaxUint32 - 1023, - requestID: math.MaxUint32, - reqhdr: &ua.RequestHeader{ - RequestHandle: math.MaxUint32, + sechan: buildSecureChannel( + &SecureChannel{ + cfg: &Config{}, + requestID: math.MaxUint32, + time: fixedTime, }, - time: fixedTime, - }, + &channelInstance{ + sequenceNumber: math.MaxUint32 - 1023, + }), req: &ua.ReadRequest{}, m: &Message{ MessageHeader: &MessageHeader{ @@ -123,7 +137,7 @@ func TestNewRequestMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m, err := tt.sechan.newRequestMessage(tt.req, tt.authToken, tt.timeout) + m, err := tt.sechan.activeInstance.newRequestMessage(tt.req, tt.sechan.nextRequestID(), tt.authToken, tt.timeout) if err != nil { t.Fatalf("got err %v want nil", err) }