From 674c645cd210f1730423636b36f66b5351b6fe39 Mon Sep 17 00:00:00 2001 From: Jin Date: Sun, 12 Jan 2025 11:04:52 +0800 Subject: [PATCH] feat:optimize p2p read message timeout --- p2p/peers/message.go | 86 +++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/p2p/peers/message.go b/p2p/peers/message.go index 9c14c791..eb64b505 100644 --- a/p2p/peers/message.go +++ b/p2p/peers/message.go @@ -219,7 +219,6 @@ func (p *ConnMsgRW) Encoder() encoder.NetworkEncoding { func (p *ConnMsgRW) readLoop(pe *Peer, errc chan<- error) { defer p.wg.Done() returnFun := func(err error) { - if err != nil { select { case <-p.closing: @@ -228,20 +227,11 @@ func (p *ConnMsgRW) readLoop(pe *Peer, errc chan<- error) { } } } - var msg *Msg for { - ctx, can := context.WithTimeout(context.Background(), HandleTimeout) - defer can() - ret := make(chan *Msg) - msg = nil - go p.readMsg(pe, ret) - select { - case <-p.closing: - return - case <-ctx.Done(): - returnFun(fmt.Errorf("ConnMsgRW read message timeout:%s", pe.GetID())) + msg, err := p.readMsg(pe) + if err != nil { + returnFun(err) return - case msg = <-ret: } if msg == nil { returnFun(fmt.Errorf("No read msg")) @@ -264,7 +254,7 @@ func (p *ConnMsgRW) readLoop(pe *Peer, errc chan<- error) { msgT := reflect.New(ty) msgd := msgT.Interface() - err := p.en.DecodeWithMaxLength(bytes.NewReader(msg.Payload), msgd) + err = p.en.DecodeWithMaxLength(bytes.NewReader(msg.Payload), msgd) // value, ok := p.pending.Load(msg.ID) if ok { @@ -289,60 +279,64 @@ func (p *ConnMsgRW) readLoop(pe *Peer, errc chan<- error) { } } -func (p *ConnMsgRW) readMsg(pe *Peer, ret chan *Msg) { - returnFun := func(msg *Msg) { - select { - case <-p.closing: - return - case ret <- msg: - } - } - +func (p *ConnMsgRW) readMsg(pe *Peer) (*Msg, error) { if p.closed.Load() { - log.Warn(ErrConnClosed.Error()) - returnFun(nil) - return + return nil, ErrConnClosed } dataHead := make([]byte, PacketSize) size, err := p.rw.Read(dataHead) if err == io.EOF { log.Debug("Base Stream closed by peer", "peer", pe.IDWithAddress()) - returnFun(nil) - return + return nil, nil } if err != nil { log.Warn("Error reading from base stream", "peer", pe.IDWithAddress(), "error", err) - returnFun(nil) - return + return nil, err } if size != PacketSize { - log.Warn("Error message head size", "peer", pe.IDWithAddress()) - returnFun(nil) - return + err = fmt.Errorf("Error message head size") + log.Warn(err.Error(), "peer", pe.IDWithAddress()) + return nil, err } dataSize := binary.BigEndian.Uint64(dataHead) log.Debug("Receive message head", "peer", pe.IDWithAddress(), "size", dataSize) if dataSize > MaxMessageSize { - log.Warn("Too large message", "size", dataSize, "max", MaxMessageSize) - returnFun(nil) - return + return nil, fmt.Errorf("Too large message size: %d > %d", dataSize, MaxMessageSize) } msgData := make([]byte, dataSize) - size, err = io.ReadFull(p.rw, msgData) + ctx, can := context.WithTimeout(context.Background(), HandleTimeout) + defer can() + + ret := make(chan error) + go func(ret chan error) { + size, err = io.ReadFull(p.rw, msgData) + select { + case <-p.closing: + return + case ret <- err: + } + }(ret) + + select { + case <-p.closing: + return nil, ErrConnClosed + case <-ctx.Done(): + return nil, fmt.Errorf("ConnMsgRW read message timeout:%s", pe.GetID()) + case err = <-ret: + } + if err == io.EOF { log.Debug("Base Stream closed by peer", "peer", pe.IDWithAddress()) - returnFun(nil) - return + return nil, nil } if err != nil { log.Warn("Error reading from long stream", "peer", pe.IDWithAddress(), "error", err) - returnFun(nil) - return + return nil, err } if uint64(size) != dataSize { - log.Warn("Receive error size message data", "peer", pe.IDWithAddress()) - returnFun(nil) - return + err = fmt.Errorf("Receive error size message data") + log.Warn(err.Error(), "peer", pe.IDWithAddress()) + return nil, err } msgIDBs := msgData[:MsgCodeSize] msgID := binary.BigEndian.Uint64(msgIDBs) @@ -351,11 +345,11 @@ func (p *ConnMsgRW) readMsg(pe *Peer, ret chan *Msg) { log.Debug("Receive message", "id", msgID, "code", msgCode, "peer", pe.IDWithAddress(), "size", size) - returnFun(&Msg{ + return &Msg{ ID: msgID, Code: msgCode, Payload: msgData[MsgCodeSize*2:], - }) + }, nil } func NewConnRW(stream network.Stream, en encoder.NetworkEncoding) *ConnMsgRW {