]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
IpsetManager refactoring
authorMiguel Angel Ajo <majopela@redhat.com>
Thu, 11 Sep 2014 15:05:57 +0000 (17:05 +0200)
committerMiguel Angel Ajo <mangelajo@redhat.com>
Sat, 6 Dec 2014 17:21:43 +0000 (18:21 +0100)
Refactor the IpsetManager to move all the low level
knowledge about the ipset behaviour and performance
considerations from the firewall driver to the manager,
reduce redundant function names, and change missleading
variables talking about ipset chains, where they
should say ipset sets.

No logical changes to behaviour, just responsibilities
moved from one class to another.

Unit testing is move from iptables_firewall to ipset_manager
to test the new responsibilities of the class.

Implements: blueprint ipset-manager-refactor

Change-Id: I93000a37a71cd22753b32edbd0a5f3c9cb8b0bde

neutron/agent/linux/ipset_manager.py
neutron/agent/linux/iptables_firewall.py
neutron/tests/functional/agent/linux/test_ipset.py
neutron/tests/unit/agent/linux/test_ipset_manager.py [new file with mode: 0644]
neutron/tests/unit/test_iptables_firewall.py

index ddd736e8e8afca8ad2a2d56dbc0fd0845d513c94..703eb00d1fc661dd0685bcfac6225a0b5c0af8ac 100644 (file)
 from neutron.agent.linux import utils as linux_utils
 from neutron.common import utils
 
+IPSET_ADD_BULK_THRESHOLD = 5
+SWAP_SUFFIX = '-new'
+IPSET_NAME_MAX_LENGTH = 31 - len(SWAP_SUFFIX)
+
 
 class IpsetManager(object):
-    """Wrapper for ipset."""
+    """Smart wrapper for ipset.
+
+       Keeps track of ip addresses per set, using bulk
+       or single ip add/remove for smaller changes.
+    """
 
     def __init__(self, execute=None, root_helper=None, namespace=None):
         self.execute = execute or linux_utils.execute
         self.root_helper = root_helper
         self.namespace = namespace
+        self.ipset_sets = {}
+
+    @staticmethod
+    def get_name(id, ethertype):
+        """Returns the given ipset name for an id+ethertype pair.
+        This reference can be used from iptables.
+        """
+        name = ethertype + id
+        return name[:IPSET_NAME_MAX_LENGTH]
+
+    def set_exists(self, id, ethertype):
+        """Returns true if the id+ethertype pair is known to the manager."""
+        set_name = self.get_name(id, ethertype)
+        return set_name in self.ipset_sets
 
     @utils.synchronized('ipset', external=True)
-    def create_ipset_chain(self, chain_name, ethertype):
-        cmd = ['ipset', 'create', '-exist', chain_name, 'hash:ip', 'family',
-               self._get_ipset_chain_type(ethertype)]
-        self._apply(cmd)
+    def set_members(self, id, ethertype, member_ips):
+        """Create or update a specific set by name and ethertype.
+        It will make sure that a set is created, updated to
+        add / remove new members, or swapped atomically if
+        that's faster.
+        """
+        set_name = self.get_name(id, ethertype)
+        if not self.set_exists(id, ethertype):
+            # The initial creation is handled with create/refresh to
+            # avoid any downtime for existing sets (i.e. avoiding
+            # a flush/restore), as the restore operation of ipset is
+            # additive to the existing set.
+            self._create_set(set_name, ethertype)
+            self._refresh_set(set_name, member_ips, ethertype)
+            # TODO(majopela,shihanzhang,haleyb): Optimize this by
+            # gathering the system ipsets at start. So we can determine
+            # if a normal restore is enough for initial creation.
+            # That should speed up agent boot up time.
+        else:
+            add_ips = self._get_new_set_ips(set_name, member_ips)
+            del_ips = self._get_deleted_set_ips(set_name, member_ips)
+            if (len(add_ips) + len(del_ips) < IPSET_ADD_BULK_THRESHOLD):
+                self._add_members_to_set(set_name, add_ips)
+                self._del_members_from_set(set_name, del_ips)
+            else:
+                self._refresh_set(set_name, member_ips, ethertype)
 
     @utils.synchronized('ipset', external=True)
-    def add_member_to_ipset_chain(self, chain_name, member_ip):
-        cmd = ['ipset', 'add', '-exist', chain_name, member_ip]
+    def destroy(self, id, ethertype, forced=False):
+        set_name = self.get_name(id, ethertype)
+        self._destroy(set_name, forced)
+
+    def _add_member_to_set(self, set_name, member_ip):
+        cmd = ['ipset', 'add', '-exist', set_name, member_ip]
         self._apply(cmd)
+        self.ipset_sets[set_name].append(member_ip)
 
-    @utils.synchronized('ipset', external=True)
-    def refresh_ipset_chain_by_name(self, chain_name, member_ips, ethertype):
-        new_chain_name = chain_name + '-new'
-        chain_type = self._get_ipset_chain_type(ethertype)
-        process_input = ["create %s hash:ip family %s" % (new_chain_name,
-                                                          chain_type)]
+    def _refresh_set(self, set_name, member_ips, ethertype):
+        new_set_name = set_name + SWAP_SUFFIX
+        set_type = self._get_ipset_set_type(ethertype)
+        process_input = ["create %s hash:ip family %s" % (new_set_name,
+                                                          set_type)]
         for ip in member_ips:
-            process_input.append("add %s %s" % (new_chain_name, ip))
+            process_input.append("add %s %s" % (new_set_name, ip))
 
-        self._restore_ipset_chains(process_input)
-        self._swap_ipset_chains(new_chain_name, chain_name)
-        self._destroy_ipset_chain(new_chain_name)
+        self._restore_sets(process_input)
+        self._swap_sets(new_set_name, set_name)
+        self._destroy(new_set_name, True)
+        self.ipset_sets[set_name] = member_ips
 
-    @utils.synchronized('ipset', external=True)
-    def del_ipset_chain_member(self, chain_name, member_ip):
-        cmd = ['ipset', 'del', chain_name, member_ip]
+    def _del_member_from_set(self, set_name, member_ip):
+        cmd = ['ipset', 'del', set_name, member_ip]
         self._apply(cmd)
+        self.ipset_sets[set_name].remove(member_ip)
 
-    @utils.synchronized('ipset', external=True)
-    def destroy_ipset_chain_by_name(self, chain_name):
-        self._destroy_ipset_chain(chain_name)
+    def _create_set(self, set_name, ethertype):
+        cmd = ['ipset', 'create', '-exist', set_name, 'hash:ip', 'family',
+               self._get_ipset_set_type(ethertype)]
+        self._apply(cmd)
+        self.ipset_sets[set_name] = []
 
     def _apply(self, cmd, input=None):
         input = '\n'.join(input) if input else None
@@ -66,17 +117,39 @@ class IpsetManager(object):
                      root_helper=self.root_helper,
                      process_input=input)
 
-    def _get_ipset_chain_type(self, ethertype):
+    def _get_new_set_ips(self, set_name, expected_ips):
+        new_member_ips = (set(expected_ips) -
+                          set(self.ipset_sets.get(set_name, [])))
+        return list(new_member_ips)
+
+    def _get_deleted_set_ips(self, set_name, expected_ips):
+        deleted_member_ips = (set(self.ipset_sets.get(set_name, [])) -
+                              set(expected_ips))
+        return list(deleted_member_ips)
+
+    def _add_members_to_set(self, set_name, add_ips):
+        for ip in add_ips:
+            if ip not in self.ipset_sets[set_name]:
+                self._add_member_to_set(set_name, ip)
+
+    def _del_members_from_set(self, set_name, del_ips):
+        for ip in del_ips:
+            if ip in self.ipset_sets[set_name]:
+                self._del_member_from_set(set_name, ip)
+
+    def _get_ipset_set_type(self, ethertype):
         return 'inet6' if ethertype == 'IPv6' else 'inet'
 
-    def _restore_ipset_chains(self, process_input):
+    def _restore_sets(self, process_input):
         cmd = ['ipset', 'restore', '-exist']
         self._apply(cmd, process_input)
 
-    def _swap_ipset_chains(self, src_chain, dest_chain):
-        cmd = ['ipset', 'swap', src_chain, dest_chain]
+    def _swap_sets(self, src_set, dest_set):
+        cmd = ['ipset', 'swap', src_set, dest_set]
         self._apply(cmd)
 
-    def _destroy_ipset_chain(self, chain_name):
-        cmd = ['ipset', 'destroy', chain_name]
-        self._apply(cmd)
+    def _destroy(self, set_name, forced=False):
+        if set_name in self.ipset_sets or forced:
+            cmd = ['ipset', 'destroy', set_name]
+            self._apply(cmd)
+            self.ipset_sets.pop(set_name, None)
index 73b21a4e4b62f94433892b6090621d1a7257b1ef..33793a0e67284adcec84a0615b959a184f6d511e 100644 (file)
@@ -39,9 +39,6 @@ DIRECTION_IP_PREFIX = {'ingress': 'source_ip_prefix',
 IPSET_DIRECTION = {INGRESS_DIRECTION: 'src',
                    EGRESS_DIRECTION: 'dst'}
 LINUX_DEV_LEN = 14
-IPSET_CHAIN_LEN = 20
-IPSET_CHANGE_BULK_THRESHOLD = 10
-IPSET_ADD_BULK_THRESHOLD = 5
 comment_rule = iptables_manager.comment_rule
 
 
@@ -69,7 +66,6 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         # List of security group member ips for ports residing on this host
         self.sg_members = {}
         self.pre_sg_members = None
-        self.ipset_chains = {}
         self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
 
     @property
@@ -338,8 +334,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         security_group_rules += self._select_sg_rules_for_port(port, direction)
         if self.enable_ipset:
             remote_sg_ids = self._get_remote_sg_ids(port, direction)
-            # update the corresponding ipset chain member
-            self._update_ipset_chain_member(remote_sg_ids)
+            # update the corresponding ipset members
+            self._update_ipset_members(remote_sg_ids)
         # split groups by ip version
         # for ipv4, iptables command is used
         # for ipv6, iptables6 command is used
@@ -365,56 +361,12 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
     def _get_cur_sg_member_ips(self, sg_id, ethertype):
         return self.sg_members.get(sg_id, {}).get(ethertype, [])
 
-    def _get_pre_sg_member_ips(self, sg_id, ethertype):
-        return self.pre_sg_members.get(sg_id, {}).get(ethertype, [])
-
-    def _get_new_sg_member_ips(self, sg_id, ethertype):
-        add_member_ips = (set(self._get_cur_sg_member_ips(sg_id, ethertype)) -
-                          set(self._get_pre_sg_member_ips(sg_id, ethertype)))
-        return list(add_member_ips)
-
-    def _get_deleted_sg_member_ips(self, sg_id, ethertype):
-        del_member_ips = (set(self._get_pre_sg_member_ips(sg_id, ethertype)) -
-                          set(self._get_cur_sg_member_ips(sg_id, ethertype)))
-        return list(del_member_ips)
-
-    def _bulk_set_ips_to_chain(self, chain_name, member_ips, ethertype):
-        self.ipset.refresh_ipset_chain_by_name(chain_name, member_ips,
-                                               ethertype)
-        self.ipset_chains[chain_name] = member_ips
-
-    def _add_ips_to_ipset_chain(self, chain_name, add_ips):
-        for ip in add_ips:
-            if ip not in self.ipset_chains[chain_name]:
-                self.ipset.add_member_to_ipset_chain(chain_name, ip)
-                self.ipset_chains[chain_name].append(ip)
-
-    def _del_ips_from_ipset_chain(self, chain_name, del_ips):
-        if chain_name in self.ipset_chains:
-            for del_ip in del_ips:
-                if del_ip in self.ipset_chains[chain_name]:
-                    self.ipset.del_ipset_chain_member(chain_name, del_ip)
-                    self.ipset_chains[chain_name].remove(del_ip)
-
-    def _update_ipset_chain_member(self, security_group_ids):
+    def _update_ipset_members(self, security_group_ids):
         for ethertype, sg_ids in security_group_ids.items():
             for sg_id in sg_ids:
-                add_ips = self._get_new_sg_member_ips(sg_id, ethertype)
-                del_ips = self._get_deleted_sg_member_ips(sg_id, ethertype)
                 cur_member_ips = self._get_cur_sg_member_ips(sg_id, ethertype)
-                chain_name = ethertype + sg_id[:IPSET_CHAIN_LEN]
-                if chain_name not in self.ipset_chains and cur_member_ips:
-                    self.ipset_chains[chain_name] = []
-                    self.ipset.create_ipset_chain(chain_name, ethertype)
-                    self._bulk_set_ips_to_chain(chain_name,
-                                                cur_member_ips, ethertype)
-                elif (len(add_ips) + len(del_ips)
-                      < IPSET_CHANGE_BULK_THRESHOLD):
-                    self._add_ips_to_ipset_chain(chain_name, add_ips)
-                    self._del_ips_from_ipset_chain(chain_name, del_ips)
-                else:
-                    self._bulk_set_ips_to_chain(chain_name,
-                                                cur_member_ips, ethertype)
+                if cur_member_ips:
+                    self.ipset.set_members(sg_id, ethertype, cur_member_ips)
 
     def _generate_ipset_chain(self, sg_rule, remote_gid):
         iptables_rules = []
@@ -429,12 +381,10 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
                                sg_rule.get('port_range_max'))
         direction = sg_rule.get('direction')
         ethertype = sg_rule.get('ethertype')
-        # the length of ipset chain name require less than 31
-        # characters
-        ipset_chain_name = (ethertype + remote_gid[:IPSET_CHAIN_LEN])
-        if ipset_chain_name in self.ipset_chains:
+
+        if self.ipset.set_exists(remote_gid, ethertype):
             args += ['-m set', '--match-set',
-                     ipset_chain_name,
+                     self.ipset.get_name(remote_gid, ethertype),
                      IPSET_DIRECTION[direction]]
             args += ['-j RETURN']
             iptables_rules += [' '.join(args)]
@@ -539,8 +489,8 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
             self._defer_apply = True
 
     def _remove_unused_security_group_info(self):
-        need_removed_ipset_chains = {constants.IPv4: set(),
-                                     constants.IPv6: set()}
+        need_removed_ipsets = {constants.IPv4: set(),
+                               constants.IPv6: set()}
         need_removed_security_groups = set()
         remote_group_ids = {constants.IPv4: set(),
                             constants.IPv6: set()}
@@ -554,24 +504,20 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
             cur_group_ids.update(groups)
 
         for ethertype in [constants.IPv4, constants.IPv6]:
-            need_removed_ipset_chains[ethertype].update(
+            need_removed_ipsets[ethertype].update(
                 [x for x in self.pre_sg_members if x not in remote_group_ids[
                     ethertype]])
             need_removed_security_groups.update(
                 [x for x in self.pre_sg_rules if x not in cur_group_ids])
 
-        # Remove unused remote ipset set
-        for ethertype, remove_chain_ids in need_removed_ipset_chains.items():
-            for remove_chain_id in remove_chain_ids:
-                if self.sg_members.get(remove_chain_id, {}).get(ethertype, []):
-                    self.sg_members[remove_chain_id][ethertype] = []
+        # Remove unused ip sets (sg_members and kernel ipset if we
+        # are using ipset)
+        for ethertype, remove_set_ids in need_removed_ipsets.items():
+            for remove_set_id in remove_set_ids:
+                if self.sg_members.get(remove_set_id, {}).get(ethertype, []):
+                    self.sg_members[remove_set_id][ethertype] = []
                 if self.enable_ipset:
-                    removed_chain = (
-                        ethertype + remove_chain_id[:IPSET_CHAIN_LEN])
-                    if removed_chain in self.ipset_chains:
-                        self.ipset.destroy_ipset_chain_by_name(
-                            removed_chain)
-                        self.ipset_chains.pop(removed_chain, None)
+                    self.ipset.destroy(remove_set_id, ethertype)
 
         # Remove unused remote security group member ips
         sg_ids = self.sg_members.keys()
index d8bbfc98619d90e4a1c0d553344f6aac52a105c7..c447e0018e17cb32cd70bc66c9ea0975d42e61cb 100644 (file)
@@ -17,9 +17,9 @@ from neutron.agent.linux import ipset_manager
 from neutron.agent.linux import iptables_manager
 from neutron.tests.functional.agent.linux import base
 
-IPSET_CHAIN = 'test-chain'
+IPSET_SET = 'test-set'
 IPSET_ETHERTYPE = 'IPv4'
-ICMP_ACCEPT_RULE = '-p icmp -m set --match-set %s src -j ACCEPT' % IPSET_CHAIN
+ICMP_ACCEPT_RULE = '-p icmp -m set --match-set %s src -j ACCEPT' % IPSET_SET
 UNRELATED_IP = '1.1.1.1'
 
 
@@ -29,8 +29,8 @@ class IpsetBase(base.BaseIPVethTestCase):
         super(IpsetBase, self).setUp()
 
         self.src_ns, self.dst_ns = self.prepare_veth_pairs()
-        self.ipset = self._create_ipset_manager_and_chain(self.dst_ns,
-                                                          IPSET_CHAIN)
+        self.ipset = self._create_ipset_manager_and_set(self.dst_ns,
+                                                        IPSET_SET)
 
         self.dst_iptables = iptables_manager.IptablesManager(
             root_helper=self.root_helper,
@@ -38,12 +38,12 @@ class IpsetBase(base.BaseIPVethTestCase):
 
         self._add_iptables_ipset_rules(self.dst_iptables)
 
-    def _create_ipset_manager_and_chain(self, dst_ns, chain_name):
+    def _create_ipset_manager_and_set(self, dst_ns, set_name):
         ipset = ipset_manager.IpsetManager(
             root_helper=self.root_helper,
             namespace=dst_ns.namespace)
 
-        ipset.create_ipset_chain(chain_name, IPSET_ETHERTYPE)
+        ipset._create_set(set_name, IPSET_ETHERTYPE)
         return ipset
 
     @staticmethod
@@ -62,31 +62,29 @@ class IpsetManagerTestCase(IpsetBase):
 
     def test_add_member_allows_ping(self):
         self.pinger.assert_no_ping_from_ns(self.src_ns, self.DST_ADDRESS)
-        self.ipset.add_member_to_ipset_chain(IPSET_CHAIN, self.SRC_ADDRESS)
+        self.ipset._add_member_to_set(IPSET_SET, self.SRC_ADDRESS)
         self.pinger.assert_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
     def test_del_member_denies_ping(self):
-        self.ipset.add_member_to_ipset_chain(IPSET_CHAIN, self.SRC_ADDRESS)
+        self.ipset._add_member_to_set(IPSET_SET, self.SRC_ADDRESS)
         self.pinger.assert_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
-        self.ipset.del_ipset_chain_member(IPSET_CHAIN, self.SRC_ADDRESS)
+        self.ipset._del_member_from_set(IPSET_SET, self.SRC_ADDRESS)
         self.pinger.assert_no_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
     def test_refresh_ipset_allows_ping(self):
-        self.ipset.refresh_ipset_chain_by_name(IPSET_CHAIN, [UNRELATED_IP],
-                                               IPSET_ETHERTYPE)
+        self.ipset._refresh_set(IPSET_SET, [UNRELATED_IP], IPSET_ETHERTYPE)
         self.pinger.assert_no_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
-        self.ipset.refresh_ipset_chain_by_name(
-            IPSET_CHAIN, [UNRELATED_IP, self.SRC_ADDRESS], IPSET_ETHERTYPE)
+        self.ipset._refresh_set(IPSET_SET, [UNRELATED_IP, self.SRC_ADDRESS],
+                                IPSET_ETHERTYPE)
         self.pinger.assert_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
-        self.ipset.refresh_ipset_chain_by_name(
-            IPSET_CHAIN, [self.SRC_ADDRESS, UNRELATED_IP], IPSET_ETHERTYPE)
+        self.ipset._refresh_set(IPSET_SET, [self.SRC_ADDRESS, UNRELATED_IP],
+                                IPSET_ETHERTYPE)
         self.pinger.assert_ping_from_ns(self.src_ns, self.DST_ADDRESS)
 
-    def test_destroy_ipset_chain(self):
-        self.assertRaises(RuntimeError,
-                          self.ipset.destroy_ipset_chain_by_name, IPSET_CHAIN)
+    def test_destroy_ipset_set(self):
+        self.assertRaises(RuntimeError, self.ipset._destroy, IPSET_SET)
         self._remove_iptables_ipset_rules(self.dst_iptables)
-        self.ipset.destroy_ipset_chain_by_name(IPSET_CHAIN)
+        self.ipset._destroy(IPSET_SET)
diff --git a/neutron/tests/unit/agent/linux/test_ipset_manager.py b/neutron/tests/unit/agent/linux/test_ipset_manager.py
new file mode 100644 (file)
index 0000000..70b9637
--- /dev/null
@@ -0,0 +1,121 @@
+#
+#    Licensed under the Apache License, Version 2.0 (the "License");
+#    you may not use this file except in compliance with the License.
+#    You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS,
+#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#    See the License for the specific language governing permissions and
+#    limitations under the License.
+
+import mock
+
+from neutron.agent.linux import ipset_manager
+from neutron.tests import base
+
+TEST_SET_ID = 'fake_sgid'
+ETHERTYPE = 'IPv4'
+TEST_SET_NAME = ipset_manager.IpsetManager.get_name(TEST_SET_ID, ETHERTYPE)
+TEST_SET_NAME_NEW = TEST_SET_NAME + ipset_manager.SWAP_SUFFIX
+FAKE_IPS = ['10.0.0.1', '10.0.0.2', '10.0.0.3', '10.0.0.4',
+            '10.0.0.5', '10.0.0.6']
+
+
+class BaseIpsetManagerTest(base.BaseTestCase):
+    def setUp(self):
+        super(BaseIpsetManagerTest, self).setUp()
+        self.root_helper = 'sudo'
+        self.ipset = ipset_manager.IpsetManager(
+            root_helper=self.root_helper)
+        self.execute = mock.patch.object(self.ipset, "execute").start()
+        self.expected_calls = []
+        self.expect_create()
+
+    def verify_mock_calls(self):
+        self.execute.assert_has_calls(self.expected_calls, any_order=False)
+
+    def expect_set(self, addresses):
+        temp_input = ['create IPv4fake_sgid-new hash:ip family inet']
+        temp_input.extend('add IPv4fake_sgid-new %s' % ip for ip in addresses)
+        input = '\n'.join(temp_input)
+        self.expected_calls.extend([
+            mock.call(['ipset', 'restore', '-exist'],
+                      process_input=input,
+                      root_helper=self.root_helper),
+            mock.call(['ipset', 'swap', TEST_SET_NAME_NEW, TEST_SET_NAME],
+                      process_input=None,
+                      root_helper=self.root_helper),
+            mock.call(['ipset', 'destroy', TEST_SET_NAME_NEW],
+                      process_input=None,
+                      root_helper=self.root_helper)])
+
+    def expect_add(self, addresses):
+        self.expected_calls.extend(
+            mock.call(['ipset', 'add', '-exist', TEST_SET_NAME, ip],
+                      process_input=None,
+                      root_helper=self.root_helper) for ip in addresses)
+
+    def expect_del(self, addresses):
+        self.expected_calls.extend(
+            mock.call(['ipset', 'del', TEST_SET_NAME, ip],
+                      process_input=None,
+                      root_helper=self.root_helper) for ip in addresses)
+
+    def expect_create(self):
+        self.expected_calls.append(
+            mock.call(['ipset', 'create', '-exist', TEST_SET_NAME,
+                       'hash:ip', 'family', 'inet'],
+                      process_input=None,
+                      root_helper=self.root_helper))
+
+    def expect_destroy(self):
+        self.expected_calls.append(
+            mock.call(['ipset', 'destroy', TEST_SET_NAME],
+                      process_input=None,
+                      root_helper=self.root_helper))
+
+    def add_first_ip(self):
+        self.expect_set([FAKE_IPS[0]])
+        self.ipset.set_members(TEST_SET_ID, ETHERTYPE, [FAKE_IPS[0]])
+
+    def add_all_ips(self):
+        self.expect_set(FAKE_IPS)
+        self.ipset.set_members(TEST_SET_ID, ETHERTYPE, FAKE_IPS)
+
+
+class IpsetManagerTestCase(BaseIpsetManagerTest):
+
+    def test_set_exists(self):
+        self.add_first_ip()
+        self.assertTrue(self.ipset.set_exists(TEST_SET_ID, ETHERTYPE))
+
+    def test_set_members_with_first_add_member(self):
+        self.add_first_ip()
+        self.verify_mock_calls()
+
+    def test_set_members_adding_less_than_5(self):
+        self.add_first_ip()
+        self.expect_add(reversed(FAKE_IPS[1:5]))
+        self.ipset.set_members(TEST_SET_ID, ETHERTYPE, FAKE_IPS[0:5])
+        self.verify_mock_calls()
+
+    def test_set_members_deleting_less_than_5(self):
+        self.add_all_ips()
+        self.expect_del(reversed(FAKE_IPS[4:5]))
+        self.ipset.set_members(TEST_SET_ID, ETHERTYPE, FAKE_IPS[0:3])
+        self.verify_mock_calls()
+
+    def test_set_members_adding_more_than_5(self):
+        self.add_first_ip()
+        self.expect_set(FAKE_IPS)
+        self.ipset.set_members(TEST_SET_ID, ETHERTYPE, FAKE_IPS)
+        self.verify_mock_calls()
+
+    def test_destroy(self):
+        self.add_first_ip()
+        self.expect_destroy()
+        self.ipset.destroy(TEST_SET_ID, ETHERTYPE)
+        self.verify_mock_calls()
index 52c61f5be04b19b454fb00f4d06c90e992a6ee67..21b7516f64eb7e82a1642dd37b7eadffef9074c5 100644 (file)
@@ -19,6 +19,7 @@ import mock
 from oslo.config import cfg
 
 from neutron.agent.common import config as a_cfg
+from neutron.agent.linux import ipset_manager
 from neutron.agent.linux import iptables_comments as ic
 from neutron.agent.linux import iptables_firewall
 from neutron.agent import securitygroups_rpc as sg_cfg
@@ -33,10 +34,6 @@ FAKE_PREFIX = {'IPv4': '10.0.0.0/24',
 FAKE_IP = {'IPv4': '10.0.0.1',
            'IPv6': 'fe80::1'}
 
-TEST_IP_RANGE = ['10.0.0.1', '10.0.0.2', '10.0.0.3', '10.0.0.4',
-                 '10.0.0.5', '10.0.0.6', '10.0.0.7', '10.0.0.8',
-                 '10.0.0.9', '10.0.0.10']
-
 
 class BaseIptablesFirewallTestCase(base.BaseTestCase):
     def setUp(self):
@@ -1408,6 +1405,9 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
     def setUp(self):
         super(IptablesFirewallEnhancedIpsetTestCase, self).setUp()
         self.firewall.ipset = mock.Mock()
+        self.firewall.ipset.get_name.side_effect = (
+            ipset_manager.IpsetManager.get_name)
+        self.firewall.ipset.set_exists.return_value = True
 
     def _fake_port(self):
         return {'device': 'tapfake_dev',
@@ -1424,44 +1424,26 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
             {'direction': 'ingress', 'remote_group_id': 'fake_sgid',
              'ethertype': 'IPv6'}]}
 
-    def test_prepare_port_filter_with_default_sg(self):
+    def test_prepare_port_filter_with_new_members(self):
         self.firewall.sg_rules = self._fake_sg_rule()
         self.firewall.sg_members = {'fake_sgid': {
             'IPv4': ['10.0.0.1', '10.0.0.2'], 'IPv6': ['fe80::1']}}
         self.firewall.pre_sg_members = {}
         port = self._fake_port()
         self.firewall.prepare_port_filter(port)
-        calls = [mock.call.create_ipset_chain('IPv4fake_sgid', 'IPv4'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv4fake_sgid', ['10.0.0.1', '10.0.0.2'], 'IPv4'),
-                 mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
-
-        self.firewall.ipset.assert_has_calls(calls)
-
-    def test_prepare_port_filter_with_add_members_beyond_4(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.sg_members = {'fake_sgid': {
-            'IPv4': TEST_IP_RANGE[:5],
-            'IPv6': ['fe80::1']}}
-        self.firewall.pre_sg_members = {}
-        port = self._fake_port()
-        self.firewall.prepare_port_filter(port)
-        calls = [mock.call.create_ipset_chain('IPv4fake_sgid', 'IPv4'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv4fake_sgid', TEST_IP_RANGE[:5], 'IPv4'),
-                 mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
-
+        calls = [
+            mock.call.set_members('fake_sgid', 'IPv4',
+                                  ['10.0.0.1', '10.0.0.2']),
+            mock.call.set_members('fake_sgid', 'IPv6',
+                                  ['fe80::1'])
+        ]
         self.firewall.ipset.assert_has_calls(calls)
 
-    def test_prepare_port_filter_with_ipset_chain_exist(self):
+    def test_prepare_port_filter_with_deleted_member(self):
         self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2']}
         self.firewall.sg_members = {'fake_sgid': {
-            'IPv4': TEST_IP_RANGE[:5],
+            'IPv4': [
+                '10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
             'IPv6': ['fe80::1']}}
         self.firewall.pre_sg_members = {'fake_sgid': {
             'IPv4': ['10.0.0.2'],
@@ -1469,57 +1451,41 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         port = self._fake_port()
         self.firewall.prepare_port_filter(port)
         calls = [
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.1'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.3'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.4'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.5'),
-            mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-            mock.call.refresh_ipset_chain_by_name(
-                'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
+            mock.call.set_members('fake_sgid', 'IPv4',
+                                  ['10.0.0.1', '10.0.0.3', '10.0.0.4',
+                                   '10.0.0.5']),
+            mock.call.set_members('fake_sgid', 'IPv6', ['fe80::1'])]
 
         self.firewall.ipset.assert_has_calls(calls, True)
 
-    def test_prepare_port_filter_with_del_member(self):
+    def test_remove_port_filter_with_destroy_ipset_chain(self):
         self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.ipset_chains = {'IPv4fake_sgid': ['10.0.0.2']}
+        port = self._fake_port()
         self.firewall.sg_members = {'fake_sgid': {
-            'IPv4': [
-                '10.0.0.1', '10.0.0.3', '10.0.0.4', '10.0.0.5'],
+            'IPv4': ['10.0.0.1'],
             'IPv6': ['fe80::1']}}
         self.firewall.pre_sg_members = {'fake_sgid': {
-            'IPv4': ['10.0.0.2'],
-            'IPv6': ['fe80::1']}}
-        port = self._fake_port()
+            'IPv4': [],
+            'IPv6': []}}
         self.firewall.prepare_port_filter(port)
-        calls = [
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.1'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.3'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.4'),
-            mock.call.add_member_to_ipset_chain('IPv4fake_sgid', '10.0.0.5'),
-            mock.call.del_ipset_chain_member('IPv4fake_sgid', '10.0.0.2'),
-            mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-            mock.call.refresh_ipset_chain_by_name(
-                'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
-
-        self.firewall.ipset.assert_has_calls(calls, True)
-
-    def test_prepare_port_filter_change_beyond_9(self):
-        self.firewall.sg_rules = self._fake_sg_rule()
-        self.firewall.ipset_chains = {'IPv4fake_sgid': TEST_IP_RANGE[5:]}
+        self.firewall.filter_defer_apply_on()
         self.firewall.sg_members = {'fake_sgid': {
-            'IPv4': TEST_IP_RANGE[:5],
-            'IPv6': ['fe80::1']}}
+            'IPv4': [],
+            'IPv6': []}}
         self.firewall.pre_sg_members = {'fake_sgid': {
-            'IPv4': TEST_IP_RANGE[5:],
+            'IPv4': ['10.0.0.1'],
             'IPv6': ['fe80::1']}}
-        port = self._fake_port()
-        self.firewall.prepare_port_filter(port)
+        self.firewall.remove_port_filter(port)
+        self.firewall.filter_defer_apply_off()
         calls = [
-            mock.call.refresh_ipset_chain_by_name('IPv4fake_sgid',
-                                                  TEST_IP_RANGE[:5], 'IPv4'),
-            mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-            mock.call.refresh_ipset_chain_by_name(
-                'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
+            mock.call.set_members('fake_sgid', 'IPv4', ['10.0.0.1']),
+            mock.call.set_members('fake_sgid', 'IPv6', ['fe80::1']),
+            mock.call.set_exists('fake_sgid', 'IPv4'),
+            mock.call.get_name('fake_sgid', 'IPv4'),
+            mock.call.set_exists('fake_sgid', 'IPv6'),
+            mock.call.get_name('fake_sgid', 'IPv6'),
+            mock.call.destroy('fake_sgid', 'IPv4'),
+            mock.call.destroy('fake_sgid', 'IPv6')]
 
         self.firewall.ipset.assert_has_calls(calls)
 
@@ -1535,12 +1501,9 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         port = self._fake_port()
         port['security_group_source_groups'].append('fake_sgid2')
         self.firewall.prepare_port_filter(port)
-        calls = [mock.call.create_ipset_chain('IPv4fake_sgid', 'IPv4'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv4fake_sgid', ['10.0.0.1', '10.0.0.2'], 'IPv4'),
-                 mock.call.create_ipset_chain('IPv6fake_sgid', 'IPv6'),
-                 mock.call.refresh_ipset_chain_by_name(
-                     'IPv6fake_sgid', ['fe80::1'], 'IPv6')]
+        calls = [mock.call.set_members('fake_sgid', 'IPv4',
+                                       ['10.0.0.1', '10.0.0.2']),
+                 mock.call.set_members('fake_sgid', 'IPv6', ['fe80::1'])]
 
         self.firewall.ipset.assert_has_calls(calls)
 
@@ -1564,6 +1527,6 @@ class IptablesFirewallEnhancedIpsetTestCase(BaseIptablesFirewallTestCase):
         self.firewall.filtered_ports['tapfake_dev'] = port
         self.firewall._pre_defer_filtered_ports = {}
         self.firewall.filter_defer_apply_off()
-        calls = [mock.call.destroy_ipset_chain_by_name('IPv4fake_sgid')]
+        calls = [mock.call.destroy('fake_sgid', 'IPv4')]
 
         self.firewall.ipset.assert_has_calls(calls, True)