"""Implements iptables rules using linux utilities."""
+import collections
import contextlib
import os
import re
return rules_index
- def _find_last_entry(self, filter_list, match_str):
- # find a matching entry, starting from the bottom
- for s in reversed(filter_list):
- if match_str in s:
- return s.strip()
+ def _find_last_entry(self, filter_map, match_str):
+ # find last matching entry
+ try:
+ return filter_map[match_str][-1]
+ except KeyError:
+ pass
def _modify_rules(self, current_lines, table, table_name):
# Chains are stored as sets to avoid duplicates.
(old_filter if self.wrap_name in line else
new_filter).append(line.strip())
+ old_filter_map = make_filter_map(old_filter)
+ new_filter_map = make_filter_map(new_filter)
+
rules_index = self._find_rules_index(new_filter)
all_chains = [':%s' % name for name in unwrapped_chains]
for chain in all_chains:
chain_str = str(chain).strip()
- old = self._find_last_entry(old_filter, chain_str)
+ old = self._find_last_entry(old_filter_map, chain_str)
if not old:
- dup = self._find_last_entry(new_filter, chain_str)
+ dup = self._find_last_entry(new_filter_map, chain_str)
new_filter = [s for s in new_filter if chain_str not in s.strip()]
# if no old or duplicates, use original chain
# Further down, we weed out duplicates from the bottom of the
# list, so here we remove the dupes ahead of time.
- old = self._find_last_entry(old_filter, rule_str)
+ old = self._find_last_entry(old_filter_map, rule_str)
if not old:
- dup = self._find_last_entry(new_filter, rule_str)
+ dup = self._find_last_entry(new_filter_map, rule_str)
new_filter = [s for s in new_filter if rule_str not in s.strip()]
# if no old or duplicates, use original rule
acc['bytes'] += int(data[1])
return acc
+
+
+def make_filter_map(filter_list):
+ filter_map = collections.defaultdict(list)
+ for data in filter_list:
+ # strip any [packet:byte] counts at start or end of lines,
+ # for example, chains look like ":neutron-foo - [0:0]"
+ # and rules look like "[0:0] -A neutron-foo..."
+ if data.startswith('['):
+ key = data.rpartition('] ')[2]
+ elif data.endswith(']'):
+ key = data.rsplit(' [', 1)[0]
+ if key.endswith(' -'):
+ key = key[:-2]
+ else:
+ # things like COMMIT, *filter, and *nat land here
+ continue
+ filter_map[key].append(data)
+ # regular IP(v6) entries are translated into /32s or /128s so we
+ # include a lookup without the CIDR here to match as well
+ for cidr in ('/32', '/128'):
+ if cidr in key:
+ alt_key = key.replace(cidr, '')
+ filter_map[alt_key].append(data)
+ # return a regular dict so readers don't accidentally add entries
+ return dict(filter_map)
'[0:0] -A FORWARD -j neutron-filter-top',
'[0:0] -A OUTPUT -j neutron-filter-top'
% IPTABLES_ARG]
-
- return self.iptables._find_last_entry(filter_list, find_str)
+ filter_map = iptables_manager.make_filter_map(filter_list)
+ return self.iptables._find_last_entry(filter_map, find_str)
def test_find_last_entry_old_dup(self):
- find_str = 'neutron-filter-top'
+ find_str = '-A OUTPUT -j neutron-filter-top'
match_str = '[0:0] -A OUTPUT -j neutron-filter-top'
ret_str = self._test_find_last_entry(find_str)
self.assertEqual(ret_str, match_str)
ret_str = self._test_find_last_entry(find_str)
self.assertIsNone(ret_str)
+ def test_make_filter_map_cidr_stripping(self):
+ filter_rules = ('[0:0] -A OUTPUT -j DROP',
+ '[0:0] -A INPUT -d 192.168.0.2/32 -j DROP',
+ '[0:0] -A INPUT -d 1234:31::001F/128 -j DROP',
+ 'OUTPUT - [0:0]')
+ filter_map = iptables_manager.make_filter_map(filter_rules)
+ # make sure /128 works without CIDR
+ self.assertEqual(filter_rules[2],
+ filter_map['-A INPUT -d 1234:31::001F -j DROP'][0])
+ # make sure /32 works without CIDR
+ self.assertEqual(filter_rules[1],
+ filter_map['-A INPUT -d 192.168.0.2 -j DROP'][0])
+
class IptablesManagerStateLessTestCase(base.BaseTestCase):