@@ -361,25 +361,15 @@ static void smc_destruct(struct sock *sk)
361361 return ;
362362}
363363
364- static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
365- int protocol )
364+ void smc_sock_init (struct net * net , struct sock * sk , int protocol )
366365{
367- struct smc_sock * smc ;
368- struct proto * prot ;
369- struct sock * sk ;
370-
371- prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
372- sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
373- if (!sk )
374- return NULL ;
366+ struct smc_sock * smc = smc_sk (sk );
375367
376- sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
377368 sk -> sk_state = SMC_INIT ;
378369 sk -> sk_destruct = smc_destruct ;
379370 sk -> sk_protocol = protocol ;
380371 WRITE_ONCE (sk -> sk_sndbuf , 2 * READ_ONCE (net -> smc .sysctl_wmem ));
381372 WRITE_ONCE (sk -> sk_rcvbuf , 2 * READ_ONCE (net -> smc .sysctl_rmem ));
382- smc = smc_sk (sk );
383373 INIT_WORK (& smc -> tcp_listen_work , smc_tcp_listen_work );
384374 INIT_WORK (& smc -> connect_work , smc_connect_work );
385375 INIT_DELAYED_WORK (& smc -> conn .tx_work , smc_tx_work );
@@ -389,6 +379,24 @@ static struct sock *smc_sock_alloc(struct net *net, struct socket *sock,
389379 sk -> sk_prot -> hash (sk );
390380 mutex_init (& smc -> clcsock_release_lock );
391381 smc_init_saved_callbacks (smc );
382+ smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
383+ smc -> use_fallback = false; /* assume rdma capability first */
384+ smc -> fallback_rsn = 0 ;
385+ }
386+
387+ static struct sock * smc_sock_alloc (struct net * net , struct socket * sock ,
388+ int protocol )
389+ {
390+ struct proto * prot ;
391+ struct sock * sk ;
392+
393+ prot = (protocol == SMCPROTO_SMC6 ) ? & smc_proto6 : & smc_proto ;
394+ sk = sk_alloc (net , PF_SMC , GFP_KERNEL , prot , 0 );
395+ if (!sk )
396+ return NULL ;
397+
398+ sock_init_data (sock , sk ); /* sets sk_refcnt to 1 */
399+ smc_sock_init (net , sk , protocol );
392400
393401 return sk ;
394402}
@@ -3321,6 +3329,31 @@ static const struct proto_ops smc_sock_ops = {
33213329 .splice_read = smc_splice_read ,
33223330};
33233331
3332+ int smc_create_clcsk (struct net * net , struct sock * sk , int family )
3333+ {
3334+ struct smc_sock * smc = smc_sk (sk );
3335+ int rc ;
3336+
3337+ rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3338+ & smc -> clcsock );
3339+ if (rc ) {
3340+ sk_common_release (sk );
3341+ return rc ;
3342+ }
3343+
3344+ /* smc_clcsock_release() does not wait smc->clcsock->sk's
3345+ * destruction; its sk_state might not be TCP_CLOSE after
3346+ * smc->sk is close()d, and TCP timers can be fired later,
3347+ * which need net ref.
3348+ */
3349+ sk = smc -> clcsock -> sk ;
3350+ __netns_tracker_free (net , & sk -> ns_tracker , false);
3351+ sk -> sk_net_refcnt = 1 ;
3352+ get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3353+ sock_inuse_add (net , 1 );
3354+ return 0 ;
3355+ }
3356+
33243357static int __smc_create (struct net * net , struct socket * sock , int protocol ,
33253358 int kern , struct socket * clcsock )
33263359{
@@ -3346,35 +3379,12 @@ static int __smc_create(struct net *net, struct socket *sock, int protocol,
33463379
33473380 /* create internal TCP socket for CLC handshake and fallback */
33483381 smc = smc_sk (sk );
3349- smc -> use_fallback = false; /* assume rdma capability first */
3350- smc -> fallback_rsn = 0 ;
3351-
3352- /* default behavior from limit_smc_hs in every net namespace */
3353- smc -> limit_smc_hs = net -> smc .limit_smc_hs ;
33543382
33553383 rc = 0 ;
3356- if (!clcsock ) {
3357- rc = sock_create_kern (net , family , SOCK_STREAM , IPPROTO_TCP ,
3358- & smc -> clcsock );
3359- if (rc ) {
3360- sk_common_release (sk );
3361- goto out ;
3362- }
3363-
3364- /* smc_clcsock_release() does not wait smc->clcsock->sk's
3365- * destruction; its sk_state might not be TCP_CLOSE after
3366- * smc->sk is close()d, and TCP timers can be fired later,
3367- * which need net ref.
3368- */
3369- sk = smc -> clcsock -> sk ;
3370- __netns_tracker_free (net , & sk -> ns_tracker , false);
3371- sk -> sk_net_refcnt = 1 ;
3372- get_net_track (net , & sk -> ns_tracker , GFP_KERNEL );
3373- sock_inuse_add (net , 1 );
3374- } else {
3384+ if (!clcsock )
3385+ rc = smc_create_clcsk (net , sk , family );
3386+ else
33753387 smc -> clcsock = clcsock ;
3376- }
3377-
33783388out :
33793389 return rc ;
33803390}
0 commit comments