diff --git a/net/batman-adv/multicast.c b/net/batman-adv/multicast.c index a11d3d89f012..a35f597e8c8b 100644 --- a/net/batman-adv/multicast.c +++ b/net/batman-adv/multicast.c @@ -1536,7 +1536,7 @@ out: if (!ret && primary_if) *primary_if = hard_iface; - else + else if (hard_iface) batadv_hardif_put(hard_iface); return ret; diff --git a/net/batman-adv/translation-table.c b/net/batman-adv/translation-table.c index 0225616d5771..3986551397ca 100644 --- a/net/batman-adv/translation-table.c +++ b/net/batman-adv/translation-table.c @@ -862,7 +862,7 @@ batadv_tt_prepare_tvlv_global_data(struct batadv_orig_node *orig_node, struct batadv_orig_node_vlan *vlan; u8 *tt_change_ptr; - rcu_read_lock(); + spin_lock_bh(&orig_node->vlan_list_lock); hlist_for_each_entry_rcu(vlan, &orig_node->vlan_list, list) { num_vlan++; num_entries += atomic_read(&vlan->tt.num_entries); @@ -900,7 +900,7 @@ batadv_tt_prepare_tvlv_global_data(struct batadv_orig_node *orig_node, *tt_change = (struct batadv_tvlv_tt_change *)tt_change_ptr; out: - rcu_read_unlock(); + spin_unlock_bh(&orig_node->vlan_list_lock); return tvlv_len; } @@ -931,15 +931,20 @@ batadv_tt_prepare_tvlv_local_data(struct batadv_priv *bat_priv, struct batadv_tvlv_tt_vlan_data *tt_vlan; struct batadv_softif_vlan *vlan; u16 num_vlan = 0; - u16 num_entries = 0; + u16 vlan_entries = 0; + u16 total_entries = 0; u16 tvlv_len; u8 *tt_change_ptr; int change_offset; - rcu_read_lock(); + spin_lock_bh(&bat_priv->softif_vlan_list_lock); hlist_for_each_entry_rcu(vlan, &bat_priv->softif_vlan_list, list) { + vlan_entries = atomic_read(&vlan->tt.num_entries); + if (vlan_entries < 1) + continue; + num_vlan++; - num_entries += atomic_read(&vlan->tt.num_entries); + total_entries += vlan_entries; } change_offset = sizeof(**tt_data); @@ -947,7 +952,7 @@ batadv_tt_prepare_tvlv_local_data(struct batadv_priv *bat_priv, /* if tt_len is negative, allocate the space needed by the full table */ if (*tt_len < 0) - *tt_len = batadv_tt_len(num_entries); + *tt_len = batadv_tt_len(total_entries); tvlv_len = *tt_len; tvlv_len += change_offset; @@ -964,6 +969,10 @@ batadv_tt_prepare_tvlv_local_data(struct batadv_priv *bat_priv, tt_vlan = (struct batadv_tvlv_tt_vlan_data *)(*tt_data + 1); hlist_for_each_entry_rcu(vlan, &bat_priv->softif_vlan_list, list) { + vlan_entries = atomic_read(&vlan->tt.num_entries); + if (vlan_entries < 1) + continue; + tt_vlan->vid = htons(vlan->vid); tt_vlan->crc = htonl(vlan->tt.crc); @@ -974,7 +983,7 @@ batadv_tt_prepare_tvlv_local_data(struct batadv_priv *bat_priv, *tt_change = (struct batadv_tvlv_tt_change *)tt_change_ptr; out: - rcu_read_unlock(); + spin_unlock_bh(&bat_priv->softif_vlan_list_lock); return tvlv_len; } @@ -1538,6 +1547,8 @@ batadv_tt_global_orig_entry_find(const struct batadv_tt_global_entry *entry, * handled by a given originator * @entry: the TT global entry to check * @orig_node: the originator to search in the list + * @flags: a pointer to store TT flags for the given @entry received + * from @orig_node * * find out if an orig_node is already in the list of a tt_global_entry. * @@ -1545,7 +1556,8 @@ batadv_tt_global_orig_entry_find(const struct batadv_tt_global_entry *entry, */ static bool batadv_tt_global_entry_has_orig(const struct batadv_tt_global_entry *entry, - const struct batadv_orig_node *orig_node) + const struct batadv_orig_node *orig_node, + u8 *flags) { struct batadv_tt_orig_list_entry *orig_entry; bool found = false; @@ -1553,6 +1565,10 @@ batadv_tt_global_entry_has_orig(const struct batadv_tt_global_entry *entry, orig_entry = batadv_tt_global_orig_entry_find(entry, orig_node); if (orig_entry) { found = true; + + if (flags) + *flags = orig_entry->flags; + batadv_tt_orig_list_entry_put(orig_entry); } @@ -1731,7 +1747,7 @@ static bool batadv_tt_global_add(struct batadv_priv *bat_priv, if (!(common->flags & BATADV_TT_CLIENT_TEMP)) goto out; if (batadv_tt_global_entry_has_orig(tt_global_entry, - orig_node)) + orig_node, NULL)) goto out_remove; batadv_tt_global_del_orig_list(tt_global_entry); goto add_orig_entry; @@ -2880,23 +2896,46 @@ unlock: } /** - * batadv_tt_local_valid() - verify that given tt entry is a valid one + * batadv_tt_local_valid() - verify local tt entry and get flags * @entry_ptr: to be checked local tt entry * @data_ptr: not used but definition required to satisfy the callback prototype + * @flags: a pointer to store TT flags for this client to + * + * Checks the validity of the given local TT entry. If it is, then the provided + * flags pointer is updated. * * Return: true if the entry is a valid, false otherwise. */ -static bool batadv_tt_local_valid(const void *entry_ptr, const void *data_ptr) +static bool batadv_tt_local_valid(const void *entry_ptr, + const void *data_ptr, + u8 *flags) { const struct batadv_tt_common_entry *tt_common_entry = entry_ptr; if (tt_common_entry->flags & BATADV_TT_CLIENT_NEW) return false; + + if (flags) + *flags = tt_common_entry->flags; + return true; } +/** + * batadv_tt_global_valid() - verify global tt entry and get flags + * @entry_ptr: to be checked global tt entry + * @data_ptr: an orig_node object (may be NULL) + * @flags: a pointer to store TT flags for this client to + * + * Checks the validity of the given global TT entry. If it is, then the provided + * flags pointer is updated either with the common (summed) TT flags if data_ptr + * is NULL or the specific, per originator TT flags otherwise. + * + * Return: true if the entry is a valid, false otherwise. + */ static bool batadv_tt_global_valid(const void *entry_ptr, - const void *data_ptr) + const void *data_ptr, + u8 *flags) { const struct batadv_tt_common_entry *tt_common_entry = entry_ptr; const struct batadv_tt_global_entry *tt_global_entry; @@ -2910,7 +2949,8 @@ static bool batadv_tt_global_valid(const void *entry_ptr, struct batadv_tt_global_entry, common); - return batadv_tt_global_entry_has_orig(tt_global_entry, orig_node); + return batadv_tt_global_entry_has_orig(tt_global_entry, orig_node, + flags); } /** @@ -2920,25 +2960,34 @@ static bool batadv_tt_global_valid(const void *entry_ptr, * @hash: hash table containing the tt entries * @tt_len: expected tvlv tt data buffer length in number of bytes * @tvlv_buff: pointer to the buffer to fill with the TT data - * @valid_cb: function to filter tt change entries + * @valid_cb: function to filter tt change entries and to return TT flags * @cb_data: data passed to the filter function as argument + * + * Fills the tvlv buff with the tt entries from the specified hash. If valid_cb + * is not provided then this becomes a no-op. */ static void batadv_tt_tvlv_generate(struct batadv_priv *bat_priv, struct batadv_hashtable *hash, void *tvlv_buff, u16 tt_len, bool (*valid_cb)(const void *, - const void *), + const void *, + u8 *flags), void *cb_data) { struct batadv_tt_common_entry *tt_common_entry; struct batadv_tvlv_tt_change *tt_change; struct hlist_head *head; u16 tt_tot, tt_num_entries = 0; + u8 flags; + bool ret; u32 i; tt_tot = batadv_tt_entries(tt_len); tt_change = (struct batadv_tvlv_tt_change *)tvlv_buff; + if (!valid_cb) + return; + rcu_read_lock(); for (i = 0; i < hash->size; i++) { head = &hash->table[i]; @@ -2948,11 +2997,12 @@ static void batadv_tt_tvlv_generate(struct batadv_priv *bat_priv, if (tt_tot == tt_num_entries) break; - if ((valid_cb) && (!valid_cb(tt_common_entry, cb_data))) + ret = valid_cb(tt_common_entry, cb_data, &flags); + if (!ret) continue; ether_addr_copy(tt_change->addr, tt_common_entry->addr); - tt_change->flags = tt_common_entry->flags; + tt_change->flags = flags; tt_change->vid = htons(tt_common_entry->vid); memset(tt_change->reserved, 0, sizeof(tt_change->reserved));