Skip to content

Commit 48b78d2

Browse files
committed
fix: connections not closing
1 parent ffb54dd commit 48b78d2

File tree

6 files changed

+61
-44
lines changed

6 files changed

+61
-44
lines changed

client/client.go

+2-19
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (c *proxyConn) connect(ctx context.Context) (net.Conn, error) {
150150
return nil, err
151151
}
152152

153-
var cipherTarget io.ReadWriter
153+
var cipherTarget net.Conn
154154
ed := crypto.StreamEncryptDecrypter{
155155
EncryptKey: c.sessionKey,
156156
DecryptKey: c.sessionKey,
@@ -209,10 +209,7 @@ func (c *proxyConn) connect(ctx context.Context) (net.Conn, error) {
209209
}
210210
}
211211

212-
return &CipherConn {
213-
cipherTarget,
214-
c.target,
215-
}, nil
212+
return cipherTarget, nil
216213
}
217214

218215
func (c *proxyConn) writePubKey() error {
@@ -314,17 +311,3 @@ func (c *proxyConn) readReply() error {
314311

315312
return nil
316313
}
317-
318-
// CipherConn implements net.Conn interface, with a underlying io.ReadWriter.
319-
type CipherConn struct {
320-
io.ReadWriter
321-
net.Conn
322-
}
323-
324-
func (c *CipherConn) Read(b []byte) (n int, err error) {
325-
return c.ReadWriter.Read(b)
326-
}
327-
328-
func (c *CipherConn) Write(b []byte) (n int, err error) {
329-
return c.ReadWriter.Write(b)
330-
}

cmd/groundhog/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func initLogger() {
8787
case "fatal":
8888
level = yagl.LvlFatal
8989
default:
90-
fmt.Fprintf(os.Stderr, "unrecognized logging level: %s\n", level)
90+
fmt.Fprintf(os.Stderr, "unrecognized logging level: %s\n", logLevel)
9191
os.Exit(1)
9292
}
9393

common/crypto/crypto.go

+43-22
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package crypto
22

33
import (
4-
"crypto/cipher"
5-
"io"
6-
"errors"
74
"crypto/aes"
5+
"errors"
6+
"io"
7+
"net"
8+
"crypto/cipher"
89
)
910

1011
type readWriter struct {
@@ -68,36 +69,35 @@ func (ed *StreamEncryptDecrypter) initCipherStream() error {
6869
// corresponding ciphertext io.ReadWriter. Any ciphertext write to returned
6970
// io.ReadWriter will be decrypted and write to plaintext. Any plaintext read
7071
// from plaintext will be encrypted and write to returned io.ReadWriter.
71-
func (ed *StreamEncryptDecrypter) Ciphertext(plaintext io.ReadWriter) (io.ReadWriter, error) {
72+
func (ed *StreamEncryptDecrypter) Ciphertext(plaintext net.Conn) (net.Conn, error) {
7273
if err := ed.initCipherStream(); err != nil {
7374
return nil, err
7475
}
7576

7677
// logic here could be simpler if golang has a built-in duplex pipe.
77-
// net.Pipe is a solution, but seems too heavy,
78-
// with possibility of causing memory leak if not careful
7978
cipherRdIn, cipherWtOut := io.Pipe()
8079
cipherRdOut, cipherWtIn := io.Pipe()
8180

82-
ciphertext := &readWriter{
83-
cipherRdOut,
84-
cipherWtOut,
81+
ciphertext := &CipherConn{
82+
&readWriter{
83+
cipherRdOut,
84+
cipherWtOut,
85+
},
86+
plaintext,
8587
}
8688

8789
// decrypt ciphertext to plaintext
8890
go func() {
8991
decrypter := &cipher.StreamReader{S: ed.DecryptStream, R: cipherRdIn}
9092
io.Copy(plaintext, decrypter)
91-
cipherWtOut.Close()
92-
cipherRdIn.Close()
93+
ciphertext.Close() // which close the underlying plaintext
9394
}()
9495

9596
// encrypt plaintext to ciphertext
9697
go func() {
9798
encrypter := &cipher.StreamWriter{S: ed.EncryptStream, W: cipherWtIn}
9899
io.Copy(encrypter, plaintext)
99-
cipherWtIn.Close()
100-
cipherRdOut.Close()
100+
ciphertext.Close()
101101
}()
102102

103103
return ciphertext, nil
@@ -107,34 +107,55 @@ func (ed *StreamEncryptDecrypter) Ciphertext(plaintext io.ReadWriter) (io.ReadWr
107107
// corresponding plaintext io.ReadWriter. Any plaintext write to returned
108108
// io.ReadWriter will be encrypted and write to ciphertext. Any ciphertext read
109109
// from ciphertext will be decrypted and write to returned io.ReadWriter.
110-
func (ed *StreamEncryptDecrypter) Plaintext(ciphertext io.ReadWriter) (io.ReadWriter, error) {
110+
func (ed *StreamEncryptDecrypter) Plaintext(ciphertext net.Conn) (net.Conn, error) {
111111
if err := ed.initCipherStream(); err != nil {
112112
return nil, err
113113
}
114114

115115
plainRdIn, plainWtOut := io.Pipe()
116116
plainRdOut, plainWtIn := io.Pipe()
117117

118-
plaintext := &readWriter{
119-
plainRdOut,
120-
plainWtOut,
118+
plaintext := &CipherConn {
119+
&readWriter{
120+
plainRdOut,
121+
plainWtOut,
122+
},
123+
ciphertext,
121124
}
122125

123126
// encrypt plaintext to ciphertext
124127
go func() {
125128
encrypter := &cipher.StreamWriter{S: ed.EncryptStream, W: ciphertext}
126129
io.Copy(encrypter, plainRdIn)
127-
plainWtIn.Close()
128-
plainRdOut.Close()
130+
plaintext.Close()
129131
}()
130132

131133
// decrypt ciphertext to plaintext
132134
go func() {
133135
decrypter := &cipher.StreamReader{S: ed.DecryptStream, R: ciphertext}
134136
io.Copy(plainWtIn, decrypter)
135-
plainWtOut.Close()
136-
plainRdIn.Close()
137+
plaintext.Close()
137138
}()
138139

139140
return plaintext, nil
140-
}
141+
}
142+
143+
// CipherConn implements net.Conn interface, with a underlying io.ReadWriter.
144+
type CipherConn struct {
145+
io.ReadWriter
146+
net.Conn
147+
}
148+
149+
func (c *CipherConn) Read(b []byte) (n int, err error) {
150+
if _, err := c.Conn.Read([]byte{}); err != nil {
151+
return 0, err
152+
}
153+
return c.ReadWriter.Read(b)
154+
}
155+
156+
func (c *CipherConn) Write(b []byte) (n int, err error) {
157+
if _, err := c.Conn.Write([]byte{}); err != nil {
158+
return 0, err
159+
}
160+
return c.ReadWriter.Write(b)
161+
}

common/util/util.go

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package util
44
import (
55
"io"
66
"sync"
7+
"net"
78
)
89

910
// Proxy connect two ReadWriter, forward data between them in a full-duplex
@@ -16,6 +17,7 @@ func Proxy(lhs io.ReadWriter, rhs io.ReadWriter) (lhsWritten, rhsWritten int64,
1617
wg.Add(1)
1718
go func() {
1819
lhsWritten, err = io.Copy(lhs, rhs)
20+
closeNetConn(lhs, rhs)
1921
if err != nil {
2022
wg.Done()
2123
} else {
@@ -27,6 +29,7 @@ func Proxy(lhs io.ReadWriter, rhs io.ReadWriter) (lhsWritten, rhsWritten int64,
2729
wg.Add(1)
2830
go func() {
2931
rhsWritten, err = io.Copy(rhs, lhs)
32+
closeNetConn(lhs, rhs)
3033
if err != nil {
3134
wg.Done()
3235
} else {
@@ -37,3 +40,11 @@ func Proxy(lhs io.ReadWriter, rhs io.ReadWriter) (lhsWritten, rhsWritten int64,
3740
wg.Wait()
3841
return
3942
}
43+
44+
func closeNetConn(rws... io.ReadWriter) {
45+
for _, v := range rws {
46+
if conn, ok := v.(net.Conn); ok {
47+
conn.Close()
48+
}
49+
}
50+
}

server/server.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ type gndhog struct {
140140
}
141141

142142
func (g *gndhog) init(ctx context.Context, conn net.Conn) {
143-
defer conn.Close()
143+
ctx, cancel := context.WithCancel(ctx)
144+
defer cancel()
144145

145146
// watchdog to close connections if context cancelled
146147
go func() {

socks5/socks5.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ type socks struct {
8787
}
8888

8989
func (s *socks) init(ctx context.Context, conn net.Conn) {
90-
defer conn.Close()
90+
ctx, cancel := context.WithCancel(ctx)
91+
defer cancel()
9192

9293
// watchdog to close connections if context cancelled
9394
go func() {

0 commit comments

Comments
 (0)