batman-adv: protect originator nodes with reference counters

Signed-off-by: Marek Lindner <lindner_marek@yahoo.de>
diff --git a/net/batman-adv/originator.c b/net/batman-adv/originator.c
index 5c32314..fcdb0b7 100644
--- a/net/batman-adv/originator.c
+++ b/net/batman-adv/originator.c
@@ -103,12 +103,13 @@
 	return neigh_node;
 }
 
-static void free_orig_node(void *data, void *arg)
+void orig_node_free_ref(struct kref *refcount)
 {
 	struct hlist_node *node, *node_tmp;
 	struct neigh_node *neigh_node;
-	struct orig_node *orig_node = (struct orig_node *)data;
-	struct bat_priv *bat_priv = (struct bat_priv *)arg;
+	struct orig_node *orig_node;
+
+	orig_node = container_of(refcount, struct orig_node, refcount);
 
 	spin_lock_bh(&orig_node->neigh_list_lock);
 
@@ -122,7 +123,8 @@
 	spin_unlock_bh(&orig_node->neigh_list_lock);
 
 	frag_list_free(&orig_node->frag_list);
-	hna_global_del_orig(bat_priv, orig_node, "originator timed out");
+	hna_global_del_orig(orig_node->bat_priv, orig_node,
+			    "originator timed out");
 
 	kfree(orig_node->bcast_own);
 	kfree(orig_node->bcast_own_sum);
@@ -131,17 +133,53 @@
 
 void originator_free(struct bat_priv *bat_priv)
 {
-	if (!bat_priv->orig_hash)
+	struct hashtable_t *hash = bat_priv->orig_hash;
+	struct hlist_node *walk, *safe;
+	struct hlist_head *head;
+	struct element_t *bucket;
+	spinlock_t *list_lock; /* spinlock to protect write access */
+	struct orig_node *orig_node;
+	int i;
+
+	if (!hash)
 		return;
 
 	cancel_delayed_work_sync(&bat_priv->orig_work);
 
 	spin_lock_bh(&bat_priv->orig_hash_lock);
-	hash_delete(bat_priv->orig_hash, free_orig_node, bat_priv);
 	bat_priv->orig_hash = NULL;
+
+	for (i = 0; i < hash->size; i++) {
+		head = &hash->table[i];
+		list_lock = &hash->list_locks[i];
+
+		spin_lock_bh(list_lock);
+		hlist_for_each_entry_safe(bucket, walk, safe, head, hlist) {
+			orig_node = bucket->data;
+
+			hlist_del_rcu(walk);
+			call_rcu(&bucket->rcu, bucket_free_rcu);
+			kref_put(&orig_node->refcount, orig_node_free_ref);
+		}
+		spin_unlock_bh(list_lock);
+	}
+
+	hash_destroy(hash);
 	spin_unlock_bh(&bat_priv->orig_hash_lock);
 }
 
+static void bucket_free_orig_rcu(struct rcu_head *rcu)
+{
+	struct element_t *bucket;
+	struct orig_node *orig_node;
+
+	bucket = container_of(rcu, struct element_t, rcu);
+	orig_node = bucket->data;
+
+	kref_put(&orig_node->refcount, orig_node_free_ref);
+	kfree(bucket);
+}
+
 /* this function finds or creates an originator entry for the given
  * address if it does not exits */
 struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr)
@@ -156,8 +194,10 @@
 						   addr));
 	rcu_read_unlock();
 
-	if (orig_node)
+	if (orig_node) {
+		kref_get(&orig_node->refcount);
 		return orig_node;
+	}
 
 	bat_dbg(DBG_BATMAN, bat_priv,
 		"Creating new originator: %pM\n", addr);
@@ -168,7 +208,9 @@
 
 	INIT_HLIST_HEAD(&orig_node->neigh_list);
 	spin_lock_init(&orig_node->neigh_list_lock);
+	kref_init(&orig_node->refcount);
 
+	orig_node->bat_priv = bat_priv;
 	memcpy(orig_node->orig, addr, ETH_ALEN);
 	orig_node->router = NULL;
 	orig_node->hna_buff = NULL;
@@ -197,6 +239,8 @@
 	if (hash_added < 0)
 		goto free_bcast_own_sum;
 
+	/* extra reference for return */
+	kref_get(&orig_node->refcount);
 	return orig_node;
 free_bcast_own_sum:
 	kfree(orig_node->bcast_own_sum);
@@ -318,8 +362,7 @@
 				if (orig_node->gw_flags)
 					gw_node_delete(bat_priv, orig_node);
 				hlist_del_rcu(walk);
-				call_rcu(&bucket->rcu, bucket_free_rcu);
-				free_orig_node(orig_node, bat_priv);
+				call_rcu(&bucket->rcu, bucket_free_orig_rcu);
 				continue;
 			}