@@ -18,14 +18,18 @@ var (
18
18
aLongTimeAgo = time .Unix (1 , 0 )
19
19
)
20
20
21
- func (d * Dialer ) connect (ctx context.Context , c net.Conn , address string ) (_ net.Addr , ctxErr error ) {
22
- host , port , err := splitHostPort (address )
21
+ func (d * Dialer ) connect (ctx context.Context , c net.Conn , req Request ) (conn net.Conn , _ net.Addr , ctxErr error ) {
22
+ var udpHeader []byte
23
+
24
+ host , port , err := splitHostPort (req .DstAddress )
23
25
if err != nil {
24
- return nil , err
26
+ return c , nil , err
25
27
}
26
28
if deadline , ok := ctx .Deadline (); ok && ! deadline .IsZero () {
27
29
c .SetDeadline (deadline )
28
- defer c .SetDeadline (noDeadline )
30
+ if req .Cmd != CmdUDPAssociate {
31
+ defer c .SetDeadline (noDeadline )
32
+ }
29
33
}
30
34
if ctx != context .Background () {
31
35
errCh := make (chan error , 1 )
@@ -47,14 +51,15 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
47
51
}()
48
52
}
49
53
54
+ conn = c
50
55
b := make ([]byte , 0 , 6 + len (host )) // the size here is just an estimate
51
56
b = append (b , Version5 )
52
57
if len (d .AuthMethods ) == 0 || d .Authenticate == nil {
53
58
b = append (b , 1 , byte (AuthMethodNotRequired ))
54
59
} else {
55
60
ams := d .AuthMethods
56
61
if len (ams ) > 255 {
57
- return nil , errors .New ("too many authentication methods" )
62
+ return c , nil , errors .New ("too many authentication methods" )
58
63
}
59
64
b = append (b , byte (len (ams )))
60
65
for _ , am := range ams {
@@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
69
74
return
70
75
}
71
76
if b [0 ] != Version5 {
72
- return nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
77
+ return c , nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
73
78
}
74
79
am := AuthMethod (b [1 ])
75
80
if am == AuthMethodNoAcceptableMethods {
76
- return nil , errors .New ("no acceptable authentication methods" )
81
+ return c , nil , errors .New ("no acceptable authentication methods" )
77
82
}
78
83
if d .Authenticate != nil {
79
84
if ctxErr = d .Authenticate (ctx , c , am ); ctxErr != nil {
@@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
82
87
}
83
88
84
89
b = b [:0 ]
85
- b = append (b , Version5 , byte (d . cmd ), 0 )
90
+ b = append (b , Version5 , byte (req . Cmd ), 0 )
86
91
if ip := net .ParseIP (host ); ip != nil {
87
92
if ip4 := ip .To4 (); ip4 != nil {
88
93
b = append (b , AddrTypeIPv4 )
@@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
91
96
b = append (b , AddrTypeIPv6 )
92
97
b = append (b , ip6 ... )
93
98
} else {
94
- return nil , errors .New ("unknown address type" )
99
+ return c , nil , errors .New ("unknown address type" )
95
100
}
96
101
} else {
97
102
if len (host ) > 255 {
98
- return nil , errors .New ("FQDN too long" )
103
+ return c , nil , errors .New ("FQDN too long" )
99
104
}
100
105
b = append (b , AddrTypeFQDN )
101
106
b = append (b , byte (len (host )))
102
107
b = append (b , host ... )
103
108
}
104
109
b = append (b , byte (port >> 8 ), byte (port ))
110
+
111
+ if req .Cmd == CmdUDPAssociate {
112
+ udpHeader = make ([]byte , len (b ))
113
+ copy (udpHeader [3 :], b [3 :])
114
+ }
115
+
105
116
if _ , ctxErr = c .Write (b ); ctxErr != nil {
106
117
return
107
118
}
@@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
110
121
return
111
122
}
112
123
if b [0 ] != Version5 {
113
- return nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
124
+ return c , nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
114
125
}
115
126
if cmdErr := Reply (b [1 ]); cmdErr != StatusSucceeded {
116
- return nil , errors .New ("unknown error " + cmdErr .String ())
127
+ return c , nil , errors .New ("unknown error " + cmdErr .String ())
117
128
}
118
129
if b [2 ] != 0 {
119
- return nil , errors .New ("non-zero reserved field" )
130
+ return c , nil , errors .New ("non-zero reserved field" )
120
131
}
121
132
l := 2
133
+ addrType := b [3 ]
122
134
var a Addr
123
- switch b [ 3 ] {
135
+ switch addrType {
124
136
case AddrTypeIPv4 :
125
137
l += net .IPv4len
126
138
a .IP = make (net.IP , net .IPv4len )
@@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
129
141
a .IP = make (net.IP , net .IPv6len )
130
142
case AddrTypeFQDN :
131
143
if _ , err := io .ReadFull (c , b [:1 ]); err != nil {
132
- return nil , err
144
+ return c , nil , err
133
145
}
134
146
l += int (b [0 ])
135
147
default :
136
- return nil , errors .New ("unknown address type " + strconv .Itoa (int (b [3 ])))
148
+ return c , nil , errors .New ("unknown address type " + strconv .Itoa (int (b [3 ])))
137
149
}
150
+
138
151
if cap (b ) < l {
139
152
b = make ([]byte , l )
140
153
} else {
@@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
149
162
a .Name = string (b [:len (b )- 2 ])
150
163
}
151
164
a .Port = int (b [len (b )- 2 ])<< 8 | int (b [len (b )- 1 ])
152
- return & a , nil
153
- }
154
165
155
- func splitHostPort ( address string ) ( string , int , error ) {
156
- host , port , err := net .SplitHostPort ( address )
157
- if err != nil {
158
- return "" , 0 , err
159
- }
160
- portnum , err := strconv . Atoi ( port )
161
- if err != nil {
162
- return "" , 0 , err
163
- }
164
- if 1 > portnum || portnum > 0xffff {
165
- return "" , 0 , errors . New ( "port number out of range " + port )
166
+ if req . Cmd == CmdUDPAssociate {
167
+ var uc net.Conn
168
+ if uc , err = d . proxyDial ( ctx , req . UDPNetwork , a . String ()); err != nil {
169
+ return c , & a , err
170
+ }
171
+ c . SetDeadline ( noDeadline )
172
+ go func () {
173
+ defer uc . Close ()
174
+ io . Copy ( io . Discard , c )
175
+ }()
176
+ return udpConn { Conn : uc , socksConn : c , header : udpHeader }, & a , nil
166
177
}
167
- return host , portnum , nil
178
+
179
+ return c , & a , nil
168
180
}
0 commit comments