From: Miguel Angel Ajo Date: Tue, 3 Feb 2015 13:35:40 +0000 (+0000) Subject: Refactor _remove_unused_security_group_info X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=6bc82841c5da507e4d90d2bab96b2be2a6431ef4;p=openstack-build%2Fneutron-build.git Refactor _remove_unused_security_group_info _remove_unused_security_group_info is refactored into smaller functions, to make this block easier to understand. Implements blueprint refactor-iptables-firewall-driver Change-Id: I4107f1a702d059337e7b2d701a5d0372ee2cfe11 --- diff --git a/neutron/agent/linux/iptables_firewall.py b/neutron/agent/linux/iptables_firewall.py index 9a4d3bd8b..9896559dc 100644 --- a/neutron/agent/linux/iptables_firewall.py +++ b/neutron/agent/linux/iptables_firewall.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import netaddr from oslo.config import cfg @@ -64,7 +65,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver): self.sg_rules = {} self.pre_sg_rules = None # List of security group member ips for ports residing on this host - self.sg_members = {} + self.sg_members = collections.defaultdict( + lambda: collections.defaultdict(list)) self.pre_sg_members = None self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset @@ -78,7 +80,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver): def update_security_group_members(self, sg_id, sg_members): LOG.debug("Update members of security group (%s)", sg_id) - self.sg_members[sg_id] = sg_members + self.sg_members[sg_id] = collections.defaultdict(list, sg_members) def prepare_port_filter(self, port): LOG.debug("Preparing device (%s) filter", port['device']) @@ -323,12 +325,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver): else: yield rule - def _get_remote_sg_ids(self, port, direction): + def _get_remote_sg_ids(self, port, direction=None): sg_ids = port.get('security_groups', []) remote_sg_ids = {constants.IPv4: [], constants.IPv6: []} for sg_id in sg_ids: for rule in self.sg_rules.get(sg_id, []): - if rule['direction'] == direction: + if not direction or rule['direction'] == direction: remote_sg_id = rule.get('remote_group_id') ether_type = rule.get('ethertype') if remote_sg_id and ether_type: @@ -374,15 +376,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver): ipv6_iptables_rules) self._drop_dhcp_rule(ipv4_iptables_rules, ipv6_iptables_rules) - def _get_current_sg_member_ips(self, sg_id, ethertype): - return self.sg_members.get(sg_id, {}).get(ethertype, []) - def _update_ipset_members(self, security_group_ids): - for ethertype, sg_ids in security_group_ids.items(): + for ip_version, sg_ids in security_group_ids.items(): for sg_id in sg_ids: - current_ips = self._get_current_sg_member_ips(sg_id, ethertype) + current_ips = self.sg_members[sg_id][ip_version] if current_ips: - self.ipset.set_members(sg_id, ethertype, current_ips) + self.ipset.set_members(sg_id, ip_version, current_ips) def _generate_ipset_rule_args(self, sg_rule, remote_gid): ethertype = sg_rule.get('ethertype') @@ -505,47 +504,88 @@ class IptablesFirewallDriver(firewall.FirewallDriver): self._defer_apply = True def _remove_unused_security_group_info(self): - need_removed_ipsets = {constants.IPv4: set(), - constants.IPv6: set()} - need_removed_security_groups = set() - remote_group_ids = {constants.IPv4: set(), - constants.IPv6: set()} - current_group_ids = set() - for port in self.filtered_ports.values(): - for direction in INGRESS_DIRECTION, EGRESS_DIRECTION: - for ethertype, sg_ids in self._get_remote_sg_ids( - port, direction).items(): - remote_group_ids[ethertype].update(sg_ids) - groups = port.get('security_groups', []) - current_group_ids.update(groups) - - for ethertype in [constants.IPv4, constants.IPv6]: - need_removed_ipsets[ethertype].update( - [x for x in self.pre_sg_members if x not in remote_group_ids[ - ethertype]]) - need_removed_security_groups.update( - [x for x in self.pre_sg_rules if x not in current_group_ids]) - - # Remove unused ip sets (sg_members and kernel ipset if we - # are using ipset) - for ethertype, remove_set_ids in need_removed_ipsets.items(): - for remove_set_id in remove_set_ids: - if self.sg_members.get(remove_set_id, {}).get(ethertype, []): - self.sg_members[remove_set_id][ethertype] = [] - if self.enable_ipset: - self.ipset.destroy(remove_set_id, ethertype) - - # Remove unused remote security group member ips - sg_ids = self.sg_members.keys() - for sg_id in sg_ids: - if not (self.sg_members[sg_id].get(constants.IPv4, []) - or self.sg_members[sg_id].get(constants.IPv6, [])): - self.sg_members.pop(sg_id, None) + """Remove any unnecesary local security group info or unused ipsets. + + This function has to be called after applying the last iptables + rules, so we're in a point where no iptable rule depends + on an ipset we're going to delete. + """ + filtered_ports = self.filtered_ports.values() + + remote_sgs_to_remove = self._determine_remote_sgs_to_remove( + filtered_ports) + + for ip_version, remote_sg_ids in remote_sgs_to_remove.iteritems(): + self._clear_sg_members(ip_version, remote_sg_ids) + if self.enable_ipset: + self._remove_ipsets_for_remote_sgs(ip_version, remote_sg_ids) + + self._remove_unused_sg_members() # Remove unused security group rules - for remove_group_id in need_removed_security_groups: - if remove_group_id in self.sg_rules: - self.sg_rules.pop(remove_group_id, None) + for remove_group_id in self._determine_sg_rules_to_remove( + filtered_ports): + self.sg_rules.pop(remove_group_id, None) + + def _determine_remote_sgs_to_remove(self, filtered_ports): + """Calculate which remote security groups we don't need anymore. + + We do the calculation for each ip_version. + """ + sgs_to_remove_per_ipversion = {constants.IPv4: set(), + constants.IPv6: set()} + remote_group_id_sets = self._get_remote_sg_ids_sets_by_ipversion( + filtered_ports) + for ip_version, remote_group_id_set in ( + remote_group_id_sets.iteritems()): + sgs_to_remove_per_ipversion[ip_version].update( + set(self.pre_sg_members) - remote_group_id_set) + return sgs_to_remove_per_ipversion + + def _get_remote_sg_ids_sets_by_ipversion(self, filtered_ports): + """Given a port, calculates the remote sg references by ip_version.""" + remote_group_id_sets = {constants.IPv4: set(), + constants.IPv6: set()} + for port in filtered_ports: + for ip_version, sg_ids in self._get_remote_sg_ids( + port).iteritems(): + remote_group_id_sets[ip_version].update(sg_ids) + return remote_group_id_sets + + def _determine_sg_rules_to_remove(self, filtered_ports): + """Calculate which security groups need to be removed. + + We find out by substracting our previous sg group ids, + with the security groups associated to a set of ports. + """ + port_group_ids = self._get_sg_ids_set_for_ports(filtered_ports) + return set(self.pre_sg_rules) - port_group_ids + + def _get_sg_ids_set_for_ports(self, filtered_ports): + """Get the port security group ids as a set.""" + port_group_ids = set() + for port in filtered_ports: + port_group_ids.update(port.get('security_groups', [])) + return port_group_ids + + def _clear_sg_members(self, ip_version, remote_sg_ids): + """Clear our internal cache of sg members matching the parameters.""" + for remote_sg_id in remote_sg_ids: + if self.sg_members[remote_sg_id][ip_version]: + self.sg_members[remote_sg_id][ip_version] = [] + + def _remove_ipsets_for_remote_sgs(self, ip_version, remote_sg_ids): + """Remove system ipsets matching the provided parameters.""" + for remote_sg_id in remote_sg_ids: + self.ipset.destroy(remote_sg_id, ip_version) + + def _remove_unused_sg_members(self): + """Remove sg_member entries where no IPv4 or IPv6 is associated.""" + for sg_id in self.sg_members.keys(): + sg_has_members = (self.sg_members[sg_id][constants.IPv4] or + self.sg_members[sg_id][constants.IPv6]) + if not sg_has_members: + del self.sg_members[sg_id] def filter_defer_apply_off(self): if self._defer_apply: diff --git a/neutron/tests/unit/test_iptables_firewall.py b/neutron/tests/unit/test_iptables_firewall.py index 294a06b1b..b7c08947c 100644 --- a/neutron/tests/unit/test_iptables_firewall.py +++ b/neutron/tests/unit/test_iptables_firewall.py @@ -34,8 +34,11 @@ FAKE_PREFIX = {'IPv4': '10.0.0.0/24', 'IPv6': 'fe80::/48'} FAKE_IP = {'IPv4': '10.0.0.1', 'IPv6': 'fe80::1'} -#TODO(mangelajo): replace all 'fake_sgid' strings for the constant +#TODO(mangelajo): replace all '*_sgid' strings for the constants FAKE_SGID = 'fake_sgid' +OTHER_SGID = 'other_sgid' +_IPv6 = constants.IPv6 +_IPv4 = constants.IPv4 class BaseIptablesFirewallTestCase(base.BaseTestCase): @@ -1420,16 +1423,25 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): 'security_groups': [sg_id], 'security_group_source_groups': [sg_id]} - def _fake_sg_rule_for_ethertype(self, ethertype): - return {'direction': 'ingress', 'remote_group_id': 'fake_sgid', + def _fake_sg_rule_for_ethertype(self, ethertype, remote_group): + return {'direction': 'ingress', 'remote_group_id': remote_group, 'ethertype': ethertype} - def _fake_sg_rule(self): - return {'fake_sgid': [self._fake_sg_rule_for_ethertype('IPv4'), - self._fake_sg_rule_for_ethertype('IPv6')]} + def _fake_sg_rules(self, sg_id=FAKE_SGID, remote_groups=None): + remote_groups = remote_groups or {_IPv4: [FAKE_SGID], + _IPv6: [FAKE_SGID]} + rules = [] + for ip_version, remote_group_list in remote_groups.iteritems(): + for remote_group in remote_group_list: + rules.append(self._fake_sg_rule_for_ethertype(ip_version, + remote_group)) + return {sg_id: rules} + + def _fake_sg_members(self, sg_ids=None): + return {sg_id: copy.copy(FAKE_IP) for sg_id in (sg_ids or [FAKE_SGID])} def test_prepare_port_filter_with_new_members(self): - self.firewall.sg_rules = self._fake_sg_rule() + self.firewall.sg_rules = self._fake_sg_rules() self.firewall.sg_members = {'fake_sgid': { 'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}} self.firewall.pre_sg_members = {} @@ -1444,34 +1456,97 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): self.firewall.ipset.assert_has_calls(calls) def _setup_fake_firewall_members_and_rules(self, firewall): - firewall.sg_rules = self._fake_sg_rule() - firewall.pre_sg_rules = self._fake_sg_rule() - firewall.sg_members = {'fake_sgid': { - 'IPv4': ['10.0.0.1'], - 'IPv6': ['fe80::1']}} + firewall.sg_rules = self._fake_sg_rules() + firewall.pre_sg_rules = self._fake_sg_rules() + firewall.sg_members = self._fake_sg_members() firewall.pre_sg_members = firewall.sg_members + def _prepare_rules_and_members_for_removal(self): + self._setup_fake_firewall_members_and_rules(self.firewall) + self.firewall.pre_sg_members[OTHER_SGID] = ( + self.firewall.pre_sg_members[FAKE_SGID]) + + def test_determine_remote_sgs_to_remove(self): + self._prepare_rules_and_members_for_removal() + ports = [self._fake_port()] + + self.assertEqual( + {_IPv4: set([OTHER_SGID]), _IPv6: set([OTHER_SGID])}, + self.firewall._determine_remote_sgs_to_remove(ports)) + + def test_determine_remote_sgs_to_remove_ipv6_unreferenced(self): + self._prepare_rules_and_members_for_removal() + ports = [self._fake_port()] + self.firewall.sg_rules = self._fake_sg_rules( + remote_groups={_IPv4: [OTHER_SGID, FAKE_SGID], + _IPv6: [FAKE_SGID]}) + self.assertEqual( + {_IPv4: set(), _IPv6: set([OTHER_SGID])}, + self.firewall._determine_remote_sgs_to_remove(ports)) + + def test_get_remote_sg_ids_by_ipversion(self): + self.firewall.sg_rules = self._fake_sg_rules( + remote_groups={_IPv4: [FAKE_SGID], _IPv6: [OTHER_SGID]}) + + ports = [self._fake_port()] + + self.assertEqual( + {_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])}, + self.firewall._get_remote_sg_ids_sets_by_ipversion(ports)) + + def test_determine_sg_rules_to_remove(self): + self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID) + ports = [self._fake_port()] + + self.assertEqual(set([OTHER_SGID]), + self.firewall._determine_sg_rules_to_remove(ports)) + + def test_get_sg_ids_set_for_ports(self): + sg_ids = set([FAKE_SGID, OTHER_SGID]) + ports = [self._fake_port(sg_id) for sg_id in sg_ids] + + self.assertEqual(sg_ids, + self.firewall._get_sg_ids_set_for_ports(ports)) + + def test_clear_sg_members(self): + self.firewall.sg_members = self._fake_sg_members( + sg_ids=[FAKE_SGID, OTHER_SGID]) + self.firewall._clear_sg_members(_IPv4, [OTHER_SGID]) + + self.assertEqual(0, len(self.firewall.sg_members[OTHER_SGID][_IPv4])) + + def test_remove_unused_sg_members(self): + self.firewall.sg_members = self._fake_sg_members([FAKE_SGID, + OTHER_SGID]) + self.firewall.sg_members[FAKE_SGID][_IPv4] = [] + self.firewall.sg_members[FAKE_SGID][_IPv6] = [] + self.firewall.sg_members[OTHER_SGID][_IPv6] = [] + self.firewall._remove_unused_sg_members() + + self.assertIn(OTHER_SGID, self.firewall.sg_members) + self.assertNotIn(FAKE_SGID, self.firewall.sg_members) + def test_remove_unused_security_group_info_clears_unused_rules(self): self._setup_fake_firewall_members_and_rules(self.firewall) self.firewall.prepare_port_filter(self._fake_port()) # create another SG which won't be referenced by any filtered port fake_sg_rules = self.firewall.sg_rules['fake_sgid'] - self.firewall.pre_sg_rules['other_sgid'] = fake_sg_rules - self.firewall.sg_rules['other_sgid'] = fake_sg_rules + self.firewall.pre_sg_rules[OTHER_SGID] = fake_sg_rules + self.firewall.sg_rules[OTHER_SGID] = fake_sg_rules # call the cleanup function, and check the unused sg_rules are out self.firewall._remove_unused_security_group_info() - self.assertNotIn('other_sgid', self.firewall.sg_rules) + self.assertNotIn(OTHER_SGID, self.firewall.sg_rules) - def test_remove_unused_sg_members(self): + def test_remove_unused_security_group_info(self): self._setup_fake_firewall_members_and_rules(self.firewall) # no filtered ports in 'fake_sgid', so all rules and members # are not needed and we expect them to be cleaned up - self.firewall.prepare_port_filter(self._fake_port('other_sgid')) + self.firewall.prepare_port_filter(self._fake_port(OTHER_SGID)) self.firewall._remove_unused_security_group_info() - self.assertNotIn('fake_sgid', self.firewall.sg_members) + self.assertNotIn(FAKE_SGID, self.firewall.sg_members) def test_remove_all_unused_info(self): self._setup_fake_firewall_members_and_rules(self.firewall) @@ -1481,8 +1556,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): self.assertFalse(self.firewall.sg_rules) def test_prepare_port_filter_with_deleted_member(self): - self.firewall.sg_rules = self._fake_sg_rule() - self.firewall.pre_sg_rules = self._fake_sg_rule() + self.firewall.sg_rules = self._fake_sg_rules() + self.firewall.pre_sg_rules = self._fake_sg_rules() self.firewall.sg_members = {'fake_sgid': { 'IPv4': [ '10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'], @@ -1500,7 +1575,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): self.firewall.ipset.assert_has_calls(calls, True) def test_remove_port_filter_with_destroy_ipset_chain(self): - self.firewall.sg_rules = self._fake_sg_rule() + self.firewall.sg_rules = self._fake_sg_rules() port = self._fake_port() self.firewall.sg_members = {'fake_sgid': { 'IPv4': ['10.0.0.1'], @@ -1531,13 +1606,13 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): self.firewall.ipset.assert_has_calls(calls) def test_prepare_port_filter_with_sg_no_member(self): - self.firewall.sg_rules = self._fake_sg_rule() - self.firewall.sg_rules['fake_sgid'].append( + self.firewall.sg_rules = self._fake_sg_rules() + self.firewall.sg_rules[FAKE_SGID].append( {'direction': 'ingress', 'remote_group_id': 'fake_sgid2', 'ethertype': 'IPv4'}) self.firewall.sg_rules.update() - self.firewall.sg_members = {'fake_sgid': { - 'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}} + self.firewall.sg_members['fake_sgid'] = { + 'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']} self.firewall.pre_sg_members = {} port = self._fake_port() port['security_group_source_groups'].append('fake_sgid2') @@ -1549,8 +1624,8 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): self.firewall.ipset.assert_has_calls(calls) def test_filter_defer_apply_off_with_sg_only_ipv6_rule(self): - self.firewall.sg_rules = self._fake_sg_rule() - self.firewall.pre_sg_rules = self._fake_sg_rule() + self.firewall.sg_rules = self._fake_sg_rules() + self.firewall.pre_sg_rules = self._fake_sg_rules() self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2'], 'IPv6fake_sgid': ['fe80::1']} self.firewall.sg_members = {'fake_sgid': { @@ -1579,7 +1654,7 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase): 'IPv6': [FAKE_IP['IPv6']]}} port = self._fake_port() - rule = self._fake_sg_rule_for_ethertype('IPv4') + rule = self._fake_sg_rule_for_ethertype(_IPv4, FAKE_SGID) rules = self.firewall._expand_sg_rule_with_remote_ips( rule, port, 'ingress') self.assertEqual(list(rules),