1// SPDX-License-Identifier: GPL-2.0 2/* Multipath TCP token management 3 * Copyright (c) 2017 - 2019, Intel Corporation. 4 * 5 * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org, 6 * authored by: 7 * 8 * Sébastien Barré <sebastien.barre@uclouvain.be> 9 * Christoph Paasch <christoph.paasch@uclouvain.be> 10 * Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi> 11 * Gregory Detal <gregory.detal@uclouvain.be> 12 * Fabien Duchêne <fabien.duchene@uclouvain.be> 13 * Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de> 14 * Lavkesh Lahngir <lavkesh51@gmail.com> 15 * Andreas Ripke <ripke@neclab.eu> 16 * Vlad Dogaru <vlad.dogaru@intel.com> 17 * Octavian Purdila <octavian.purdila@intel.com> 18 * John Ronan <jronan@tssg.org> 19 * Catalin Nicutar <catalin.nicutar@gmail.com> 20 * Brandon Heller <brandonh@stanford.edu> 21 */ 22 23#define pr_fmt(fmt) "MPTCP: " fmt 24 25#include <linux/kernel.h> 26#include <linux/module.h> 27#include <linux/memblock.h> 28#include <linux/ip.h> 29#include <linux/tcp.h> 30#include <net/sock.h> 31#include <net/inet_common.h> 32#include <net/protocol.h> 33#include <net/mptcp.h> 34#include "protocol.h" 35 36#define TOKEN_MAX_RETRIES 4 37#define TOKEN_MAX_CHAIN_LEN 4 38 39struct token_bucket { 40 spinlock_t lock; 41 int chain_len; 42 struct hlist_nulls_head req_chain; 43 struct hlist_nulls_head msk_chain; 44}; 45 46static struct token_bucket *token_hash __read_mostly; 47static unsigned int token_mask __read_mostly; 48 49static struct token_bucket *token_bucket(u32 token) 50{ 51 return &token_hash[token & token_mask]; 52} 53 54/* called with bucket lock held */ 55static struct mptcp_subflow_request_sock * 56__token_lookup_req(struct token_bucket *t, u32 token) 57{ 58 struct mptcp_subflow_request_sock *req; 59 struct hlist_nulls_node *pos; 60 61 hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node) 62 if (req->token == token) 63 return req; 64 return NULL; 65} 66 67/* called with bucket lock held */ 68static struct mptcp_sock * 69__token_lookup_msk(struct token_bucket *t, u32 token) 70{ 71 struct hlist_nulls_node *pos; 72 struct sock *sk; 73 74 sk_nulls_for_each_rcu(sk, pos, &t->msk_chain) 75 if (mptcp_sk(sk)->token == token) 76 return mptcp_sk(sk); 77 return NULL; 78} 79 80static bool __token_bucket_busy(struct token_bucket *t, u32 token) 81{ 82 return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN || 83 __token_lookup_req(t, token) || __token_lookup_msk(t, token); 84} 85 86static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) 87{ 88 /* we might consider a faster version that computes the key as a 89 * hash of some information available in the MPTCP socket. Use 90 * random data at the moment, as it's probably the safest option 91 * in case multiple sockets are opened in different namespaces at 92 * the same time. 93 */ 94 get_random_bytes(key, sizeof(u64)); 95 mptcp_crypto_key_sha(*key, token, idsn); 96} 97 98/** 99 * mptcp_token_new_request - create new key/idsn/token for subflow_request 100 * @req: the request socket 101 * 102 * This function is called when a new mptcp connection is coming in. 103 * 104 * It creates a unique token to identify the new mptcp connection, 105 * a secret local key and the initial data sequence number (idsn). 106 * 107 * Returns 0 on success. 108 */ 109int mptcp_token_new_request(struct request_sock *req) 110{ 111 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 112 struct token_bucket *bucket; 113 u32 token; 114 115 mptcp_crypto_key_sha(subflow_req->local_key, 116 &subflow_req->token, 117 &subflow_req->idsn); 118 pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n", 119 req, subflow_req->local_key, subflow_req->token, 120 subflow_req->idsn); 121 122 token = subflow_req->token; 123 bucket = token_bucket(token); 124 spin_lock_bh(&bucket->lock); 125 if (__token_bucket_busy(bucket, token)) { 126 spin_unlock_bh(&bucket->lock); 127 return -EBUSY; 128 } 129 130 hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain); 131 bucket->chain_len++; 132 spin_unlock_bh(&bucket->lock); 133 return 0; 134} 135 136/** 137 * mptcp_token_new_connect - create new key/idsn/token for subflow 138 * @sk: the socket that will initiate a connection 139 * 140 * This function is called when a new outgoing mptcp connection is 141 * initiated. 142 * 143 * It creates a unique token to identify the new mptcp connection, 144 * a secret local key and the initial data sequence number (idsn). 145 * 146 * On success, the mptcp connection can be found again using 147 * the computed token at a later time, this is needed to process 148 * join requests. 149 * 150 * returns 0 on success. 151 */ 152int mptcp_token_new_connect(struct sock *sk) 153{ 154 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); 155 struct mptcp_sock *msk = mptcp_sk(subflow->conn); 156 int retries = TOKEN_MAX_RETRIES; 157 struct token_bucket *bucket; 158 159again: 160 mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token, 161 &subflow->idsn); 162 163 bucket = token_bucket(subflow->token); 164 spin_lock_bh(&bucket->lock); 165 if (__token_bucket_busy(bucket, subflow->token)) { 166 spin_unlock_bh(&bucket->lock); 167 if (!--retries) 168 return -EBUSY; 169 goto again; 170 } 171 172 pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", 173 sk, subflow->local_key, subflow->token, subflow->idsn); 174 175 WRITE_ONCE(msk->token, subflow->token); 176 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 177 bucket->chain_len++; 178 spin_unlock_bh(&bucket->lock); 179 return 0; 180} 181 182/** 183 * mptcp_token_accept - replace a req sk with full sock in token hash 184 * @req: the request socket to be removed 185 * @msk: the just cloned socket linked to the new connection 186 * 187 * Called when a SYN packet creates a new logical connection, i.e. 188 * is not a join request. 189 */ 190void mptcp_token_accept(struct mptcp_subflow_request_sock *req, 191 struct mptcp_sock *msk) 192{ 193 struct mptcp_subflow_request_sock *pos; 194 struct token_bucket *bucket; 195 196 bucket = token_bucket(req->token); 197 spin_lock_bh(&bucket->lock); 198 199 /* pedantic lookup check for the moved token */ 200 pos = __token_lookup_req(bucket, req->token); 201 if (!WARN_ON_ONCE(pos != req)) 202 hlist_nulls_del_init_rcu(&req->token_node); 203 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 204 spin_unlock_bh(&bucket->lock); 205} 206 207bool mptcp_token_exists(u32 token) 208{ 209 struct hlist_nulls_node *pos; 210 struct token_bucket *bucket; 211 struct mptcp_sock *msk; 212 struct sock *sk; 213 214 rcu_read_lock(); 215 bucket = token_bucket(token); 216 217again: 218 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 219 msk = mptcp_sk(sk); 220 if (READ_ONCE(msk->token) == token) 221 goto found; 222 } 223 if (get_nulls_value(pos) != (token & token_mask)) 224 goto again; 225 226 rcu_read_unlock(); 227 return false; 228found: 229 rcu_read_unlock(); 230 return true; 231} 232 233/** 234 * mptcp_token_get_sock - retrieve mptcp connection sock using its token 235 * @net: restrict to this namespace 236 * @token: token of the mptcp connection to retrieve 237 * 238 * This function returns the mptcp connection structure with the given token. 239 * A reference count on the mptcp socket returned is taken. 240 * 241 * returns NULL if no connection with the given token value exists. 242 */ 243struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token) 244{ 245 struct hlist_nulls_node *pos; 246 struct token_bucket *bucket; 247 struct mptcp_sock *msk; 248 struct sock *sk; 249 250 rcu_read_lock(); 251 bucket = token_bucket(token); 252 253again: 254 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 255 msk = mptcp_sk(sk); 256 if (READ_ONCE(msk->token) != token || 257 !net_eq(sock_net(sk), net)) 258 continue; 259 260 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 261 goto not_found; 262 263 if (READ_ONCE(msk->token) != token || 264 !net_eq(sock_net(sk), net)) { 265 sock_put(sk); 266 goto again; 267 } 268 goto found; 269 } 270 if (get_nulls_value(pos) != (token & token_mask)) 271 goto again; 272 273not_found: 274 msk = NULL; 275 276found: 277 rcu_read_unlock(); 278 return msk; 279} 280EXPORT_SYMBOL_GPL(mptcp_token_get_sock); 281 282/** 283 * mptcp_token_iter_next - iterate over the token container from given pos 284 * @net: namespace to be iterated 285 * @s_slot: start slot number 286 * @s_num: start number inside the given lock 287 * 288 * This function returns the first mptcp connection structure found inside the 289 * token container starting from the specified position, or NULL. 290 * 291 * On successful iteration, the iterator is move to the next position and the 292 * the acquires a reference to the returned socket. 293 */ 294struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, 295 long *s_num) 296{ 297 struct mptcp_sock *ret = NULL; 298 struct hlist_nulls_node *pos; 299 int slot, num = 0; 300 301 for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) { 302 struct token_bucket *bucket = &token_hash[slot]; 303 struct sock *sk; 304 305 num = 0; 306 307 if (hlist_nulls_empty(&bucket->msk_chain)) 308 continue; 309 310 rcu_read_lock(); 311 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 312 ++num; 313 if (!net_eq(sock_net(sk), net)) 314 continue; 315 316 if (num <= *s_num) 317 continue; 318 319 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 320 continue; 321 322 if (!net_eq(sock_net(sk), net)) { 323 sock_put(sk); 324 continue; 325 } 326 327 ret = mptcp_sk(sk); 328 rcu_read_unlock(); 329 goto out; 330 } 331 rcu_read_unlock(); 332 } 333 334out: 335 *s_slot = slot; 336 *s_num = num; 337 return ret; 338} 339EXPORT_SYMBOL_GPL(mptcp_token_iter_next); 340 341/** 342 * mptcp_token_destroy_request - remove mptcp connection/token 343 * @req: mptcp request socket dropping the token 344 * 345 * Remove the token associated to @req. 346 */ 347void mptcp_token_destroy_request(struct request_sock *req) 348{ 349 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 350 struct mptcp_subflow_request_sock *pos; 351 struct token_bucket *bucket; 352 353 if (hlist_nulls_unhashed(&subflow_req->token_node)) 354 return; 355 356 bucket = token_bucket(subflow_req->token); 357 spin_lock_bh(&bucket->lock); 358 pos = __token_lookup_req(bucket, subflow_req->token); 359 if (!WARN_ON_ONCE(pos != subflow_req)) { 360 hlist_nulls_del_init_rcu(&pos->token_node); 361 bucket->chain_len--; 362 } 363 spin_unlock_bh(&bucket->lock); 364} 365 366/** 367 * mptcp_token_destroy - remove mptcp connection/token 368 * @msk: mptcp connection dropping the token 369 * 370 * Remove the token associated to @msk 371 */ 372void mptcp_token_destroy(struct mptcp_sock *msk) 373{ 374 struct token_bucket *bucket; 375 struct mptcp_sock *pos; 376 377 if (sk_unhashed((struct sock *)msk)) 378 return; 379 380 bucket = token_bucket(msk->token); 381 spin_lock_bh(&bucket->lock); 382 pos = __token_lookup_msk(bucket, msk->token); 383 if (!WARN_ON_ONCE(pos != msk)) { 384 __sk_nulls_del_node_init_rcu((struct sock *)pos); 385 bucket->chain_len--; 386 } 387 spin_unlock_bh(&bucket->lock); 388} 389 390void __init mptcp_token_init(void) 391{ 392 int i; 393 394 token_hash = alloc_large_system_hash("MPTCP token", 395 sizeof(struct token_bucket), 396 0, 397 20,/* one slot per 1MB of memory */ 398 HASH_ZERO, 399 NULL, 400 &token_mask, 401 0, 402 64 * 1024); 403 for (i = 0; i < token_mask + 1; ++i) { 404 INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i); 405 INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i); 406 spin_lock_init(&token_hash[i].lock); 407 } 408} 409 410#if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS) 411EXPORT_SYMBOL_GPL(mptcp_token_new_request); 412EXPORT_SYMBOL_GPL(mptcp_token_new_connect); 413EXPORT_SYMBOL_GPL(mptcp_token_accept); 414EXPORT_SYMBOL_GPL(mptcp_token_destroy_request); 415EXPORT_SYMBOL_GPL(mptcp_token_destroy); 416#endif 417