]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Refactor _remove_unused_security_group_info
authorMiguel Angel Ajo <mangelajo@redhat.com>
Tue, 3 Feb 2015 13:35:40 +0000 (13:35 +0000)
committerMiguel Angel Ajo <mangelajo@redhat.com>
Fri, 13 Mar 2015 13:58:08 +0000 (13:58 +0000)
_remove_unused_security_group_info is refactored into smaller
functions, to make this block easier to understand.

Implements blueprint refactor-iptables-firewall-driver

Change-Id: I4107f1a702d059337e7b2d701a5d0372ee2cfe11

neutron/agent/linux/iptables_firewall.py
neutron/tests/unit/test_iptables_firewall.py

index 9a4d3bd8bcbc789e227597ed541387d35a73df81..9896559dcd18d5ac4879018ac03a96099ad9293f 100644 (file)
@@ -13,6 +13,7 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import collections
 import netaddr
 from oslo.config import cfg
 
@@ -64,7 +65,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         self.sg_rules = {}
         self.pre_sg_rules = None
         # List of security group member ips for ports residing on this host
-        self.sg_members = {}
+        self.sg_members = collections.defaultdict(
+            lambda: collections.defaultdict(list))
         self.pre_sg_members = None
         self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
 
@@ -78,7 +80,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
 
     def update_security_group_members(self, sg_id, sg_members):
         LOG.debug("Update members of security group (%s)", sg_id)
-        self.sg_members[sg_id] = sg_members
+        self.sg_members[sg_id] = collections.defaultdict(list, sg_members)
 
     def prepare_port_filter(self, port):
         LOG.debug("Preparing device (%s) filter", port['device'])
@@ -323,12 +325,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         else:
             yield rule
 
-    def _get_remote_sg_ids(self, port, direction):
+    def _get_remote_sg_ids(self, port, direction=None):
         sg_ids = port.get('security_groups', [])
         remote_sg_ids = {constants.IPv4: [], constants.IPv6: []}
         for sg_id in sg_ids:
             for rule in self.sg_rules.get(sg_id, []):
-                if rule['direction'] == direction:
+                if not direction or rule['direction'] == direction:
                     remote_sg_id = rule.get('remote_group_id')
                     ether_type = rule.get('ethertype')
                     if remote_sg_id and ether_type:
@@ -374,15 +376,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
                             ipv6_iptables_rules)
         self._drop_dhcp_rule(ipv4_iptables_rules, ipv6_iptables_rules)
 
-    def _get_current_sg_member_ips(self, sg_id, ethertype):
-        return self.sg_members.get(sg_id, {}).get(ethertype, [])
-
     def _update_ipset_members(self, security_group_ids):
-        for ethertype, sg_ids in security_group_ids.items():
+        for ip_version, sg_ids in security_group_ids.items():
             for sg_id in sg_ids:
-                current_ips = self._get_current_sg_member_ips(sg_id, ethertype)
+                current_ips = self.sg_members[sg_id][ip_version]
                 if current_ips:
-                    self.ipset.set_members(sg_id, ethertype, current_ips)
+                    self.ipset.set_members(sg_id, ip_version, current_ips)
 
     def _generate_ipset_rule_args(self, sg_rule, remote_gid):
         ethertype = sg_rule.get('ethertype')
@@ -505,47 +504,88 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
             self._defer_apply = True
 
     def _remove_unused_security_group_info(self):
-        need_removed_ipsets = {constants.IPv4: set(),
-                               constants.IPv6: set()}
-        need_removed_security_groups = set()
-        remote_group_ids = {constants.IPv4: set(),
-                            constants.IPv6: set()}
-        current_group_ids = set()
-        for port in self.filtered_ports.values():
-            for direction in INGRESS_DIRECTION, EGRESS_DIRECTION:
-                for ethertype, sg_ids in self._get_remote_sg_ids(
-                        port, direction).items():
-                    remote_group_ids[ethertype].update(sg_ids)
-            groups = port.get('security_groups', [])
-            current_group_ids.update(groups)
-
-        for ethertype in [constants.IPv4, constants.IPv6]:
-            need_removed_ipsets[ethertype].update(
-                [x for x in self.pre_sg_members if x not in remote_group_ids[
-                    ethertype]])
-            need_removed_security_groups.update(
-                [x for x in self.pre_sg_rules if x not in current_group_ids])
-
-        # Remove unused ip sets (sg_members and kernel ipset if we
-        # are using ipset)
-        for ethertype, remove_set_ids in need_removed_ipsets.items():
-            for remove_set_id in remove_set_ids:
-                if self.sg_members.get(remove_set_id, {}).get(ethertype, []):
-                    self.sg_members[remove_set_id][ethertype] = []
-                if self.enable_ipset:
-                    self.ipset.destroy(remove_set_id, ethertype)
-
-        # Remove unused remote security group member ips
-        sg_ids = self.sg_members.keys()
-        for sg_id in sg_ids:
-            if not (self.sg_members[sg_id].get(constants.IPv4, [])
-                    or self.sg_members[sg_id].get(constants.IPv6, [])):
-                self.sg_members.pop(sg_id, None)
+        """Remove any unnecesary local security group info or unused ipsets.
+
+        This function has to be called after applying the last iptables
+        rules, so we're in a point where no iptable rule depends
+        on an ipset we're going to delete.
+        """
+        filtered_ports = self.filtered_ports.values()
+
+        remote_sgs_to_remove = self._determine_remote_sgs_to_remove(
+            filtered_ports)
+
+        for ip_version, remote_sg_ids in remote_sgs_to_remove.iteritems():
+            self._clear_sg_members(ip_version, remote_sg_ids)
+            if self.enable_ipset:
+                self._remove_ipsets_for_remote_sgs(ip_version, remote_sg_ids)
+
+        self._remove_unused_sg_members()
 
         # Remove unused security group rules
-        for remove_group_id in need_removed_security_groups:
-            if remove_group_id in self.sg_rules:
-                self.sg_rules.pop(remove_group_id, None)
+        for remove_group_id in self._determine_sg_rules_to_remove(
+                filtered_ports):
+            self.sg_rules.pop(remove_group_id, None)
+
+    def _determine_remote_sgs_to_remove(self, filtered_ports):
+        """Calculate which remote security groups we don't need anymore.
+
+        We do the calculation for each ip_version.
+        """
+        sgs_to_remove_per_ipversion = {constants.IPv4: set(),
+                                       constants.IPv6: set()}
+        remote_group_id_sets = self._get_remote_sg_ids_sets_by_ipversion(
+            filtered_ports)
+        for ip_version, remote_group_id_set in (
+                remote_group_id_sets.iteritems()):
+            sgs_to_remove_per_ipversion[ip_version].update(
+                set(self.pre_sg_members) - remote_group_id_set)
+        return sgs_to_remove_per_ipversion
+
+    def _get_remote_sg_ids_sets_by_ipversion(self, filtered_ports):
+        """Given a port, calculates the remote sg references by ip_version."""
+        remote_group_id_sets = {constants.IPv4: set(),
+                                constants.IPv6: set()}
+        for port in filtered_ports:
+            for ip_version, sg_ids in self._get_remote_sg_ids(
+                    port).iteritems():
+                remote_group_id_sets[ip_version].update(sg_ids)
+        return remote_group_id_sets
+
+    def _determine_sg_rules_to_remove(self, filtered_ports):
+        """Calculate which security groups need to be removed.
+
+        We find out by substracting our previous sg group ids,
+        with the security groups associated to a set of ports.
+        """
+        port_group_ids = self._get_sg_ids_set_for_ports(filtered_ports)
+        return set(self.pre_sg_rules) - port_group_ids
+
+    def _get_sg_ids_set_for_ports(self, filtered_ports):
+        """Get the port security group ids as a set."""
+        port_group_ids = set()
+        for port in filtered_ports:
+            port_group_ids.update(port.get('security_groups', []))
+        return port_group_ids
+
+    def _clear_sg_members(self, ip_version, remote_sg_ids):
+        """Clear our internal cache of sg members matching the parameters."""
+        for remote_sg_id in remote_sg_ids:
+            if self.sg_members[remote_sg_id][ip_version]:
+                self.sg_members[remote_sg_id][ip_version] = []
+
+    def _remove_ipsets_for_remote_sgs(self, ip_version, remote_sg_ids):
+        """Remove system ipsets matching the provided parameters."""
+        for remote_sg_id in remote_sg_ids:
+            self.ipset.destroy(remote_sg_id, ip_version)
+
+    def _remove_unused_sg_members(self):
+        """Remove sg_member entries where no IPv4 or IPv6 is associated."""
+        for sg_id in self.sg_members.keys():
+            sg_has_members = (self.sg_members[sg_id][constants.IPv4] or
+                              self.sg_members[sg_id][constants.IPv6])
+            if not sg_has_members:
+                del self.sg_members[sg_id]
 
     def filter_defer_apply_off(self):
         if self._defer_apply:
index 294a06b1b4df1b451a810e2c72996d892177f8e9..b7c08947cf0a16dcac76fb83d52a91de1490c6fe 100644 (file)
@@ -34,8 +34,11 @@ FAKE_PREFIX = {'IPv4': '10.0.0.0/24',
                'IPv6': 'fe80::/48'}
 FAKE_IP = {'IPv4': '10.0.0.1',
            'IPv6': 'fe80::1'}
-#TODO(mangelajo): replace all 'fake_sgid' strings for the constant
+#TODO(mangelajo): replace all '*_sgid' strings for the constants
 FAKE_SGID = 'fake_sgid'
+OTHER_SGID = 'other_sgid'
+_IPv6 = constants.IPv6
+_IPv4 = constants.IPv4
 
 
 class BaseIptablesFirewallTestCase(base.BaseTestCase):
@@ -1420,16 +1423,25 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
                 'security_groups': [sg_id],
                 'security_group_source_groups': [sg_id]}
 
-    def _fake_sg_rule_for_ethertype(self, ethertype):
-        return {'direction': 'ingress', 'remote_group_id': 'fake_sgid',
+    def _fake_sg_rule_for_ethertype(self, ethertype, remote_group):
+        return {'direction': 'ingress', 'remote_group_id': remote_group,
                 'ethertype': ethertype}
 
-    def _fake_sg_rule(self):
-        return {'fake_sgid': [self._fake_sg_rule_for_ethertype('IPv4'),
-                              self._fake_sg_rule_for_ethertype('IPv6')]}
+    def _fake_sg_rules(self, sg_id=FAKE_SGID, remote_groups=None):
+        remote_groups = remote_groups or {_IPv4: [FAKE_SGID],
+                                          _IPv6: [FAKE_SGID]}
+        rules = []
+        for ip_version, remote_group_list in remote_groups.iteritems():
+            for remote_group in remote_group_list:
+                rules.append(self._fake_sg_rule_for_ethertype(ip_version,
+                                                              remote_group))
+        return {sg_id: rules}
+
+    def _fake_sg_members(self, sg_ids=None):
+        return {sg_id: copy.copy(FAKE_IP) for sg_id in (sg_ids or [FAKE_SGID])}
 
     def test_prepare_port_filter_with_new_members(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
+        self.firewall.sg_rules = self._fake_sg_rules()
         self.firewall.sg_members = {'fake_sgid': {
             'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
         self.firewall.pre_sg_members = {}
@@ -1444,34 +1456,97 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.firewall.ipset.assert_has_calls(calls)
 
     def _setup_fake_firewall_members_and_rules(self, firewall):
-        firewall.sg_rules = self._fake_sg_rule()
-        firewall.pre_sg_rules = self._fake_sg_rule()
-        firewall.sg_members = {'fake_sgid': {
-            'IPv4': ['10.0.0.1'],
-            'IPv6': ['fe80::1']}}
+        firewall.sg_rules = self._fake_sg_rules()
+        firewall.pre_sg_rules = self._fake_sg_rules()
+        firewall.sg_members = self._fake_sg_members()
         firewall.pre_sg_members = firewall.sg_members
 
+    def _prepare_rules_and_members_for_removal(self):
+        self._setup_fake_firewall_members_and_rules(self.firewall)
+        self.firewall.pre_sg_members[OTHER_SGID] = (
+            self.firewall.pre_sg_members[FAKE_SGID])
+
+    def test_determine_remote_sgs_to_remove(self):
+        self._prepare_rules_and_members_for_removal()
+        ports = [self._fake_port()]
+
+        self.assertEqual(
+            {_IPv4: set([OTHER_SGID]), _IPv6: set([OTHER_SGID])},
+            self.firewall._determine_remote_sgs_to_remove(ports))
+
+    def test_determine_remote_sgs_to_remove_ipv6_unreferenced(self):
+        self._prepare_rules_and_members_for_removal()
+        ports = [self._fake_port()]
+        self.firewall.sg_rules = self._fake_sg_rules(
+            remote_groups={_IPv4: [OTHER_SGID, FAKE_SGID],
+                           _IPv6: [FAKE_SGID]})
+        self.assertEqual(
+            {_IPv4: set(), _IPv6: set([OTHER_SGID])},
+            self.firewall._determine_remote_sgs_to_remove(ports))
+
+    def test_get_remote_sg_ids_by_ipversion(self):
+        self.firewall.sg_rules = self._fake_sg_rules(
+            remote_groups={_IPv4: [FAKE_SGID], _IPv6: [OTHER_SGID]})
+
+        ports = [self._fake_port()]
+
+        self.assertEqual(
+            {_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
+            self.firewall._get_remote_sg_ids_sets_by_ipversion(ports))
+
+    def test_determine_sg_rules_to_remove(self):
+        self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID)
+        ports = [self._fake_port()]
+
+        self.assertEqual(set([OTHER_SGID]),
+                         self.firewall._determine_sg_rules_to_remove(ports))
+
+    def test_get_sg_ids_set_for_ports(self):
+        sg_ids = set([FAKE_SGID, OTHER_SGID])
+        ports = [self._fake_port(sg_id) for sg_id in sg_ids]
+
+        self.assertEqual(sg_ids,
+                         self.firewall._get_sg_ids_set_for_ports(ports))
+
+    def test_clear_sg_members(self):
+        self.firewall.sg_members = self._fake_sg_members(
+            sg_ids=[FAKE_SGID, OTHER_SGID])
+        self.firewall._clear_sg_members(_IPv4, [OTHER_SGID])
+
+        self.assertEqual(0, len(self.firewall.sg_members[OTHER_SGID][_IPv4]))
+
+    def test_remove_unused_sg_members(self):
+        self.firewall.sg_members = self._fake_sg_members([FAKE_SGID,
+                                                          OTHER_SGID])
+        self.firewall.sg_members[FAKE_SGID][_IPv4] = []
+        self.firewall.sg_members[FAKE_SGID][_IPv6] = []
+        self.firewall.sg_members[OTHER_SGID][_IPv6] = []
+        self.firewall._remove_unused_sg_members()
+
+        self.assertIn(OTHER_SGID, self.firewall.sg_members)
+        self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
+
     def test_remove_unused_security_group_info_clears_unused_rules(self):
         self._setup_fake_firewall_members_and_rules(self.firewall)
         self.firewall.prepare_port_filter(self._fake_port())
 
         # create another SG which won't be referenced by any filtered port
         fake_sg_rules = self.firewall.sg_rules['fake_sgid']
-        self.firewall.pre_sg_rules['other_sgid'] = fake_sg_rules
-        self.firewall.sg_rules['other_sgid'] = fake_sg_rules
+        self.firewall.pre_sg_rules[OTHER_SGID] = fake_sg_rules
+        self.firewall.sg_rules[OTHER_SGID] = fake_sg_rules
 
         # call the cleanup function, and check the unused sg_rules are out
         self.firewall._remove_unused_security_group_info()
-        self.assertNotIn('other_sgid', self.firewall.sg_rules)
+        self.assertNotIn(OTHER_SGID, self.firewall.sg_rules)
 
-    def test_remove_unused_sg_members(self):
+    def test_remove_unused_security_group_info(self):
         self._setup_fake_firewall_members_and_rules(self.firewall)
         # no filtered ports in 'fake_sgid', so all rules and members
         # are not needed and we expect them to be cleaned up
-        self.firewall.prepare_port_filter(self._fake_port('other_sgid'))
+        self.firewall.prepare_port_filter(self._fake_port(OTHER_SGID))
         self.firewall._remove_unused_security_group_info()
 
-        self.assertNotIn('fake_sgid', self.firewall.sg_members)
+        self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
 
     def test_remove_all_unused_info(self):
         self._setup_fake_firewall_members_and_rules(self.firewall)
@@ -1481,8 +1556,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.assertFalse(self.firewall.sg_rules)
 
     def test_prepare_port_filter_with_deleted_member(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.pre_sg_rules = self._fake_sg_rule()
+        self.firewall.sg_rules = self._fake_sg_rules()
+        self.firewall.pre_sg_rules = self._fake_sg_rules()
         self.firewall.sg_members = {'fake_sgid': {
             'IPv4': [
                 '10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
@@ -1500,7 +1575,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.firewall.ipset.assert_has_calls(calls, True)
 
     def test_remove_port_filter_with_destroy_ipset_chain(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
+        self.firewall.sg_rules = self._fake_sg_rules()
         port = self._fake_port()
         self.firewall.sg_members = {'fake_sgid': {
             'IPv4': ['10.0.0.1'],
@@ -1531,13 +1606,13 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.firewall.ipset.assert_has_calls(calls)
 
     def test_prepare_port_filter_with_sg_no_member(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.sg_rules['fake_sgid'].append(
+        self.firewall.sg_rules = self._fake_sg_rules()
+        self.firewall.sg_rules[FAKE_SGID].append(
             {'direction': 'ingress', 'remote_group_id': 'fake_sgid2',
              'ethertype': 'IPv4'})
         self.firewall.sg_rules.update()
-        self.firewall.sg_members = {'fake_sgid': {
-            'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
+        self.firewall.sg_members['fake_sgid'] = {
+            'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}
         self.firewall.pre_sg_members = {}
         port = self._fake_port()
         port['security_group_source_groups'].append('fake_sgid2')
@@ -1549,8 +1624,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.firewall.ipset.assert_has_calls(calls)
 
     def test_filter_defer_apply_off_with_sg_only_ipv6_rule(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.pre_sg_rules = self._fake_sg_rule()
+        self.firewall.sg_rules = self._fake_sg_rules()
+        self.firewall.pre_sg_rules = self._fake_sg_rules()
         self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2'],
                                       'IPv6fake_sgid': ['fe80::1']}
         self.firewall.sg_members = {'fake_sgid': {
@@ -1579,7 +1654,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
             'IPv6': [FAKE_IP['IPv6']]}}
 
         port = self._fake_port()
-        rule = self._fake_sg_rule_for_ethertype('IPv4')
+        rule = self._fake_sg_rule_for_ethertype(_IPv4, FAKE_SGID)
         rules = self.firewall._expand_sg_rule_with_remote_ips(
             rule, port, 'ingress')
         self.assertEqual(list(rules),