def _get_remote_sg_ids(self, port, direction=None):
sg_ids = port.get('security_groups', [])
- remote_sg_ids = {constants.IPv4: [], constants.IPv6: []}
+ remote_sg_ids = {constants.IPv4: set(), constants.IPv6: set()}
for sg_id in sg_ids:
for rule in self.sg_rules.get(sg_id, []):
if not direction or rule['direction'] == direction:
remote_sg_id = rule.get('remote_group_id')
ether_type = rule.get('ethertype')
if remote_sg_id and ether_type:
- remote_sg_ids[ether_type].append(remote_sg_id)
+ remote_sg_ids[ether_type].add(remote_sg_id)
return remote_sg_ids
def _add_rules_by_security_group(self, port, direction):
constants.IPv6: set()}
for port in filtered_ports:
remote_sg_ids = self._get_remote_sg_ids(port)
- for ip_version, sg_ids in six.iteritems(remote_sg_ids):
- remote_group_id_sets[ip_version].update(sg_ids)
+ for ip_version in (constants.IPv4, constants.IPv6):
+ remote_group_id_sets[ip_version] |= remote_sg_ids[ip_version]
return remote_group_id_sets
def _determine_sg_rules_to_remove(self, filtered_ports):
{_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
self.firewall._get_remote_sg_ids_sets_by_ipversion(ports))
+ def test_get_remote_sg_ids(self):
+ self.firewall.sg_rules = self._fake_sg_rules(
+ remote_groups={_IPv4: [FAKE_SGID, FAKE_SGID, FAKE_SGID],
+ _IPv6: [OTHER_SGID, OTHER_SGID, OTHER_SGID]})
+
+ port = self._fake_port()
+
+ self.assertEqual(
+ {_IPv4: set([FAKE_SGID]), _IPv6: set([OTHER_SGID])},
+ self.firewall._get_remote_sg_ids(port))
+
def test_determine_sg_rules_to_remove(self):
self.firewall.pre_sg_rules = self._fake_sg_rules(sg_id=OTHER_SGID)
ports = [self._fake_port()]