# License for the specific language governing permissions and limitations
# under the License.
+import collections
import netaddr
from oslo.config import cfg
self.sg_rules = {}
self.pre_sg_rules = None
# List of security group member ips for ports residing on this host
- self.sg_members = {}
+ self.sg_members = collections.defaultdict(
+ lambda: collections.defaultdict(list))
self.pre_sg_members = None
self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
def update_security_group_members(self, sg_id, sg_members):
LOG.debug("Update members of security group (%s)", sg_id)
- self.sg_members[sg_id] = sg_members
+ self.sg_members[sg_id] = collections.defaultdict(list, sg_members)
def prepare_port_filter(self, port):
LOG.debug("Preparing device (%s) filter", port['device'])
else:
yield rule
- def _get_remote_sg_ids(self, port, direction):
+ def _get_remote_sg_ids(self, port, direction=None):
sg_ids = port.get('security_groups', [])
remote_sg_ids = {constants.IPv4: [], constants.IPv6: []}
for sg_id in sg_ids:
for rule in self.sg_rules.get(sg_id, []):
- if rule['direction'] == direction:
+ if not direction or rule['direction'] == direction:
remote_sg_id = rule.get('remote_group_id')
ether_type = rule.get('ethertype')
if remote_sg_id and ether_type:
ipv6_iptables_rules)
self._drop_dhcp_rule(ipv4_iptables_rules, ipv6_iptables_rules)
- def _get_current_sg_member_ips(self, sg_id, ethertype):
- return self.sg_members.get(sg_id, {}).get(ethertype, [])
-
def _update_ipset_members(self, security_group_ids):
- for ethertype, sg_ids in security_group_ids.items():
+ for ip_version, sg_ids in security_group_ids.items():
for sg_id in sg_ids:
- current_ips = self._get_current_sg_member_ips(sg_id, ethertype)
+ current_ips = self.sg_members[sg_id][ip_version]
if current_ips:
- self.ipset.set_members(sg_id, ethertype, current_ips)
+ self.ipset.set_members(sg_id, ip_version, current_ips)
def _generate_ipset_rule_args(self, sg_rule, remote_gid):
ethertype = sg_rule.get('ethertype')
self._defer_apply = True
def _remove_unused_security_group_info(self):
- need_removed_ipsets = {constants.IPv4: set(),
- constants.IPv6: set()}
- need_removed_security_groups = set()
- remote_group_ids = {constants.IPv4: set(),
- constants.IPv6: set()}
- current_group_ids = set()
- for port in self.filtered_ports.values():
- for direction in INGRESS_DIRECTION, EGRESS_DIRECTION:
- for ethertype, sg_ids in self._get_remote_sg_ids(
- port, direction).items():
- remote_group_ids[ethertype].update(sg_ids)
- groups = port.get('security_groups', [])
- current_group_ids.update(groups)
-
- for ethertype in [constants.IPv4, constants.IPv6]:
- need_removed_ipsets[ethertype].update(
- [x for x in self.pre_sg_members if x not in remote_group_ids[
- ethertype]])
- need_removed_security_groups.update(
- [x for x in self.pre_sg_rules if x not in current_group_ids])
-
- # Remove unused ip sets (sg_members and kernel ipset if we
- # are using ipset)
- for ethertype, remove_set_ids in need_removed_ipsets.items():
- for remove_set_id in remove_set_ids:
- if self.sg_members.get(remove_set_id, {}).get(ethertype, []):
- self.sg_members[remove_set_id][ethertype] = []
- if self.enable_ipset:
- self.ipset.destroy(remove_set_id, ethertype)
-
- # Remove unused remote security group member ips
- sg_ids = self.sg_members.keys()
- for sg_id in sg_ids:
- if not (self.sg_members[sg_id].get(constants.IPv4, [])
- or self.sg_members[sg_id].get(constants.IPv6, [])):
- self.sg_members.pop(sg_id, None)
+ """Remove any unnecesary local security group info or unused ipsets.
+
+ This function has to be called after applying the last iptables
+ rules, so we're in a point where no iptable rule depends
+ on an ipset we're going to delete.
+ """
+ filtered_ports = self.filtered_ports.values()
+
+ remote_sgs_to_remove = self._determine_remote_sgs_to_remove(
+ filtered_ports)
+
+ for ip_version, remote_sg_ids in remote_sgs_to_remove.iteritems():
+ self._clear_sg_members(ip_version, remote_sg_ids)
+ if self.enable_ipset:
+ self._remove_ipsets_for_remote_sgs(ip_version, remote_sg_ids)
+
+ self._remove_unused_sg_members()
# Remove unused security group rules
- for remove_group_id in need_removed_security_groups:
- if remove_group_id in self.sg_rules:
- self.sg_rules.pop(remove_group_id, None)
+ for remove_group_id in self._determine_sg_rules_to_remove(
+ filtered_ports):
+ self.sg_rules.pop(remove_group_id, None)
+
+ def _determine_remote_sgs_to_remove(self, filtered_ports):
+ """Calculate which remote security groups we don't need anymore.
+
+ We do the calculation for each ip_version.
+ """
+ sgs_to_remove_per_ipversion = {constants.IPv4: set(),
+ constants.IPv6: set()}
+ remote_group_id_sets = self._get_remote_sg_ids_sets_by_ipversion(
+ filtered_ports)
+ for ip_version, remote_group_id_set in (
+ remote_group_id_sets.iteritems()):
+ sgs_to_remove_per_ipversion[ip_version].update(
+ set(self.pre_sg_members) - remote_group_id_set)
+ return sgs_to_remove_per_ipversion
+
+ def _get_remote_sg_ids_sets_by_ipversion(self, filtered_ports):
+ """Given a port, calculates the remote sg references by ip_version."""
+ remote_group_id_sets = {constants.IPv4: set(),
+ constants.IPv6: set()}
+ for port in filtered_ports:
+ for ip_version, sg_ids in self._get_remote_sg_ids(
+ port).iteritems():
+ remote_group_id_sets[ip_version].update(sg_ids)
+ return remote_group_id_sets
+
+ def _determine_sg_rules_to_remove(self, filtered_ports):
+ """Calculate which security groups need to be removed.
+
+ We find out by substracting our previous sg group ids,
+ with the security groups associated to a set of ports.
+ """
+ port_group_ids = self._get_sg_ids_set_for_ports(filtered_ports)
+ return set(self.pre_sg_rules) - port_group_ids
+
+ def _get_sg_ids_set_for_ports(self, filtered_ports):
+ """Get the port security group ids as a set."""
+ port_group_ids = set()
+ for port in filtered_ports:
+ port_group_ids.update(port.get('security_groups', []))
+ return port_group_ids
+
+ def _clear_sg_members(self, ip_version, remote_sg_ids):
+ """Clear our internal cache of sg members matching the parameters."""
+ for remote_sg_id in remote_sg_ids:
+ if self.sg_members[remote_sg_id][ip_version]:
+ self.sg_members[remote_sg_id][ip_version] = []
+
+ def _remove_ipsets_for_remote_sgs(self, ip_version, remote_sg_ids):
+ """Remove system ipsets matching the provided parameters."""
+ for remote_sg_id in remote_sg_ids:
+ self.ipset.destroy(remote_sg_id, ip_version)
+
+ def _remove_unused_sg_members(self):
+ """Remove sg_member entries where no IPv4 or IPv6 is associated."""
+ for sg_id in self.sg_members.keys():
+ sg_has_members = (self.sg_members[sg_id][constants.IPv4] or
+ self.sg_members[sg_id][constants.IPv6])
+ if not sg_has_members:
+ del self.sg_members[sg_id]
def filter_defer_apply_off(self):
if self._defer_apply:
'IPv6': 'fe80::/48'}
FAKE_IP = {'IPv4': '10.0.0.1',
'IPv6': 'fe80::1'}
-#TODO(mangelajo): replace all 'fake_sgid' strings for the constant
+#TODO(mangelajo): replace all '*_sgid' strings for the constants
FAKE_SGID = 'fake_sgid'
+OTHER_SGID = 'other_sgid'
+_IPv6 = constants.IPv6
+_IPv4 = constants.IPv4
class BaseIptablesFirewallTestCase(base.BaseTestCase):
'security_groups': [sg_id],
'security_group_source_groups': [sg_id]}
- def _fake_sg_rule_for_ethertype(self, ethertype):
- return {'direction': 'ingress', 'remote_group_id': 'fake_sgid',
+ def _fake_sg_rule_for_ethertype(self, ethertype, remote_group):
+ return {'direction': 'ingress', 'remote_group_id': remote_group,
'ethertype': ethertype}
- def _fake_sg_rule(self):
- return {'fake_sgid': [self._fake_sg_rule_for_ethertype('IPv4'),
- self._fake_sg_rule_for_ethertype('IPv6')]}
+ def _fake_sg_rules(self, sg_id=FAKE_SGID, remote_groups=None):
+ remote_groups = remote_groups or {_IPv4: [FAKE_SGID],
+ _IPv6: [FAKE_SGID]}
+ rules = []
+ for ip_version, remote_group_list in remote_groups.iteritems():
+ for remote_group in remote_group_list:
+ rules.append(self._fake_sg_rule_for_ethertype(ip_version,
+ remote_group))
+ return {sg_id: rules}
+
+ def _fake_sg_members(self, sg_ids=None):
+ return {sg_id: copy.copy(FAKE_IP) for sg_id in (sg_ids or [FAKE_SGID])}
def test_prepare_port_filter_with_new_members(self):
- self.firewall.sg_rules = self._fake_sg_rule()
+ self.firewall.sg_rules = self._fake_sg_rules()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
self.firewall.pre_sg_members = {}
self.firewall.ipset.assert_has_calls(calls)
def _setup_fake_firewall_members_and_rules(self, firewall):
- firewall.sg_rules = self._fake_sg_rule()
- firewall.pre_sg_rules = self._fake_sg_rule()
- firewall.sg_members = {'fake_sgid': {
- 'IPv4': ['10.0.0.1'],
- 'IPv6': ['fe80::1']}}
+ firewall.sg_rules = self._fake_sg_rules()
+ firewall.pre_sg_rules = self._fake_sg_rules()
+ firewall.sg_members = self._fake_sg_members()
firewall.pre_sg_members = firewall.sg_members
+ def _prepare_rules_and_members_for_removal(self):
+ self._setup_fake_firewall_members_and_rules(self.firewall)
+ self.firewall.pre_sg_members[OTHER_SGID] = (
+ self.firewall.pre_sg_members[FAKE_SGID])
+
+ def test_determine_remote_sgs_to_remove(self):
+ self._prepare_rules_and_members_for_removal()
+ ports = [self._fake_port()]
+
+ self.assertEqual(
+ {_IPv4: set([OTHER_SGID]), _IPv6: set([OTHER_SGID])},
+ self.firewall._determine_remote_sgs_to_remove(ports))
+
+ def test_determine_remote_sgs_to_remove_ipv6_unreferenced(self):
+ self._prepare_rules_and_members_for_removal()
+ ports = [self._fake_port()]
+ self.firewall.sg_rules = self._fake_sg_rules(
+ remote_groups={_IPv4: [OTHER_SGID, FAKE_SGID],
+ _IPv6: [FAKE_SGID]})
+ self.assertEqual(
+ {_IPv4: set(), _IPv6: set([OTHER_SGID])},
+ self.firewall._determine_remote_sgs_to_remove(ports))
+
+ def test_get_remote_sg_ids_by_ipversion(self):
+ self.firewall.sg_rules = self._fake_sg_rules(
+ remote_groups={_IPv4: [FAKE_SGID], _IPv6: [OTHER_SGID]})
+
+ ports = [self._fake_port()]
+
+ self.assertEqual(
+ {_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
+ self.firewall._get_remote_sg_ids_sets_by_ipversion(ports))
+
+ def test_determine_sg_rules_to_remove(self):
+ self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID)
+ ports = [self._fake_port()]
+
+ self.assertEqual(set([OTHER_SGID]),
+ self.firewall._determine_sg_rules_to_remove(ports))
+
+ def test_get_sg_ids_set_for_ports(self):
+ sg_ids = set([FAKE_SGID, OTHER_SGID])
+ ports = [self._fake_port(sg_id) for sg_id in sg_ids]
+
+ self.assertEqual(sg_ids,
+ self.firewall._get_sg_ids_set_for_ports(ports))
+
+ def test_clear_sg_members(self):
+ self.firewall.sg_members = self._fake_sg_members(
+ sg_ids=[FAKE_SGID, OTHER_SGID])
+ self.firewall._clear_sg_members(_IPv4, [OTHER_SGID])
+
+ self.assertEqual(0, len(self.firewall.sg_members[OTHER_SGID][_IPv4]))
+
+ def test_remove_unused_sg_members(self):
+ self.firewall.sg_members = self._fake_sg_members([FAKE_SGID,
+ OTHER_SGID])
+ self.firewall.sg_members[FAKE_SGID][_IPv4] = []
+ self.firewall.sg_members[FAKE_SGID][_IPv6] = []
+ self.firewall.sg_members[OTHER_SGID][_IPv6] = []
+ self.firewall._remove_unused_sg_members()
+
+ self.assertIn(OTHER_SGID, self.firewall.sg_members)
+ self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
+
def test_remove_unused_security_group_info_clears_unused_rules(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
self.firewall.prepare_port_filter(self._fake_port())
# create another SG which won't be referenced by any filtered port
fake_sg_rules = self.firewall.sg_rules['fake_sgid']
- self.firewall.pre_sg_rules['other_sgid'] = fake_sg_rules
- self.firewall.sg_rules['other_sgid'] = fake_sg_rules
+ self.firewall.pre_sg_rules[OTHER_SGID] = fake_sg_rules
+ self.firewall.sg_rules[OTHER_SGID] = fake_sg_rules
# call the cleanup function, and check the unused sg_rules are out
self.firewall._remove_unused_security_group_info()
- self.assertNotIn('other_sgid', self.firewall.sg_rules)
+ self.assertNotIn(OTHER_SGID, self.firewall.sg_rules)
- def test_remove_unused_sg_members(self):
+ def test_remove_unused_security_group_info(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
# no filtered ports in 'fake_sgid', so all rules and members
# are not needed and we expect them to be cleaned up
- self.firewall.prepare_port_filter(self._fake_port('other_sgid'))
+ self.firewall.prepare_port_filter(self._fake_port(OTHER_SGID))
self.firewall._remove_unused_security_group_info()
- self.assertNotIn('fake_sgid', self.firewall.sg_members)
+ self.assertNotIn(FAKE_SGID, self.firewall.sg_members)
def test_remove_all_unused_info(self):
self._setup_fake_firewall_members_and_rules(self.firewall)
self.assertFalse(self.firewall.sg_rules)
def test_prepare_port_filter_with_deleted_member(self):
- self.firewall.sg_rules = self._fake_sg_rule()
- self.firewall.pre_sg_rules = self._fake_sg_rule()
+ self.firewall.sg_rules = self._fake_sg_rules()
+ self.firewall.pre_sg_rules = self._fake_sg_rules()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': [
'10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
self.firewall.ipset.assert_has_calls(calls, True)
def test_remove_port_filter_with_destroy_ipset_chain(self):
- self.firewall.sg_rules = self._fake_sg_rule()
+ self.firewall.sg_rules = self._fake_sg_rules()
port = self._fake_port()
self.firewall.sg_members = {'fake_sgid': {
'IPv4': ['10.0.0.1'],
self.firewall.ipset.assert_has_calls(calls)
def test_prepare_port_filter_with_sg_no_member(self):
- self.firewall.sg_rules = self._fake_sg_rule()
- self.firewall.sg_rules['fake_sgid'].append(
+ self.firewall.sg_rules = self._fake_sg_rules()
+ self.firewall.sg_rules[FAKE_SGID].append(
{'direction': 'ingress', 'remote_group_id': 'fake_sgid2',
'ethertype': 'IPv4'})
self.firewall.sg_rules.update()
- self.firewall.sg_members = {'fake_sgid': {
- 'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
+ self.firewall.sg_members['fake_sgid'] = {
+ 'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}
self.firewall.pre_sg_members = {}
port = self._fake_port()
port['security_group_source_groups'].append('fake_sgid2')
self.firewall.ipset.assert_has_calls(calls)
def test_filter_defer_apply_off_with_sg_only_ipv6_rule(self):
- self.firewall.sg_rules = self._fake_sg_rule()
- self.firewall.pre_sg_rules = self._fake_sg_rule()
+ self.firewall.sg_rules = self._fake_sg_rules()
+ self.firewall.pre_sg_rules = self._fake_sg_rules()
self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2'],
'IPv6fake_sgid': ['fe80::1']}
self.firewall.sg_members = {'fake_sgid': {
'IPv6': [FAKE_IP['IPv6']]}}
port = self._fake_port()
- rule = self._fake_sg_rule_for_ethertype('IPv4')
+ rule = self._fake_sg_rule_for_ethertype(_IPv4, FAKE_SGID)
rules = self.firewall._expand_sg_rule_with_remote_ips(
rule, port, 'ingress')
self.assertEqual(list(rules),