Skip to content

Commit

Permalink
inet: Sanitize inet{,6} protocol demux.
Browse files Browse the repository at this point in the history
Don't pretend that inet_protos[] and inet6_protos[] are hashes, thay
are just a straight arrays.  Remove all unnecessary hash masking.

Document MAX_INET_PROTOS.

Use RAW_HTABLE_SIZE when appropriate.

Reported-by: Ben Hutchings <bhutchings@solarflare.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
davem330 committed Jun 20, 2012
1 parent 677a3d6 commit f9242b6
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 47 deletions.
7 changes: 5 additions & 2 deletions include/net/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
#include <linux/ipv6.h>
#endif

#define MAX_INET_PROTOS 256 /* Must be a power of 2 */

/* This is one larger than the largest protocol value that can be
* found in an ipv4 or ipv6 header. Since in both cases the protocol
* value is presented in a __u8, this is defined to be 256.
*/
#define MAX_INET_PROTOS 256

/* This is used to register protocols. */
struct net_protocol {
Expand Down
26 changes: 12 additions & 14 deletions net/ipv4/af_inet.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,18 @@ void build_ehash_secret(void)
}
EXPORT_SYMBOL(build_ehash_secret);

static inline int inet_netns_ok(struct net *net, int protocol)
static inline int inet_netns_ok(struct net *net, __u8 protocol)
{
int hash;
const struct net_protocol *ipprot;

if (net_eq(net, &init_net))
return 1;

hash = protocol & (MAX_INET_PROTOS - 1);
ipprot = rcu_dereference(inet_protos[hash]);

if (ipprot == NULL)
ipprot = rcu_dereference(inet_protos[protocol]);
if (ipprot == NULL) {
/* raw IP is OK */
return 1;
}
return ipprot->netns_ok;
}

Expand Down Expand Up @@ -1216,8 +1214,8 @@ EXPORT_SYMBOL(inet_sk_rebuild_header);

static int inet_gso_send_check(struct sk_buff *skb)
{
const struct iphdr *iph;
const struct net_protocol *ops;
const struct iphdr *iph;
int proto;
int ihl;
int err = -EINVAL;
Expand All @@ -1236,7 +1234,7 @@ static int inet_gso_send_check(struct sk_buff *skb)
__skb_pull(skb, ihl);
skb_reset_transport_header(skb);
iph = ip_hdr(skb);
proto = iph->protocol & (MAX_INET_PROTOS - 1);
proto = iph->protocol;
err = -EPROTONOSUPPORT;

rcu_read_lock();
Expand All @@ -1253,8 +1251,8 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb,
netdev_features_t features)
{
struct sk_buff *segs = ERR_PTR(-EINVAL);
struct iphdr *iph;
const struct net_protocol *ops;
struct iphdr *iph;
int proto;
int ihl;
int id;
Expand Down Expand Up @@ -1286,7 +1284,7 @@ static struct sk_buff *inet_gso_segment(struct sk_buff *skb,
skb_reset_transport_header(skb);
iph = ip_hdr(skb);
id = ntohs(iph->id);
proto = iph->protocol & (MAX_INET_PROTOS - 1);
proto = iph->protocol;
segs = ERR_PTR(-EPROTONOSUPPORT);

rcu_read_lock();
Expand Down Expand Up @@ -1340,7 +1338,7 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head,
goto out;
}

proto = iph->protocol & (MAX_INET_PROTOS - 1);
proto = iph->protocol;

rcu_read_lock();
ops = rcu_dereference(inet_protos[proto]);
Expand Down Expand Up @@ -1398,11 +1396,11 @@ static struct sk_buff **inet_gro_receive(struct sk_buff **head,

static int inet_gro_complete(struct sk_buff *skb)
{
const struct net_protocol *ops;
__be16 newlen = htons(skb->len - skb_network_offset(skb));
struct iphdr *iph = ip_hdr(skb);
int proto = iph->protocol & (MAX_INET_PROTOS - 1);
const struct net_protocol *ops;
int proto = iph->protocol;
int err = -ENOSYS;
__be16 newlen = htons(skb->len - skb_network_offset(skb));

csum_replace2(&iph->check, iph->tot_len, newlen);
iph->tot_len = newlen;
Expand Down
9 changes: 4 additions & 5 deletions net/ipv4/icmp.c
Original file line number Diff line number Diff line change
Expand Up @@ -637,12 +637,12 @@ EXPORT_SYMBOL(icmp_send);

static void icmp_unreach(struct sk_buff *skb)
{
const struct net_protocol *ipprot;
const struct iphdr *iph;
struct icmphdr *icmph;
int hash, protocol;
const struct net_protocol *ipprot;
u32 info = 0;
struct net *net;
u32 info = 0;
int protocol;

net = dev_net(skb_dst(skb)->dev);

Expand Down Expand Up @@ -731,9 +731,8 @@ static void icmp_unreach(struct sk_buff *skb)
*/
raw_icmp_error(skb, protocol, info);

hash = protocol & (MAX_INET_PROTOS - 1);
rcu_read_lock();
ipprot = rcu_dereference(inet_protos[hash]);
ipprot = rcu_dereference(inet_protos[protocol]);
if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, info);
rcu_read_unlock();
Expand Down
5 changes: 2 additions & 3 deletions net/ipv4/ip_input.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,13 @@ static int ip_local_deliver_finish(struct sk_buff *skb)
rcu_read_lock();
{
int protocol = ip_hdr(skb)->protocol;
int hash, raw;
const struct net_protocol *ipprot;
int raw;

resubmit:
raw = raw_local_deliver(skb, protocol);

hash = protocol & (MAX_INET_PROTOS - 1);
ipprot = rcu_dereference(inet_protos[hash]);
ipprot = rcu_dereference(inet_protos[protocol]);
if (ipprot != NULL) {
int ret;

Expand Down
8 changes: 3 additions & 5 deletions net/ipv4/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ const struct net_protocol __rcu *inet_protos[MAX_INET_PROTOS] __read_mostly;

int inet_add_protocol(const struct net_protocol *prot, unsigned char protocol)
{
int hash = protocol & (MAX_INET_PROTOS - 1);

return !cmpxchg((const struct net_protocol **)&inet_protos[hash],
return !cmpxchg((const struct net_protocol **)&inet_protos[protocol],
NULL, prot) ? 0 : -1;
}
EXPORT_SYMBOL(inet_add_protocol);
Expand All @@ -49,9 +47,9 @@ EXPORT_SYMBOL(inet_add_protocol);

int inet_del_protocol(const struct net_protocol *prot, unsigned char protocol)
{
int ret, hash = protocol & (MAX_INET_PROTOS - 1);
int ret;

ret = (cmpxchg((const struct net_protocol **)&inet_protos[hash],
ret = (cmpxchg((const struct net_protocol **)&inet_protos[protocol],
prot, NULL) == prot) ? 0 : -1;

synchronize_net();
Expand Down
7 changes: 2 additions & 5 deletions net/ipv6/icmp.c
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info)
{
const struct inet6_protocol *ipprot;
int inner_offset;
int hash;
u8 nexthdr;
__be16 frag_off;
u8 nexthdr;

if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
return;
Expand All @@ -629,10 +628,8 @@ static void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info)
--ANK (980726)
*/

hash = nexthdr & (MAX_INET_PROTOS - 1);

rcu_read_lock();
ipprot = rcu_dereference(inet6_protos[hash]);
ipprot = rcu_dereference(inet6_protos[nexthdr]);
if (ipprot && ipprot->err_handler)
ipprot->err_handler(skb, NULL, type, code, inner_offset, info);
rcu_read_unlock();
Expand Down
9 changes: 3 additions & 6 deletions net/ipv6/ip6_input.c
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,12 @@ int ipv6_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt

static int ip6_input_finish(struct sk_buff *skb)
{
struct net *net = dev_net(skb_dst(skb)->dev);
const struct inet6_protocol *ipprot;
struct inet6_dev *idev;
unsigned int nhoff;
int nexthdr;
bool raw;
u8 hash;
struct inet6_dev *idev;
struct net *net = dev_net(skb_dst(skb)->dev);

/*
* Parse extension headers
Expand All @@ -189,9 +188,7 @@ static int ip6_input_finish(struct sk_buff *skb)
nexthdr = skb_network_header(skb)[nhoff];

raw = raw6_local_deliver(skb, nexthdr);

hash = nexthdr & (MAX_INET_PROTOS - 1);
if ((ipprot = rcu_dereference(inet6_protos[hash])) != NULL) {
if ((ipprot = rcu_dereference(inet6_protos[nexthdr])) != NULL) {
int ret;

if (ipprot->flags & INET6_PROTO_FINAL) {
Expand Down
8 changes: 3 additions & 5 deletions net/ipv6/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ const struct inet6_protocol __rcu *inet6_protos[MAX_INET_PROTOS] __read_mostly;

int inet6_add_protocol(const struct inet6_protocol *prot, unsigned char protocol)
{
int hash = protocol & (MAX_INET_PROTOS - 1);

return !cmpxchg((const struct inet6_protocol **)&inet6_protos[hash],
return !cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol],
NULL, prot) ? 0 : -1;
}
EXPORT_SYMBOL(inet6_add_protocol);
Expand All @@ -42,9 +40,9 @@ EXPORT_SYMBOL(inet6_add_protocol);

int inet6_del_protocol(const struct inet6_protocol *prot, unsigned char protocol)
{
int ret, hash = protocol & (MAX_INET_PROTOS - 1);
int ret;

ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[hash],
ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol],
prot, NULL) == prot) ? 0 : -1;

synchronize_net();
Expand Down
4 changes: 2 additions & 2 deletions net/ipv6/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
saddr = &ipv6_hdr(skb)->saddr;
daddr = saddr + 1;

hash = nexthdr & (MAX_INET_PROTOS - 1);
hash = nexthdr & (RAW_HTABLE_SIZE - 1);

read_lock(&raw_v6_hashinfo.lock);
sk = sk_head(&raw_v6_hashinfo.ht[hash]);
Expand Down Expand Up @@ -229,7 +229,7 @@ bool raw6_local_deliver(struct sk_buff *skb, int nexthdr)
{
struct sock *raw_sk;

raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (MAX_INET_PROTOS - 1)]);
raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (RAW_HTABLE_SIZE - 1)]);
if (raw_sk && !ipv6_raw_deliver(skb, nexthdr))
raw_sk = NULL;

Expand Down

0 comments on commit f9242b6

Please sign in to comment.