@@ -20,6 +20,21 @@ const (
2020 DefaultSSHPort = 10022
2121)
2222
23+ type writeTimeoutConn struct {
24+ net.Conn
25+ timeout time.Duration
26+ }
27+
28+ func (c * writeTimeoutConn ) Write (p []byte ) (n int , err error ) {
29+ if err = c .Conn .SetWriteDeadline (time .Now ().Add (c .timeout )); err != nil {
30+ return 0 , fmt .Errorf ("writeTimeoutConn: SetWriteDeadline: %w" , err )
31+ }
32+ if n , err = c .Conn .Write (p ); err != nil {
33+ return n , fmt .Errorf ("writeTimeoutConn: write: %w" , err )
34+ }
35+ return n , nil
36+ }
37+
2338// legacyConnection is an insecure TCP connection.
2439type legacyConnection struct {
2540 net.Conn
@@ -32,10 +47,16 @@ func (c *legacyConnection) Connect(addr string, timeout time.Duration) error {
3247 return err
3348 }
3449
35- c . Conn , err = net .DialTimeout ("tcp" , addr , timeout )
50+ conn , err : = net .DialTimeout ("tcp" , addr , timeout )
3651 if err != nil {
3752 return fmt .Errorf ("legacy connection: dial: %w" , err )
3853 }
54+
55+ c .Conn = & writeTimeoutConn {
56+ Conn : conn ,
57+ timeout : timeout ,
58+ }
59+
3960 return nil
4061}
4162
@@ -53,10 +74,16 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {
5374 return err
5475 }
5576
56- if c .Conn , err = net .DialTimeout ("tcp" , addr , timeout ); err != nil {
77+ conn , err := net .DialTimeout ("tcp" , addr , timeout )
78+ if err != nil {
5779 return fmt .Errorf ("ssh connection: dial: %w" , err )
5880 }
5981
82+ c .Conn = & writeTimeoutConn {
83+ Conn : conn ,
84+ timeout : timeout ,
85+ }
86+
6087 clientConn , chans , reqs , err := ssh .NewClientConn (c .Conn , addr , c .config )
6188 if err != nil {
6289 return fmt .Errorf ("ssh connecion: ssh client conn: %w" , err )
0 commit comments