diff --git a/drivers/net/gtp.c b/drivers/net/gtp.c index 375adfb743235f..427b91aca50d3a 100644 --- a/drivers/net/gtp.c +++ b/drivers/net/gtp.c @@ -141,7 +141,7 @@ static u32 ipv6_hashfn(const struct in6_addr *ip6) } /* Resolve a PDP context structure based on the 64bit TID. */ -static struct pdp_ctx *gtp0_pdp_find(struct gtp_dev *gtp, u64 tid) +static struct pdp_ctx *gtp0_pdp_find(struct gtp_dev *gtp, u64 tid, u16 family) { struct hlist_head *head; struct pdp_ctx *pdp; @@ -149,7 +149,8 @@ static struct pdp_ctx *gtp0_pdp_find(struct gtp_dev *gtp, u64 tid) head = >p->tid_hash[gtp0_hashfn(tid) % gtp->hash_size]; hlist_for_each_entry_rcu(pdp, head, hlist_tid) { - if (pdp->gtp_version == GTP_V0 && + if (pdp->af == family && + pdp->gtp_version == GTP_V0 && pdp->u.v0.tid == tid) return pdp; } @@ -157,7 +158,7 @@ static struct pdp_ctx *gtp0_pdp_find(struct gtp_dev *gtp, u64 tid) } /* Resolve a PDP context structure based on the 32bit TEI. */ -static struct pdp_ctx *gtp1_pdp_find(struct gtp_dev *gtp, u32 tid) +static struct pdp_ctx *gtp1_pdp_find(struct gtp_dev *gtp, u32 tid, u16 family) { struct hlist_head *head; struct pdp_ctx *pdp; @@ -165,7 +166,8 @@ static struct pdp_ctx *gtp1_pdp_find(struct gtp_dev *gtp, u32 tid) head = >p->tid_hash[gtp1u_hashfn(tid) % gtp->hash_size]; hlist_for_each_entry_rcu(pdp, head, hlist_tid) { - if (pdp->gtp_version == GTP_V1 && + if (pdp->af == family && + pdp->gtp_version == GTP_V1 && pdp->u.v1.i_tei == tid) return pdp; } @@ -305,15 +307,8 @@ static int gtp_inner_proto(struct sk_buff *skb, unsigned int hdrlen, } static int gtp_rx(struct pdp_ctx *pctx, struct sk_buff *skb, - unsigned int hdrlen, unsigned int role) + unsigned int hdrlen, unsigned int role, __u16 inner_proto) { - __u16 inner_proto; - - if (gtp_inner_proto(skb, hdrlen, &inner_proto) < 0) { - netdev_dbg(pctx->dev, "GTP packet does not encapsulate an IP packet\n"); - return -1; - } - if (!gtp_check_ms(skb, pctx, hdrlen, role, inner_proto)) { netdev_dbg(pctx->dev, "No PDP ctx for this MS\n"); return 1; @@ -562,6 +557,21 @@ static int gtp0_handle_echo_resp(struct gtp_dev *gtp, struct sk_buff *skb) msg, 0, GTP_GENL_MCGRP, GFP_ATOMIC); } +static int gtp_proto_to_family(__u16 proto) +{ + switch (proto) { + case ETH_P_IP: + return AF_INET; + case ETH_P_IPV6: + return AF_INET6; + default: + WARN_ON_ONCE(1); + break; + } + + return AF_UNSPEC; +} + /* 1 means pass up to the stack, -1 means drop and 0 means decapsulated. */ static int gtp0_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) { @@ -569,6 +579,7 @@ static int gtp0_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) sizeof(struct gtp0_header); struct gtp0_header *gtp0; struct pdp_ctx *pctx; + __u16 inner_proto; if (!pskb_may_pull(skb, hdrlen)) return -1; @@ -591,13 +602,19 @@ static int gtp0_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) if (gtp0->type != GTP_TPDU) return 1; - pctx = gtp0_pdp_find(gtp, be64_to_cpu(gtp0->tid)); + if (gtp_inner_proto(skb, hdrlen, &inner_proto) < 0) { + netdev_dbg(gtp->dev, "GTP packet does not encapsulate an IP packet\n"); + return -1; + } + + pctx = gtp0_pdp_find(gtp, be64_to_cpu(gtp0->tid), + gtp_proto_to_family(inner_proto)); if (!pctx) { netdev_dbg(gtp->dev, "No PDP ctx to decap skb=%p\n", skb); return 1; } - return gtp_rx(pctx, skb, hdrlen, gtp->role); + return gtp_rx(pctx, skb, hdrlen, gtp->role, inner_proto); } /* msg_type has to be GTP_ECHO_REQ or GTP_ECHO_RSP */ @@ -768,6 +785,7 @@ static int gtp1u_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) sizeof(struct gtp1_header); struct gtp1_header *gtp1; struct pdp_ctx *pctx; + __u16 inner_proto; if (!pskb_may_pull(skb, hdrlen)) return -1; @@ -803,9 +821,15 @@ static int gtp1u_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) if (!pskb_may_pull(skb, hdrlen)) return -1; + if (gtp_inner_proto(skb, hdrlen, &inner_proto) < 0) { + netdev_dbg(gtp->dev, "GTP packet does not encapsulate an IP packet\n"); + return -1; + } + gtp1 = (struct gtp1_header *)(skb->data + sizeof(struct udphdr)); - pctx = gtp1_pdp_find(gtp, ntohl(gtp1->tid)); + pctx = gtp1_pdp_find(gtp, ntohl(gtp1->tid), + gtp_proto_to_family(inner_proto)); if (!pctx) { netdev_dbg(gtp->dev, "No PDP ctx to decap skb=%p\n", skb); return 1; @@ -815,7 +839,7 @@ static int gtp1u_udp_encap_recv(struct gtp_dev *gtp, struct sk_buff *skb) gtp_parse_exthdrs(skb, &hdrlen) < 0) return -1; - return gtp_rx(pctx, skb, hdrlen, gtp->role); + return gtp_rx(pctx, skb, hdrlen, gtp->role, inner_proto); } static void __gtp_encap_destroy(struct sock *sk) @@ -1843,10 +1867,12 @@ static struct pdp_ctx *gtp_pdp_add(struct gtp_dev *gtp, struct sock *sk, found = true; if (version == GTP_V0) pctx_tid = gtp0_pdp_find(gtp, - nla_get_u64(info->attrs[GTPA_TID])); + nla_get_u64(info->attrs[GTPA_TID]), + family); else if (version == GTP_V1) pctx_tid = gtp1_pdp_find(gtp, - nla_get_u32(info->attrs[GTPA_I_TEI])); + nla_get_u32(info->attrs[GTPA_I_TEI]), + family); if (pctx_tid) found = true; @@ -2034,6 +2060,12 @@ static struct pdp_ctx *gtp_find_pdp_by_link(struct net *net, struct nlattr *nla[]) { struct gtp_dev *gtp; + int family; + + if (nla[GTPA_FAMILY]) + family = nla_get_u8(nla[GTPA_FAMILY]); + else + family = AF_INET; gtp = gtp_find_dev(net, nla); if (!gtp) @@ -2042,10 +2074,16 @@ static struct pdp_ctx *gtp_find_pdp_by_link(struct net *net, if (nla[GTPA_MS_ADDRESS]) { __be32 ip = nla_get_be32(nla[GTPA_MS_ADDRESS]); + if (family != AF_INET) + return ERR_PTR(-EINVAL); + return ipv4_pdp_find(gtp, ip); } else if (nla[GTPA_MS_ADDR6]) { struct in6_addr addr = nla_get_in6_addr(nla[GTPA_MS_ADDR6]); + if (family != AF_INET6) + return ERR_PTR(-EINVAL); + if (addr.s6_addr32[2] || addr.s6_addr32[3]) return ERR_PTR(-EADDRNOTAVAIL); @@ -2054,10 +2092,13 @@ static struct pdp_ctx *gtp_find_pdp_by_link(struct net *net, } else if (nla[GTPA_VERSION]) { u32 gtp_version = nla_get_u32(nla[GTPA_VERSION]); - if (gtp_version == GTP_V0 && nla[GTPA_TID]) - return gtp0_pdp_find(gtp, nla_get_u64(nla[GTPA_TID])); - else if (gtp_version == GTP_V1 && nla[GTPA_I_TEI]) - return gtp1_pdp_find(gtp, nla_get_u32(nla[GTPA_I_TEI])); + if (gtp_version == GTP_V0 && nla[GTPA_TID]) { + return gtp0_pdp_find(gtp, nla_get_u64(nla[GTPA_TID]), + family); + } else if (gtp_version == GTP_V1 && nla[GTPA_I_TEI]) { + return gtp1_pdp_find(gtp, nla_get_u32(nla[GTPA_I_TEI]), + family); + } } return ERR_PTR(-EINVAL);