igmp: hash a hash table to speedup ip_check_mc_rcu()

After IP route cache removal, multicast applications using
a lot of multicast addresses hit a O(N) behavior in ip_check_mc_rcu()

Add a per in_device hash table to get faster lookup.

This hash table is created only if the number of items in mc_list is
above 4.

Reported-by: Shawn Bohrer <sbohrer@rgmadvisors.com>
Signed-off-by: Eric Dumazet <edumazet@google.com>
Tested-by: Shawn Bohrer <sbohrer@rgmadvisors.com>
Reviewed-by: Cong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c
index b047e2d..3469506 100644
--- a/net/ipv4/devinet.c
+++ b/net/ipv4/devinet.c
@@ -215,6 +215,7 @@
 
 	WARN_ON(idev->ifa_list);
 	WARN_ON(idev->mc_list);
+	kfree(rcu_dereference_protected(idev->mc_hash, 1));
 #ifdef NET_REFCNT_DEBUG
 	pr_debug("%s: %p=%s\n", __func__, idev, dev ? dev->name : "NIL");
 #endif
diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c
index 450f625..f72011d 100644
--- a/net/ipv4/igmp.c
+++ b/net/ipv4/igmp.c
@@ -1217,6 +1217,57 @@
  *	Multicast list managers
  */
 
+static u32 ip_mc_hash(const struct ip_mc_list *im)
+{
+	return hash_32((u32)im->multiaddr, MC_HASH_SZ_LOG);
+}
+
+static void ip_mc_hash_add(struct in_device *in_dev,
+			   struct ip_mc_list *im)
+{
+	struct ip_mc_list __rcu **mc_hash;
+	u32 hash;
+
+	mc_hash = rtnl_dereference(in_dev->mc_hash);
+	if (mc_hash) {
+		hash = ip_mc_hash(im);
+		im->next_hash = rtnl_dereference(mc_hash[hash]);
+		rcu_assign_pointer(mc_hash[hash], im);
+		return;
+	}
+
+	/* do not use a hash table for small number of items */
+	if (in_dev->mc_count < 4)
+		return;
+
+	mc_hash = kzalloc(sizeof(struct ip_mc_list *) << MC_HASH_SZ_LOG,
+			  GFP_KERNEL);
+	if (!mc_hash)
+		return;
+
+	for_each_pmc_rtnl(in_dev, im) {
+		hash = ip_mc_hash(im);
+		im->next_hash = rtnl_dereference(mc_hash[hash]);
+		RCU_INIT_POINTER(mc_hash[hash], im);
+	}
+
+	rcu_assign_pointer(in_dev->mc_hash, mc_hash);
+}
+
+static void ip_mc_hash_remove(struct in_device *in_dev,
+			      struct ip_mc_list *im)
+{
+	struct ip_mc_list __rcu **mc_hash = rtnl_dereference(in_dev->mc_hash);
+	struct ip_mc_list *aux;
+
+	if (!mc_hash)
+		return;
+	mc_hash += ip_mc_hash(im);
+	while ((aux = rtnl_dereference(*mc_hash)) != im)
+		mc_hash = &aux->next_hash;
+	*mc_hash = im->next_hash;
+}
+
 
 /*
  *	A socket has joined a multicast group on device dev.
@@ -1258,6 +1309,8 @@
 	in_dev->mc_count++;
 	rcu_assign_pointer(in_dev->mc_list, im);
 
+	ip_mc_hash_add(in_dev, im);
+
 #ifdef CONFIG_IP_MULTICAST
 	igmpv3_del_delrec(in_dev, im->multiaddr);
 #endif
@@ -1314,6 +1367,7 @@
 	     ip = &i->next_rcu) {
 		if (i->multiaddr == addr) {
 			if (--i->users == 0) {
+				ip_mc_hash_remove(in_dev, i);
 				*ip = i->next_rcu;
 				in_dev->mc_count--;
 				igmp_group_dropped(i);
@@ -2321,12 +2375,25 @@
 int ip_check_mc_rcu(struct in_device *in_dev, __be32 mc_addr, __be32 src_addr, u16 proto)
 {
 	struct ip_mc_list *im;
+	struct ip_mc_list __rcu **mc_hash;
 	struct ip_sf_list *psf;
 	int rv = 0;
 
-	for_each_pmc_rcu(in_dev, im) {
-		if (im->multiaddr == mc_addr)
-			break;
+	mc_hash = rcu_dereference(in_dev->mc_hash);
+	if (mc_hash) {
+		u32 hash = hash_32((u32)mc_addr, MC_HASH_SZ_LOG);
+
+		for (im = rcu_dereference(mc_hash[hash]);
+		     im != NULL;
+		     im = rcu_dereference(im->next_hash)) {
+			if (im->multiaddr == mc_addr)
+				break;
+		}
+	} else {
+		for_each_pmc_rcu(in_dev, im) {
+			if (im->multiaddr == mc_addr)
+				break;
+		}
 	}
 	if (im && proto == IPPROTO_IGMP) {
 		rv = 1;