diff --git a/net/core/skmsg.c b/net/core/skmsg.c index c479372f2cd2..9d72f71e9b47 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -682,13 +682,43 @@ static struct sk_psock *sk_psock_from_strp(struct strparser *strp) return container_of(parser, struct sk_psock, parser); } -static void sk_psock_verdict_apply(struct sk_psock *psock, - struct sk_buff *skb, int verdict) +static void sk_psock_skb_redirect(struct sk_psock *psock, struct sk_buff *skb) { struct sk_psock *psock_other; struct sock *sk_other; bool ingress; + sk_other = tcp_skb_bpf_redirect_fetch(skb); + if (unlikely(!sk_other)) { + kfree_skb(skb); + return; + } + psock_other = sk_psock(sk_other); + if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || + !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) { + kfree_skb(skb); + return; + } + + ingress = tcp_skb_bpf_ingress(skb); + if ((!ingress && sock_writeable(sk_other)) || + (ingress && + atomic_read(&sk_other->sk_rmem_alloc) <= + sk_other->sk_rcvbuf)) { + if (!ingress) + skb_set_owner_w(skb, sk_other); + skb_queue_tail(&psock_other->ingress_skb, skb); + schedule_work(&psock_other->work); + } else { + kfree_skb(skb); + } +} + +static void sk_psock_verdict_apply(struct sk_psock *psock, + struct sk_buff *skb, int verdict) +{ + struct sock *sk_other; + switch (verdict) { case __SK_PASS: sk_other = psock->sk; @@ -707,25 +737,8 @@ static void sk_psock_verdict_apply(struct sk_psock *psock, } goto out_free; case __SK_REDIRECT: - sk_other = tcp_skb_bpf_redirect_fetch(skb); - if (unlikely(!sk_other)) - goto out_free; - psock_other = sk_psock(sk_other); - if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || - !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) - goto out_free; - ingress = tcp_skb_bpf_ingress(skb); - if ((!ingress && sock_writeable(sk_other)) || - (ingress && - atomic_read(&sk_other->sk_rmem_alloc) <= - sk_other->sk_rcvbuf)) { - if (!ingress) - skb_set_owner_w(skb, sk_other); - skb_queue_tail(&psock_other->ingress_skb, skb); - schedule_work(&psock_other->work); - break; - } - /* fall-through */ + sk_psock_skb_redirect(psock, skb); + break; case __SK_DROP: /* fall-through */ default: