Skip to content

Commit 36397a1

Browse files
author
Martin KaFai Lau
committed
Merge branch 'Add SO_REUSEPORT support for TC bpf_sk_assign'
Lorenz Bauer says: ==================== We want to replace iptables TPROXY with a BPF program at TC ingress. To make this work in all cases we need to assign a SO_REUSEPORT socket to an skb, which is currently prohibited. This series adds support for such sockets to bpf_sk_assing. I did some refactoring to cut down on the amount of duplicate code. The key to this is to use INDIRECT_CALL in the reuseport helpers. To show that this approach is not just beneficial to TC sk_assign I removed duplicate code for bpf_sk_lookup as well. Joint work with Daniel Borkmann. Signed-off-by: Lorenz Bauer <lmb@isovalent.com> --- Changes in v6: - Reject unhashed UDP sockets in bpf_sk_assign to avoid ref leak - Link to v5: https://lore.kernel.org/r/20230613-so-reuseport-v5-0-f6686a0dbce0@isovalent.com Changes in v5: - Drop reuse_sk == sk check in inet[6]_steal_stock (Kuniyuki) - Link to v4: https://lore.kernel.org/r/20230613-so-reuseport-v4-0-4ece76708bba@isovalent.com Changes in v4: - WARN_ON_ONCE if reuseport socket is refcounted (Kuniyuki) - Use inet[6]_ehashfn_t to shorten function declarations (Kuniyuki) - Shuffle documentation patch around (Kuniyuki) - Update commit message to explain why IPv6 needs EXPORT_SYMBOL - Link to v3: https://lore.kernel.org/r/20230613-so-reuseport-v3-0-907b4cbb7b99@isovalent.com Changes in v3: - Fix warning re udp_ehashfn and udp6_ehashfn (Simon) - Return higher scoring connected UDP reuseport sockets (Kuniyuki) - Fix ipv6 module builds - Link to v2: https://lore.kernel.org/r/20230613-so-reuseport-v2-0-b7c69a342613@isovalent.com Changes in v2: - Correct commit abbrev length (Kuniyuki) - Reduce duplication (Kuniyuki) - Add checks on sk_state (Martin) - Split exporting inet[6]_lookup_reuseport into separate patch (Eric) --- Daniel Borkmann (1): selftests/bpf: Test that SO_REUSEPORT can be used with sk_assign helper ==================== Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
2 parents 7b2b201 + 22408d5 commit 36397a1

File tree

13 files changed

+661
-178
lines changed

13 files changed

+661
-178
lines changed

include/net/inet6_hashtables.h

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ struct sock *__inet6_lookup_established(struct net *net,
4848
const u16 hnum, const int dif,
4949
const int sdif);
5050

51+
typedef u32 (inet6_ehashfn_t)(const struct net *net,
52+
const struct in6_addr *laddr, const u16 lport,
53+
const struct in6_addr *faddr, const __be16 fport);
54+
55+
inet6_ehashfn_t inet6_ehashfn;
56+
57+
INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);
58+
59+
struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
60+
struct sk_buff *skb, int doff,
61+
const struct in6_addr *saddr,
62+
__be16 sport,
63+
const struct in6_addr *daddr,
64+
unsigned short hnum,
65+
inet6_ehashfn_t *ehashfn);
66+
5167
struct sock *inet6_lookup_listener(struct net *net,
5268
struct inet_hashinfo *hashinfo,
5369
struct sk_buff *skb, int doff,
@@ -57,6 +73,15 @@ struct sock *inet6_lookup_listener(struct net *net,
5773
const unsigned short hnum,
5874
const int dif, const int sdif);
5975

76+
struct sock *inet6_lookup_run_sk_lookup(struct net *net,
77+
int protocol,
78+
struct sk_buff *skb, int doff,
79+
const struct in6_addr *saddr,
80+
const __be16 sport,
81+
const struct in6_addr *daddr,
82+
const u16 hnum, const int dif,
83+
inet6_ehashfn_t *ehashfn);
84+
6085
static inline struct sock *__inet6_lookup(struct net *net,
6186
struct inet_hashinfo *hashinfo,
6287
struct sk_buff *skb, int doff,
@@ -78,21 +103,67 @@ static inline struct sock *__inet6_lookup(struct net *net,
78103
daddr, hnum, dif, sdif);
79104
}
80105

106+
static inline
107+
struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
108+
const struct in6_addr *saddr, const __be16 sport,
109+
const struct in6_addr *daddr, const __be16 dport,
110+
bool *refcounted, inet6_ehashfn_t *ehashfn)
111+
{
112+
struct sock *sk, *reuse_sk;
113+
bool prefetched;
114+
115+
sk = skb_steal_sock(skb, refcounted, &prefetched);
116+
if (!sk)
117+
return NULL;
118+
119+
if (!prefetched)
120+
return sk;
121+
122+
if (sk->sk_protocol == IPPROTO_TCP) {
123+
if (sk->sk_state != TCP_LISTEN)
124+
return sk;
125+
} else if (sk->sk_protocol == IPPROTO_UDP) {
126+
if (sk->sk_state != TCP_CLOSE)
127+
return sk;
128+
} else {
129+
return sk;
130+
}
131+
132+
reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
133+
saddr, sport, daddr, ntohs(dport),
134+
ehashfn);
135+
if (!reuse_sk)
136+
return sk;
137+
138+
/* We've chosen a new reuseport sock which is never refcounted. This
139+
* implies that sk also isn't refcounted.
140+
*/
141+
WARN_ON_ONCE(*refcounted);
142+
143+
return reuse_sk;
144+
}
145+
81146
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
82147
struct sk_buff *skb, int doff,
83148
const __be16 sport,
84149
const __be16 dport,
85150
int iif, int sdif,
86151
bool *refcounted)
87152
{
88-
struct sock *sk = skb_steal_sock(skb, refcounted);
89-
153+
struct net *net = dev_net(skb_dst(skb)->dev);
154+
const struct ipv6hdr *ip6h = ipv6_hdr(skb);
155+
struct sock *sk;
156+
157+
sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
158+
refcounted, inet6_ehashfn);
159+
if (IS_ERR(sk))
160+
return NULL;
90161
if (sk)
91162
return sk;
92163

93-
return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
94-
doff, &ipv6_hdr(skb)->saddr, sport,
95-
&ipv6_hdr(skb)->daddr, ntohs(dport),
164+
return __inet6_lookup(net, hashinfo, skb,
165+
doff, &ip6h->saddr, sport,
166+
&ip6h->daddr, ntohs(dport),
96167
iif, sdif, refcounted);
97168
}
98169

include/net/inet_hashtables.h

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,27 @@ struct sock *__inet_lookup_established(struct net *net,
379379
const __be32 daddr, const u16 hnum,
380380
const int dif, const int sdif);
381381

382+
typedef u32 (inet_ehashfn_t)(const struct net *net,
383+
const __be32 laddr, const __u16 lport,
384+
const __be32 faddr, const __be16 fport);
385+
386+
inet_ehashfn_t inet_ehashfn;
387+
388+
INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
389+
390+
struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
391+
struct sk_buff *skb, int doff,
392+
__be32 saddr, __be16 sport,
393+
__be32 daddr, unsigned short hnum,
394+
inet_ehashfn_t *ehashfn);
395+
396+
struct sock *inet_lookup_run_sk_lookup(struct net *net,
397+
int protocol,
398+
struct sk_buff *skb, int doff,
399+
__be32 saddr, __be16 sport,
400+
__be32 daddr, u16 hnum, const int dif,
401+
inet_ehashfn_t *ehashfn);
402+
382403
static inline struct sock *
383404
inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
384405
const __be32 saddr, const __be16 sport,
@@ -428,6 +449,46 @@ static inline struct sock *inet_lookup(struct net *net,
428449
return sk;
429450
}
430451

452+
static inline
453+
struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
454+
const __be32 saddr, const __be16 sport,
455+
const __be32 daddr, const __be16 dport,
456+
bool *refcounted, inet_ehashfn_t *ehashfn)
457+
{
458+
struct sock *sk, *reuse_sk;
459+
bool prefetched;
460+
461+
sk = skb_steal_sock(skb, refcounted, &prefetched);
462+
if (!sk)
463+
return NULL;
464+
465+
if (!prefetched)
466+
return sk;
467+
468+
if (sk->sk_protocol == IPPROTO_TCP) {
469+
if (sk->sk_state != TCP_LISTEN)
470+
return sk;
471+
} else if (sk->sk_protocol == IPPROTO_UDP) {
472+
if (sk->sk_state != TCP_CLOSE)
473+
return sk;
474+
} else {
475+
return sk;
476+
}
477+
478+
reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
479+
saddr, sport, daddr, ntohs(dport),
480+
ehashfn);
481+
if (!reuse_sk)
482+
return sk;
483+
484+
/* We've chosen a new reuseport sock which is never refcounted. This
485+
* implies that sk also isn't refcounted.
486+
*/
487+
WARN_ON_ONCE(*refcounted);
488+
489+
return reuse_sk;
490+
}
491+
431492
static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
432493
struct sk_buff *skb,
433494
int doff,
@@ -436,22 +497,23 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
436497
const int sdif,
437498
bool *refcounted)
438499
{
439-
struct sock *sk = skb_steal_sock(skb, refcounted);
500+
struct net *net = dev_net(skb_dst(skb)->dev);
440501
const struct iphdr *iph = ip_hdr(skb);
502+
struct sock *sk;
441503

504+
sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
505+
refcounted, inet_ehashfn);
506+
if (IS_ERR(sk))
507+
return NULL;
442508
if (sk)
443509
return sk;
444510

445-
return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
511+
return __inet_lookup(net, hashinfo, skb,
446512
doff, iph->saddr, sport,
447513
iph->daddr, dport, inet_iif(skb), sdif,
448514
refcounted);
449515
}
450516

451-
u32 inet6_ehashfn(const struct net *net,
452-
const struct in6_addr *laddr, const u16 lport,
453-
const struct in6_addr *faddr, const __be16 fport);
454-
455517
static inline void sk_daddr_set(struct sock *sk, __be32 addr)
456518
{
457519
sk->sk_daddr = addr; /* alias of inet_daddr */

include/net/sock.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,20 +2815,23 @@ sk_is_refcounted(struct sock *sk)
28152815
* skb_steal_sock - steal a socket from an sk_buff
28162816
* @skb: sk_buff to steal the socket from
28172817
* @refcounted: is set to true if the socket is reference-counted
2818+
* @prefetched: is set to true if the socket was assigned from bpf
28182819
*/
28192820
static inline struct sock *
2820-
skb_steal_sock(struct sk_buff *skb, bool *refcounted)
2821+
skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
28212822
{
28222823
if (skb->sk) {
28232824
struct sock *sk = skb->sk;
28242825

28252826
*refcounted = true;
2826-
if (skb_sk_is_prefetched(skb))
2827+
*prefetched = skb_sk_is_prefetched(skb);
2828+
if (*prefetched)
28272829
*refcounted = sk_is_refcounted(sk);
28282830
skb->destructor = NULL;
28292831
skb->sk = NULL;
28302832
return sk;
28312833
}
2834+
*prefetched = false;
28322835
*refcounted = false;
28332836
return NULL;
28342837
}

include/uapi/linux/bpf.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4198,9 +4198,6 @@ union bpf_attr {
41984198
* **-EOPNOTSUPP** if the operation is not supported, for example
41994199
* a call from outside of TC ingress.
42004200
*
4201-
* **-ESOCKTNOSUPPORT** if the socket type is not supported
4202-
* (reuseport).
4203-
*
42044201
* long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
42054202
* Description
42064203
* Helper is overloaded depending on BPF program type. This

net/core/filter.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7351,8 +7351,8 @@ BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags)
73517351
return -EOPNOTSUPP;
73527352
if (unlikely(dev_net(skb->dev) != sock_net(sk)))
73537353
return -ENETUNREACH;
7354-
if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
7355-
return -ESOCKTNOSUPPORT;
7354+
if (sk_unhashed(sk))
7355+
return -EOPNOTSUPP;
73567356
if (sk_is_refcounted(sk) &&
73577357
unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
73587358
return -ENOENT;

net/ipv4/inet_hashtables.c

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
#include <net/tcp.h>
2929
#include <net/sock_reuseport.h>
3030

31-
static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
32-
const __u16 lport, const __be32 faddr,
33-
const __be16 fport)
31+
u32 inet_ehashfn(const struct net *net, const __be32 laddr,
32+
const __u16 lport, const __be32 faddr,
33+
const __be16 fport)
3434
{
3535
static u32 inet_ehash_secret __read_mostly;
3636

@@ -39,6 +39,7 @@ static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
3939
return __inet_ehashfn(laddr, lport, faddr, fport,
4040
inet_ehash_secret + net_hash_mix(net));
4141
}
42+
EXPORT_SYMBOL_GPL(inet_ehashfn);
4243

4344
/* This function handles inet_sock, but also timewait and request sockets
4445
* for IPv4/IPv6.
@@ -332,20 +333,40 @@ static inline int compute_score(struct sock *sk, struct net *net,
332333
return score;
333334
}
334335

335-
static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
336-
struct sk_buff *skb, int doff,
337-
__be32 saddr, __be16 sport,
338-
__be32 daddr, unsigned short hnum)
336+
INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
337+
338+
/**
339+
* inet_lookup_reuseport() - execute reuseport logic on AF_INET socket if necessary.
340+
* @net: network namespace.
341+
* @sk: AF_INET socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
342+
* @skb: context for a potential SK_REUSEPORT program.
343+
* @doff: header offset.
344+
* @saddr: source address.
345+
* @sport: source port.
346+
* @daddr: destination address.
347+
* @hnum: destination port in host byte order.
348+
* @ehashfn: hash function used to generate the fallback hash.
349+
*
350+
* Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
351+
* the selected sock or an error.
352+
*/
353+
struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
354+
struct sk_buff *skb, int doff,
355+
__be32 saddr, __be16 sport,
356+
__be32 daddr, unsigned short hnum,
357+
inet_ehashfn_t *ehashfn)
339358
{
340359
struct sock *reuse_sk = NULL;
341360
u32 phash;
342361

343362
if (sk->sk_reuseport) {
344-
phash = inet_ehashfn(net, daddr, hnum, saddr, sport);
363+
phash = INDIRECT_CALL_2(ehashfn, udp_ehashfn, inet_ehashfn,
364+
net, daddr, hnum, saddr, sport);
345365
reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
346366
}
347367
return reuse_sk;
348368
}
369+
EXPORT_SYMBOL_GPL(inet_lookup_reuseport);
349370

350371
/*
351372
* Here are some nice properties to exploit here. The BSD API
@@ -369,8 +390,8 @@ static struct sock *inet_lhash2_lookup(struct net *net,
369390
sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
370391
score = compute_score(sk, net, hnum, daddr, dif, sdif);
371392
if (score > hiscore) {
372-
result = lookup_reuseport(net, sk, skb, doff,
373-
saddr, sport, daddr, hnum);
393+
result = inet_lookup_reuseport(net, sk, skb, doff,
394+
saddr, sport, daddr, hnum, inet_ehashfn);
374395
if (result)
375396
return result;
376397

@@ -382,24 +403,23 @@ static struct sock *inet_lhash2_lookup(struct net *net,
382403
return result;
383404
}
384405

385-
static inline struct sock *inet_lookup_run_bpf(struct net *net,
386-
struct inet_hashinfo *hashinfo,
387-
struct sk_buff *skb, int doff,
388-
__be32 saddr, __be16 sport,
389-
__be32 daddr, u16 hnum, const int dif)
406+
struct sock *inet_lookup_run_sk_lookup(struct net *net,
407+
int protocol,
408+
struct sk_buff *skb, int doff,
409+
__be32 saddr, __be16 sport,
410+
__be32 daddr, u16 hnum, const int dif,
411+
inet_ehashfn_t *ehashfn)
390412
{
391413
struct sock *sk, *reuse_sk;
392414
bool no_reuseport;
393415

394-
if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
395-
return NULL; /* only TCP is supported */
396-
397-
no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport,
416+
no_reuseport = bpf_sk_lookup_run_v4(net, protocol, saddr, sport,
398417
daddr, hnum, dif, &sk);
399418
if (no_reuseport || IS_ERR_OR_NULL(sk))
400419
return sk;
401420

402-
reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
421+
reuse_sk = inet_lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum,
422+
ehashfn);
403423
if (reuse_sk)
404424
sk = reuse_sk;
405425
return sk;
@@ -417,9 +437,11 @@ struct sock *__inet_lookup_listener(struct net *net,
417437
unsigned int hash2;
418438

419439
/* Lookup redirect from BPF */
420-
if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
421-
result = inet_lookup_run_bpf(net, hashinfo, skb, doff,
422-
saddr, sport, daddr, hnum, dif);
440+
if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
441+
hashinfo == net->ipv4.tcp_death_row.hashinfo) {
442+
result = inet_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
443+
saddr, sport, daddr, hnum, dif,
444+
inet_ehashfn);
423445
if (result)
424446
goto done;
425447
}

0 commit comments

Comments
 (0)