diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h index 995efbb031ab..f74665d7bea8 100644 --- a/include/net/inet6_hashtables.h +++ b/include/net/inet6_hashtables.h @@ -96,10 +96,14 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, const __be16 sport, const __be16 dport) { - return __inet6_lookup(dev_net(skb->dst->dev), hashinfo, - &ipv6_hdr(skb)->saddr, sport, - &ipv6_hdr(skb)->daddr, ntohs(dport), - inet6_iif(skb)); + struct sock *sk; + + if (unlikely(sk = skb_steal_sock(skb))) + return sk; + else return __inet6_lookup(dev_net(skb->dst->dev), hashinfo, + &ipv6_hdr(skb)->saddr, sport, + &ipv6_hdr(skb)->daddr, ntohs(dport), + inet6_iif(skb)); } extern struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h index 3522bbcd546d..5cc182f9ecae 100644 --- a/include/net/inet_hashtables.h +++ b/include/net/inet_hashtables.h @@ -378,11 +378,15 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, const __be16 sport, const __be16 dport) { + struct sock *sk; const struct iphdr *iph = ip_hdr(skb); - return __inet_lookup(dev_net(skb->dst->dev), hashinfo, - iph->saddr, sport, - iph->daddr, dport, inet_iif(skb)); + if (unlikely(sk = skb_steal_sock(skb))) + return sk; + else + return __inet_lookup(dev_net(skb->dst->dev), hashinfo, + iph->saddr, sport, + iph->daddr, dport, inet_iif(skb)); } extern int __inet_hash_connect(struct inet_timewait_death_row *death_row, diff --git a/include/net/sock.h b/include/net/sock.h index 75a312d3888a..18f96708f3a6 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1324,6 +1324,18 @@ static inline void sk_change_net(struct sock *sk, struct net *net) sock_net_set(sk, hold_net(net)); } +static inline struct sock *skb_steal_sock(struct sk_buff *skb) +{ + if (unlikely(skb->sk)) { + struct sock *sk = skb->sk; + + skb->destructor = NULL; + skb->sk = NULL; + return sk; + } + return NULL; +} + extern void sock_enable_timestamp(struct sock *sk); extern int sock_get_timestamp(struct sock *, struct timeval __user *); extern int sock_get_timestampns(struct sock *, struct timespec __user *); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index c7a90b546b21..822c9deac83b 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -306,11 +306,15 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport, struct hlist_head udptable[]) { + struct sock *sk; const struct iphdr *iph = ip_hdr(skb); - return __udp4_lib_lookup(dev_net(skb->dst->dev), iph->saddr, sport, - iph->daddr, dport, inet_iif(skb), - udptable); + if (unlikely(sk = skb_steal_sock(skb))) + return sk; + else + return __udp4_lib_lookup(dev_net(skb->dst->dev), iph->saddr, sport, + iph->daddr, dport, inet_iif(skb), + udptable); } struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index ce26c41d157d..e51da8c092fa 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -111,11 +111,15 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport, struct hlist_head udptable[]) { + struct sock *sk; struct ipv6hdr *iph = ipv6_hdr(skb); - return __udp6_lib_lookup(dev_net(skb->dst->dev), &iph->saddr, sport, - &iph->daddr, dport, inet6_iif(skb), - udptable); + if (unlikely(sk = skb_steal_sock(skb))) + return sk; + else + return __udp6_lib_lookup(dev_net(skb->dst->dev), &iph->saddr, sport, + &iph->daddr, dport, inet6_iif(skb), + udptable); } /*