diff --git a/include/net/mptcp.h b/include/net/mptcp.h index 5694370be3d42..cea69c8015959 100644 --- a/include/net/mptcp.h +++ b/include/net/mptcp.h @@ -34,6 +34,13 @@ struct mptcp_ext { /* one byte hole */ }; +#define MPTCP_RM_IDS_MAX 8 + +struct mptcp_rm_list { + u8 ids[MPTCP_RM_IDS_MAX]; + u8 nr; +}; + struct mptcp_out_options { #if IS_ENABLED(CONFIG_MPTCP) u16 suboptions; @@ -48,7 +55,7 @@ struct mptcp_out_options { u8 addr_id; u16 port; u64 ahmac; - u8 rm_id; + struct mptcp_rm_list rm_list; u8 join_id; u8 backup; u32 nonce; diff --git a/net/mptcp/options.c b/net/mptcp/options.c index 308abf5e25be1..f8a99ef34a071 100644 --- a/net/mptcp/options.c +++ b/net/mptcp/options.c @@ -674,20 +674,25 @@ static bool mptcp_established_options_rm_addr(struct sock *sk, { struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_sock *msk = mptcp_sk(subflow->conn); - u8 rm_id; + struct mptcp_rm_list rm_list; + int i, len; if (!mptcp_pm_should_rm_signal(msk) || - !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_id))) + !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_list))) return false; - if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE) + len = mptcp_rm_addr_len(rm_list); + if (len < 0) + return false; + if (remaining < len) return false; - *size = TCPOLEN_MPTCP_RM_ADDR_BASE; + *size = len; opts->suboptions |= OPTION_MPTCP_RM_ADDR; - opts->rm_id = rm_id; + opts->rm_list = rm_list; - pr_debug("rm_id=%d", opts->rm_id); + for (i = 0; i < opts->rm_list.nr; i++) + pr_debug("rm_list_ids[%d]=%d", i, opts->rm_list.ids[i]); return true; } @@ -1217,9 +1222,23 @@ void mptcp_write_options(__be32 *ptr, const struct tcp_sock *tp, } if (OPTION_MPTCP_RM_ADDR & opts->suboptions) { + u8 i; + + for (i = opts->rm_list.nr; i < MPTCP_RM_IDS_MAX; i++) + opts->rm_list.ids[i] = TCPOPT_NOP; *ptr++ = mptcp_option(MPTCPOPT_RM_ADDR, - TCPOLEN_MPTCP_RM_ADDR_BASE, - 0, opts->rm_id); + TCPOLEN_MPTCP_RM_ADDR_BASE + opts->rm_list.nr, + 0, opts->rm_list.ids[0]); + if (opts->rm_list.nr > 1) { + put_unaligned_be32(opts->rm_list.ids[1] << 24 | opts->rm_list.ids[2] << 16 | + opts->rm_list.ids[3] << 8 | opts->rm_list.ids[4], ptr); + ptr += 1; + } + if (opts->rm_list.nr > 5) { + put_unaligned_be32(opts->rm_list.ids[5] << 24 | opts->rm_list.ids[6] << 16 | + opts->rm_list.ids[7] << 8 | TCPOPT_NOP, ptr); + ptr += 1; + } } if (OPTION_MPTCP_PRIO & opts->suboptions) { diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index c22481da40730..1c7d33c3904dd 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -257,7 +257,7 @@ bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining, } bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, - u8 *rm_id) + struct mptcp_rm_list *rm_list) { int ret = false; @@ -270,7 +270,8 @@ bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE) goto out_unlock; - *rm_id = msk->pm.rm_id; + rm_list->ids[0] = msk->pm.rm_id; + rm_list->nr = 1; WRITE_ONCE(msk->pm.addr_signal, 0); ret = true; diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index d5326cebe333e..31dcc3840a566 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -60,7 +60,7 @@ #define TCPOLEN_MPTCP_ADD_ADDR6_BASE 20 #define TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT 24 #define TCPOLEN_MPTCP_PORT_LEN 4 -#define TCPOLEN_MPTCP_RM_ADDR_BASE 4 +#define TCPOLEN_MPTCP_RM_ADDR_BASE 3 #define TCPOLEN_MPTCP_PRIO 3 #define TCPOLEN_MPTCP_PRIO_ALIGN 4 #define TCPOLEN_MPTCP_FASTCLOSE 12 @@ -707,10 +707,18 @@ static inline unsigned int mptcp_add_addr_len(int family, bool echo, bool port) return len; } +static inline int mptcp_rm_addr_len(struct mptcp_rm_list rm_list) +{ + if (rm_list.nr == 0 || rm_list.nr >= MPTCP_RM_IDS_MAX) + return -EINVAL; + + return TCPOLEN_MPTCP_RM_ADDR_BASE + roundup(rm_list.nr - 1, 4) + 1; +} + bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining, struct mptcp_addr_info *saddr, bool *echo, bool *port); bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, - u8 *rm_id); + struct mptcp_rm_list *rm_list); int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); void __init mptcp_pm_nl_init(void);