113113#include <trace/events/skb.h>
114114#include <net/busy_poll.h>
115115#include "udp_impl.h"
116+ #include <net/sock_reuseport.h>
116117
117118struct udp_table udp_table __read_mostly ;
118119EXPORT_SYMBOL (udp_table );
@@ -137,7 +138,8 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
137138 unsigned long * bitmap ,
138139 struct sock * sk ,
139140 int (* saddr_comp )(const struct sock * sk1 ,
140- const struct sock * sk2 ),
141+ const struct sock * sk2 ,
142+ bool match_wildcard ),
141143 unsigned int log )
142144{
143145 struct sock * sk2 ;
@@ -152,8 +154,9 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
152154 (!sk2 -> sk_bound_dev_if || !sk -> sk_bound_dev_if ||
153155 sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
154156 (!sk2 -> sk_reuseport || !sk -> sk_reuseport ||
157+ rcu_access_pointer (sk -> sk_reuseport_cb ) ||
155158 !uid_eq (uid , sock_i_uid (sk2 ))) &&
156- saddr_comp (sk , sk2 )) {
159+ saddr_comp (sk , sk2 , true )) {
157160 if (!bitmap )
158161 return 1 ;
159162 __set_bit (udp_sk (sk2 )-> udp_port_hash >> log , bitmap );
@@ -170,7 +173,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
170173 struct udp_hslot * hslot2 ,
171174 struct sock * sk ,
172175 int (* saddr_comp )(const struct sock * sk1 ,
173- const struct sock * sk2 ))
176+ const struct sock * sk2 ,
177+ bool match_wildcard ))
174178{
175179 struct sock * sk2 ;
176180 struct hlist_nulls_node * node ;
@@ -186,8 +190,9 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
186190 (!sk2 -> sk_bound_dev_if || !sk -> sk_bound_dev_if ||
187191 sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
188192 (!sk2 -> sk_reuseport || !sk -> sk_reuseport ||
193+ rcu_access_pointer (sk -> sk_reuseport_cb ) ||
189194 !uid_eq (uid , sock_i_uid (sk2 ))) &&
190- saddr_comp (sk , sk2 )) {
195+ saddr_comp (sk , sk2 , true )) {
191196 res = 1 ;
192197 break ;
193198 }
@@ -196,6 +201,35 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
196201 return res ;
197202}
198203
204+ static int udp_reuseport_add_sock (struct sock * sk , struct udp_hslot * hslot ,
205+ int (* saddr_same )(const struct sock * sk1 ,
206+ const struct sock * sk2 ,
207+ bool match_wildcard ))
208+ {
209+ struct net * net = sock_net (sk );
210+ struct hlist_nulls_node * node ;
211+ kuid_t uid = sock_i_uid (sk );
212+ struct sock * sk2 ;
213+
214+ sk_nulls_for_each (sk2 , node , & hslot -> head ) {
215+ if (net_eq (sock_net (sk2 ), net ) &&
216+ sk2 != sk &&
217+ sk2 -> sk_family == sk -> sk_family &&
218+ ipv6_only_sock (sk2 ) == ipv6_only_sock (sk ) &&
219+ (udp_sk (sk2 )-> udp_port_hash == udp_sk (sk )-> udp_port_hash ) &&
220+ (sk2 -> sk_bound_dev_if == sk -> sk_bound_dev_if ) &&
221+ sk2 -> sk_reuseport && uid_eq (uid , sock_i_uid (sk2 )) &&
222+ (* saddr_same )(sk , sk2 , false)) {
223+ return reuseport_add_sock (sk , sk2 );
224+ }
225+ }
226+
227+ /* Initial allocation may have already happened via setsockopt */
228+ if (!rcu_access_pointer (sk -> sk_reuseport_cb ))
229+ return reuseport_alloc (sk );
230+ return 0 ;
231+ }
232+
199233/**
200234 * udp_lib_get_port - UDP/-Lite port lookup for IPv4 and IPv6
201235 *
@@ -207,7 +241,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
207241 */
208242int udp_lib_get_port (struct sock * sk , unsigned short snum ,
209243 int (* saddr_comp )(const struct sock * sk1 ,
210- const struct sock * sk2 ),
244+ const struct sock * sk2 ,
245+ bool match_wildcard ),
211246 unsigned int hash2_nulladdr )
212247{
213248 struct udp_hslot * hslot , * hslot2 ;
@@ -290,6 +325,14 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
290325 udp_sk (sk )-> udp_port_hash = snum ;
291326 udp_sk (sk )-> udp_portaddr_hash ^= snum ;
292327 if (sk_unhashed (sk )) {
328+ if (sk -> sk_reuseport &&
329+ udp_reuseport_add_sock (sk , hslot , saddr_comp )) {
330+ inet_sk (sk )-> inet_num = 0 ;
331+ udp_sk (sk )-> udp_port_hash = 0 ;
332+ udp_sk (sk )-> udp_portaddr_hash ^= snum ;
333+ goto fail_unlock ;
334+ }
335+
293336 sk_nulls_add_node_rcu (sk , & hslot -> head );
294337 hslot -> count ++ ;
295338 sock_prot_inuse_add (sock_net (sk ), sk -> sk_prot , 1 );
@@ -309,13 +352,22 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
309352}
310353EXPORT_SYMBOL (udp_lib_get_port );
311354
312- static int ipv4_rcv_saddr_equal (const struct sock * sk1 , const struct sock * sk2 )
355+ /* match_wildcard == true: 0.0.0.0 equals to any IPv4 addresses
356+ * match_wildcard == false: addresses must be exactly the same, i.e.
357+ * 0.0.0.0 only equals to 0.0.0.0
358+ */
359+ static int ipv4_rcv_saddr_equal (const struct sock * sk1 , const struct sock * sk2 ,
360+ bool match_wildcard )
313361{
314362 struct inet_sock * inet1 = inet_sk (sk1 ), * inet2 = inet_sk (sk2 );
315363
316- return (!ipv6_only_sock (sk2 ) &&
317- (!inet1 -> inet_rcv_saddr || !inet2 -> inet_rcv_saddr ||
318- inet1 -> inet_rcv_saddr == inet2 -> inet_rcv_saddr ));
364+ if (!ipv6_only_sock (sk2 )) {
365+ if (inet1 -> inet_rcv_saddr == inet2 -> inet_rcv_saddr )
366+ return 1 ;
367+ if (!inet1 -> inet_rcv_saddr || !inet2 -> inet_rcv_saddr )
368+ return match_wildcard ;
369+ }
370+ return 0 ;
319371}
320372
321373static u32 udp4_portaddr_hash (const struct net * net , __be32 saddr ,
@@ -459,8 +511,14 @@ static struct sock *udp4_lib_lookup2(struct net *net,
459511 badness = score ;
460512 reuseport = sk -> sk_reuseport ;
461513 if (reuseport ) {
514+ struct sock * sk2 ;
462515 hash = udp_ehashfn (net , daddr , hnum ,
463516 saddr , sport );
517+ sk2 = reuseport_select_sock (sk , hash );
518+ if (sk2 ) {
519+ result = sk2 ;
520+ goto found ;
521+ }
464522 matches = 1 ;
465523 }
466524 } else if (score == badness && reuseport ) {
@@ -478,6 +536,7 @@ static struct sock *udp4_lib_lookup2(struct net *net,
478536 if (get_nulls_value (node ) != slot2 )
479537 goto begin ;
480538 if (result ) {
539+ found :
481540 if (unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
482541 result = NULL ;
483542 else if (unlikely (compute_score2 (result , net , saddr , sport ,
@@ -540,8 +599,14 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
540599 badness = score ;
541600 reuseport = sk -> sk_reuseport ;
542601 if (reuseport ) {
602+ struct sock * sk2 ;
543603 hash = udp_ehashfn (net , daddr , hnum ,
544604 saddr , sport );
605+ sk2 = reuseport_select_sock (sk , hash );
606+ if (sk2 ) {
607+ result = sk2 ;
608+ goto found ;
609+ }
545610 matches = 1 ;
546611 }
547612 } else if (score == badness && reuseport ) {
@@ -560,6 +625,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
560625 goto begin ;
561626
562627 if (result ) {
628+ found :
563629 if (unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
564630 result = NULL ;
565631 else if (unlikely (compute_score (result , net , saddr , hnum , sport ,
@@ -587,7 +653,8 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
587653struct sock * udp4_lib_lookup (struct net * net , __be32 saddr , __be16 sport ,
588654 __be32 daddr , __be16 dport , int dif )
589655{
590- return __udp4_lib_lookup (net , saddr , sport , daddr , dport , dif , & udp_table );
656+ return __udp4_lib_lookup (net , saddr , sport , daddr , dport , dif ,
657+ & udp_table );
591658}
592659EXPORT_SYMBOL_GPL (udp4_lib_lookup );
593660
@@ -1398,6 +1465,8 @@ void udp_lib_unhash(struct sock *sk)
13981465 hslot2 = udp_hashslot2 (udptable , udp_sk (sk )-> udp_portaddr_hash );
13991466
14001467 spin_lock_bh (& hslot -> lock );
1468+ if (rcu_access_pointer (sk -> sk_reuseport_cb ))
1469+ reuseport_detach_sock (sk );
14011470 if (sk_nulls_del_node_init_rcu (sk )) {
14021471 hslot -> count -- ;
14031472 inet_sk (sk )-> inet_num = 0 ;
@@ -1425,22 +1494,28 @@ void udp_lib_rehash(struct sock *sk, u16 newhash)
14251494 hslot2 = udp_hashslot2 (udptable , udp_sk (sk )-> udp_portaddr_hash );
14261495 nhslot2 = udp_hashslot2 (udptable , newhash );
14271496 udp_sk (sk )-> udp_portaddr_hash = newhash ;
1428- if (hslot2 != nhslot2 ) {
1497+
1498+ if (hslot2 != nhslot2 ||
1499+ rcu_access_pointer (sk -> sk_reuseport_cb )) {
14291500 hslot = udp_hashslot (udptable , sock_net (sk ),
14301501 udp_sk (sk )-> udp_port_hash );
14311502 /* we must lock primary chain too */
14321503 spin_lock_bh (& hslot -> lock );
1433-
1434- spin_lock (& hslot2 -> lock );
1435- hlist_nulls_del_init_rcu (& udp_sk (sk )-> udp_portaddr_node );
1436- hslot2 -> count -- ;
1437- spin_unlock (& hslot2 -> lock );
1438-
1439- spin_lock (& nhslot2 -> lock );
1440- hlist_nulls_add_head_rcu (& udp_sk (sk )-> udp_portaddr_node ,
1441- & nhslot2 -> head );
1442- nhslot2 -> count ++ ;
1443- spin_unlock (& nhslot2 -> lock );
1504+ if (rcu_access_pointer (sk -> sk_reuseport_cb ))
1505+ reuseport_detach_sock (sk );
1506+
1507+ if (hslot2 != nhslot2 ) {
1508+ spin_lock (& hslot2 -> lock );
1509+ hlist_nulls_del_init_rcu (& udp_sk (sk )-> udp_portaddr_node );
1510+ hslot2 -> count -- ;
1511+ spin_unlock (& hslot2 -> lock );
1512+
1513+ spin_lock (& nhslot2 -> lock );
1514+ hlist_nulls_add_head_rcu (& udp_sk (sk )-> udp_portaddr_node ,
1515+ & nhslot2 -> head );
1516+ nhslot2 -> count ++ ;
1517+ spin_unlock (& nhslot2 -> lock );
1518+ }
14441519
14451520 spin_unlock_bh (& hslot -> lock );
14461521 }
0 commit comments