@@ -119,7 +119,6 @@ struct pppol2tp_session {
119119 struct mutex sk_lock ; /* Protects .sk */
120120 struct sock __rcu * sk ; /* Pointer to the session PPPoX socket */
121121 struct sock * __sk ; /* Copy of .sk, for cleanup */
122- struct rcu_head rcu ; /* For asynchronous release */
123122};
124123
125124static int pppol2tp_xmit (struct ppp_channel * chan , struct sk_buff * skb );
@@ -157,20 +156,16 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
157156 if (!sk )
158157 return NULL ;
159158
160- sock_hold (sk );
161- session = (struct l2tp_session * )(sk -> sk_user_data );
162- if (!session ) {
163- sock_put (sk );
164- goto out ;
165- }
166- if (WARN_ON (session -> magic != L2TP_SESSION_MAGIC )) {
167- session = NULL ;
168- sock_put (sk );
169- goto out ;
159+ rcu_read_lock ();
160+ session = rcu_dereference_sk_user_data (sk );
161+ if (session && refcount_inc_not_zero (& session -> ref_count )) {
162+ rcu_read_unlock ();
163+ WARN_ON_ONCE (session -> magic != L2TP_SESSION_MAGIC );
164+ return session ;
170165 }
166+ rcu_read_unlock ();
171167
172- out :
173- return session ;
168+ return NULL ;
174169}
175170
176171/*****************************************************************************
@@ -318,12 +313,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
318313 l2tp_xmit_skb (session , skb );
319314 local_bh_enable ();
320315
321- sock_put ( sk );
316+ l2tp_session_dec_refcount ( session );
322317
323318 return total_len ;
324319
325320error_put_sess :
326- sock_put ( sk );
321+ l2tp_session_dec_refcount ( session );
327322error :
328323 return error ;
329324}
@@ -377,12 +372,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
377372 l2tp_xmit_skb (session , skb );
378373 local_bh_enable ();
379374
380- sock_put ( sk );
375+ l2tp_session_dec_refcount ( session );
381376
382377 return 1 ;
383378
384379abort_put_sess :
385- sock_put ( sk );
380+ l2tp_session_dec_refcount ( session );
386381abort :
387382 /* Free the original skb */
388383 kfree_skb (skb );
@@ -393,28 +388,31 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
393388 * Session (and tunnel control) socket create/destroy.
394389 *****************************************************************************/
395390
396- static void pppol2tp_put_sk (struct rcu_head * head )
397- {
398- struct pppol2tp_session * ps ;
399-
400- ps = container_of (head , typeof (* ps ), rcu );
401- sock_put (ps -> __sk );
402- }
403-
404391/* Really kill the session socket. (Called from sock_put() if
405392 * refcnt == 0.)
406393 */
407394static void pppol2tp_session_destruct (struct sock * sk )
408395{
409- struct l2tp_session * session = sk -> sk_user_data ;
410-
411396 skb_queue_purge (& sk -> sk_receive_queue );
412397 skb_queue_purge (& sk -> sk_write_queue );
398+ }
413399
414- if (session ) {
415- sk -> sk_user_data = NULL ;
416- if (WARN_ON (session -> magic != L2TP_SESSION_MAGIC ))
417- return ;
400+ static void pppol2tp_session_close (struct l2tp_session * session )
401+ {
402+ struct pppol2tp_session * ps ;
403+
404+ ps = l2tp_session_priv (session );
405+ mutex_lock (& ps -> sk_lock );
406+ ps -> __sk = rcu_dereference_protected (ps -> sk ,
407+ lockdep_is_held (& ps -> sk_lock ));
408+ RCU_INIT_POINTER (ps -> sk , NULL );
409+ mutex_unlock (& ps -> sk_lock );
410+ if (ps -> __sk ) {
411+ /* detach socket */
412+ rcu_assign_sk_user_data (ps -> __sk , NULL );
413+ sock_put (ps -> __sk );
414+
415+ /* drop ref taken when we referenced socket via sk_user_data */
418416 l2tp_session_dec_refcount (session );
419417 }
420418}
@@ -444,30 +442,13 @@ static int pppol2tp_release(struct socket *sock)
444442
445443 session = pppol2tp_sock_to_session (sk );
446444 if (session ) {
447- struct pppol2tp_session * ps ;
448-
449445 l2tp_session_delete (session );
450-
451- ps = l2tp_session_priv (session );
452- mutex_lock (& ps -> sk_lock );
453- ps -> __sk = rcu_dereference_protected (ps -> sk ,
454- lockdep_is_held (& ps -> sk_lock ));
455- RCU_INIT_POINTER (ps -> sk , NULL );
456- mutex_unlock (& ps -> sk_lock );
457- call_rcu (& ps -> rcu , pppol2tp_put_sk );
458-
459- /* Rely on the sock_put() call at the end of the function for
460- * dropping the reference held by pppol2tp_sock_to_session().
461- * The last reference will be dropped by pppol2tp_put_sk().
462- */
446+ /* drop ref taken by pppol2tp_sock_to_session */
447+ l2tp_session_dec_refcount (session );
463448 }
464449
465450 release_sock (sk );
466451
467- /* This will delete the session context via
468- * pppol2tp_session_destruct() if the socket's refcnt drops to
469- * zero.
470- */
471452 sock_put (sk );
472453
473454 return 0 ;
@@ -506,6 +487,7 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
506487 goto out ;
507488
508489 sock_init_data (sock , sk );
490+ sock_set_flag (sk , SOCK_RCU_FREE );
509491
510492 sock -> state = SS_UNCONNECTED ;
511493 sock -> ops = & pppol2tp_ops ;
@@ -542,6 +524,7 @@ static void pppol2tp_session_init(struct l2tp_session *session)
542524 struct pppol2tp_session * ps ;
543525
544526 session -> recv_skb = pppol2tp_recv ;
527+ session -> session_close = pppol2tp_session_close ;
545528 if (IS_ENABLED (CONFIG_L2TP_DEBUGFS ))
546529 session -> show = pppol2tp_show ;
547530
@@ -830,12 +813,13 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
830813
831814out_no_ppp :
832815 /* This is how we get the session context from the socket. */
833- sk -> sk_user_data = session ;
816+ sock_hold (sk );
817+ rcu_assign_sk_user_data (sk , session );
834818 rcu_assign_pointer (ps -> sk , sk );
835819 mutex_unlock (& ps -> sk_lock );
836820
837821 /* Keep the reference we've grabbed on the session: sk doesn't expect
838- * the session to disappear. pppol2tp_session_destruct () is responsible
822+ * the session to disappear. pppol2tp_session_close () is responsible
839823 * for dropping it.
840824 */
841825 drop_refcnt = false;
@@ -1002,7 +986,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
1002986
1003987 error = len ;
1004988
1005- sock_put ( sk );
989+ l2tp_session_dec_refcount ( session );
1006990end :
1007991 return error ;
1008992}
@@ -1274,7 +1258,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
12741258 err = pppol2tp_session_setsockopt (sk , session , optname , val );
12751259 }
12761260
1277- sock_put ( sk );
1261+ l2tp_session_dec_refcount ( session );
12781262end :
12791263 return err ;
12801264}
@@ -1395,7 +1379,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
13951379 err = 0 ;
13961380
13971381end_put_sess :
1398- sock_put ( sk );
1382+ l2tp_session_dec_refcount ( session );
13991383end :
14001384 return err ;
14011385}
0 commit comments