]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Add conntrack-tool to manage security groups
authorshihanzhang <shihanzhang@huawei.com>
Thu, 15 Jan 2015 12:16:21 +0000 (20:16 +0800)
committershihanzhang <shihanzhang@huawei.com>
Tue, 21 Jul 2015 12:24:53 +0000 (20:24 +0800)
This patch introduces conntrack-tool to manage security groups. When a
security group rule is deleted, the corresponding tracked connection
entries will also be removed from the kernel for the address.

Closes-Bug: #1335375
Partially-Implements: bp conntrack-in-security-group

Change-Id: Ibfd2d6a11aa970ea9e5009f4c4b858544d8b7463

neutron/agent/firewall.py
neutron/agent/linux/ip_conntrack.py [new file with mode: 0644]
neutron/agent/linux/iptables_firewall.py
neutron/agent/securitygroups_rpc.py
neutron/tests/unit/agent/linux/test_iptables_firewall.py

index afb0f18f59e0d1949caf09e176a67c1ddf8c3ef7..04a327b21df66f01bd77898c2d6275592634d4ab 100644 (file)
@@ -117,6 +117,11 @@ class FirewallDriver(object):
         """Update rules in a security group."""
         raise NotImplementedError()
 
+    def security_group_updated(self, action_type, sec_group_ids,
+                               device_id=None):
+        """Called when a security group is updated."""
+        raise NotImplementedError()
+
 
 class NoopFirewallDriver(FirewallDriver):
     """Noop Firewall Driver.
@@ -152,3 +157,7 @@ class NoopFirewallDriver(FirewallDriver):
 
     def update_security_group_rules(self, sg_id, rules):
         pass
+
+    def security_group_updated(self, action_type, sec_group_ids,
+                               device_id=None):
+        pass
diff --git a/neutron/agent/linux/ip_conntrack.py b/neutron/agent/linux/ip_conntrack.py
new file mode 100644 (file)
index 0000000..97c94e0
--- /dev/null
@@ -0,0 +1,89 @@
+#
+#    Licensed under the Apache License, Version 2.0 (the "License");
+#    you may not use this file except in compliance with the License.
+#    You may obtain a copy of the License at
+#
+#        http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS,
+#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#    See the License for the specific language governing permissions and
+#    limitations under the License.
+
+import netaddr
+from oslo_log import log as logging
+
+from neutron.agent.linux import utils as linux_utils
+from neutron.i18n import _LE
+
+LOG = logging.getLogger(__name__)
+
+
+class IpConntrackManager(object):
+    """Smart wrapper for ip conntrack."""
+
+    def __init__(self, execute=None, namespace=None):
+        self.execute = execute or linux_utils.execute
+        self.namespace = namespace
+
+    @staticmethod
+    def _generate_conntrack_cmd_by_rule(rule, namespace):
+        ethertype = rule.get('ethertype')
+        protocol = rule.get('protocol')
+        direction = rule.get('direction')
+        cmd = ['conntrack', '-D']
+        if protocol:
+            cmd.extend(['-p', str(protocol)])
+        cmd.extend(['-f', str(ethertype).lower()])
+        cmd.append('-d' if direction == 'ingress' else '-s')
+        cmd_ns = []
+        if namespace:
+            cmd_ns.extend(['ip', 'netns', 'exec', namespace])
+        cmd_ns.extend(cmd)
+        return cmd_ns
+
+    def _get_conntrack_cmds(self, device_info_list, rule, remote_ip=None):
+        conntrack_cmds = []
+        cmd = self._generate_conntrack_cmd_by_rule(rule, self.namespace)
+        ethertype = rule.get('ethertype')
+        for device_info in device_info_list:
+            zone_id = device_info.get('zone_id')
+            if not zone_id:
+                continue
+            ips = device_info.get('fixed_ips', [])
+            for ip in ips:
+                net = netaddr.IPNetwork(ip)
+                if str(net.version) not in ethertype:
+                    continue
+                ip_cmd = [str(net.ip), '-w', zone_id]
+                if remote_ip and str(
+                        netaddr.IPNetwork(remote_ip).version) in ethertype:
+                    ip_cmd.extend(['-s', str(remote_ip)])
+                conntrack_cmds.append(cmd + ip_cmd)
+        return conntrack_cmds
+
+    def _delete_conntrack_state(self, device_info_list, rule, remote_ip=None):
+        conntrack_cmds = self._get_conntrack_cmds(device_info_list,
+                                                  rule, remote_ip)
+        for cmd in conntrack_cmds:
+            try:
+                self.execute(cmd, run_as_root=True,
+                             check_exit_code=True,
+                             extra_ok_codes=[1])
+            except RuntimeError:
+                LOG.exception(
+                    _LE("Failed execute conntrack command %s"), str(cmd))
+
+    def delete_conntrack_state_by_rule(self, device_info_list, rule):
+        self._delete_conntrack_state(device_info_list, rule)
+
+    def delete_conntrack_state_by_remote_ips(self, device_info_list,
+                                             ethertype, remote_ips):
+        rule = {'ethertype': str(ethertype).lower(), 'direction': 'ingress'}
+        if remote_ips:
+            for remote_ip in remote_ips:
+                self._delete_conntrack_state(
+                    device_info_list, rule, remote_ip)
+        else:
+            self._delete_conntrack_state(device_info_list, rule)
index e4e1f171172fd83efa7b652e6c63294c18159c74..9414d40f631c0771392c8e429108c5da72e6fa13 100644 (file)
@@ -20,6 +20,7 @@ from oslo_log import log as logging
 import six
 
 from neutron.agent import firewall
+from neutron.agent.linux import ip_conntrack
 from neutron.agent.linux import ipset_manager
 from neutron.agent.linux import iptables_comments as ic
 from neutron.agent.linux import iptables_manager
@@ -56,6 +57,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         # TODO(majopela, shihanzhang): refactor out ipset to a separate
         # driver composed over this one
         self.ipset = ipset_manager.IpsetManager(namespace=namespace)
+        self.ipconntrack = ip_conntrack.IpConntrackManager(namespace=namespace)
         # list of port which has security group
         self.filtered_ports = {}
         self.unfiltered_ports = {}
@@ -72,6 +74,9 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
         self.pre_sg_members = None
         self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
         self._enabled_netfilter_for_bridges = False
+        self.updated_rule_sg_ids = set()
+        self.updated_sg_members = set()
+        self.devices_with_udpated_sg_members = collections.defaultdict(list)
 
     def _enable_netfilter_for_bridges(self):
         # we only need to set these values once, but it has to be when
@@ -102,6 +107,22 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
     def ports(self):
         return dict(self.filtered_ports, **self.unfiltered_ports)
 
+    def _update_remote_security_group_members(self, sec_group_ids):
+        for sg_id in sec_group_ids:
+            for device in self.filtered_ports.values():
+                if sg_id in device.get('security_group_source_groups', []):
+                    self.devices_with_udpated_sg_members[sg_id].append(device)
+
+    def security_group_updated(self, action_type, sec_group_ids,
+                               device_ids=[]):
+        if action_type == 'sg_rule':
+            self.updated_rule_sg_ids.update(sec_group_ids)
+        elif action_type == 'sg_member':
+            if device_ids:
+                self.updated_sg_members.update(device_ids)
+            else:
+                self._update_remote_security_group_members(sec_group_ids)
+
     def update_security_group_rules(self, sg_id, sg_rules):
         LOG.debug("Update rules of security group (%s)", sg_id)
         self.sg_rules[sg_id] = sg_rules
@@ -688,6 +709,79 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
             if not sg_has_members:
                 del self.sg_members[sg_id]
 
+    def _find_deleted_sg_rules(self, sg_id):
+        del_rules = list()
+        for pre_rule in self.pre_sg_rules.get(sg_id, []):
+            if pre_rule not in self.sg_rules.get(sg_id, []):
+                del_rules.append(pre_rule)
+        return del_rules
+
+    def _find_devices_on_security_group(self, sg_id):
+        device_list = list()
+        for device in self.filtered_ports.values():
+            if sg_id in device.get('security_groups', []):
+                device_list.append(device)
+        return device_list
+
+    def _clean_deleted_sg_rule_conntrack_entries(self):
+        deleted_sg_ids = set()
+        for sg_id in self.updated_rule_sg_ids:
+            del_rules = self._find_deleted_sg_rules(sg_id)
+            if not del_rules:
+                continue
+            device_list = self._find_devices_on_security_group(sg_id)
+            for rule in del_rules:
+                self.ipconntrack.delete_conntrack_state_by_rule(
+                    device_list, rule)
+            deleted_sg_ids.add(sg_id)
+        for id in deleted_sg_ids:
+            self.updated_rule_sg_ids.remove(id)
+
+    def _clean_updated_sg_member_conntrack_entries(self):
+        updated_device_ids = set()
+        for device in self.updated_sg_members:
+            sec_group_change = False
+            device_info = self.filtered_ports.get(device)
+            pre_device_info = self._pre_defer_filtered_ports.get(device)
+            if not (device_info or pre_device_info):
+                continue
+            for sg_id in pre_device_info.get('security_groups', []):
+                if sg_id not in device_info.get('security_groups', []):
+                    sec_group_change = True
+                    break
+            if not sec_group_change:
+                continue
+            for ethertype in [constants.IPv4, constants.IPv6]:
+                self.ipconntrack.delete_conntrack_state_by_remote_ips(
+                    [device_info], ethertype, set())
+            updated_device_ids.add(device)
+        for id in updated_device_ids:
+            self.updated_sg_members.remove(id)
+
+    def _clean_deleted_remote_sg_members_conntrack_entries(self):
+        deleted_sg_ids = set()
+        for sg_id, devices in self.devices_with_udpated_sg_members.items():
+            for ethertype in [constants.IPv4, constants.IPv6]:
+                pre_ips = self._get_sg_members(
+                    self.pre_sg_members, sg_id, ethertype)
+                cur_ips = self._get_sg_members(
+                    self.sg_members, sg_id, ethertype)
+                ips = (pre_ips - cur_ips)
+                if devices and ips:
+                    self.ipconntrack.delete_conntrack_state_by_remote_ips(
+                        devices, ethertype, ips)
+            deleted_sg_ids.add(sg_id)
+        for id in deleted_sg_ids:
+            self.devices_with_udpated_sg_members.pop(id, None)
+
+    def _remove_conntrack_entries_from_sg_updates(self):
+        self._clean_deleted_sg_rule_conntrack_entries()
+        self._clean_updated_sg_member_conntrack_entries()
+        self._clean_deleted_remote_sg_members_conntrack_entries()
+
+    def _get_sg_members(self, sg_info, sg_id, ethertype):
+        return set(sg_info.get(sg_id, {}).get(ethertype, []))
+
     def filter_defer_apply_off(self):
         if self._defer_apply:
             self._defer_apply = False
@@ -696,6 +790,7 @@ class IptablesFirewallDriver(firewall.FirewallDriver):
             self._setup_chains_apply(self.filtered_ports,
                                      self.unfiltered_ports)
             self.iptables.defer_apply_off()
+            self._remove_conntrack_entries_from_sg_updates()
             self._remove_unused_security_group_info()
             self._pre_defer_filtered_ports = None
             self._pre_defer_unfiltered_ports = None
index 598519879708dd7fff56529786c772391c2ee75d..ec1ad6b2c9f4453178d2fcb90e4932b9c5c1513b 100644 (file)
@@ -198,22 +198,25 @@ class SecurityGroupAgentRpc(object):
                  "rule updated %r"), security_groups)
         self._security_group_updated(
             security_groups,
-            'security_groups')
+            'security_groups',
+            'sg_rule')
 
     def security_groups_member_updated(self, security_groups):
         LOG.info(_LI("Security group "
                  "member updated %r"), security_groups)
         self._security_group_updated(
             security_groups,
-            'security_group_source_groups')
+            'security_group_source_groups',
+            'sg_member')
 
-    def _security_group_updated(self, security_groups, attribute):
+    def _security_group_updated(self, security_groups, attribute, action_type):
         devices = []
         sec_grp_set = set(security_groups)
         for device in self.firewall.ports.values():
             if sec_grp_set & set(device.get(attribute, [])):
                 devices.append(device['device'])
         if devices:
+            self.firewall.security_group_updated(action_type, sec_grp_set)
             if self.defer_refresh_firewall:
                 LOG.debug("Adding %s devices to the list of devices "
                           "for which firewall needs to be refreshed",
@@ -307,6 +310,8 @@ class SecurityGroupAgentRpc(object):
             LOG.debug("Refreshing firewall for all filtered devices")
             self.refresh_firewall()
         else:
+            self.firewall.security_group_updated('sg_member', [],
+                                                 updated_devices)
             # If a device is both in new and updated devices
             # avoid reprocessing it
             updated_devices = ((updated_devices | devices_to_refilter) -
index d43532df01027a361f89a37bc8944be8dbada746..0837dd1aa7f9aa3d5ffb22c5240eb58795a0da31 100644 (file)
@@ -1008,6 +1008,75 @@ class IptablesFirewallTestCase(BaseIptablesFirewallTestCase):
 
         filter_inst.assert_has_calls(calls)
 
+    def _test_remove_conntrack_entries(self, ethertype, protocol,
+                                       direction):
+        port = self._fake_port()
+        port['zone_id'] = 1
+        port['security_groups'] = 'fake_sg_id'
+        self.firewall.filtered_ports[port['device']] = port
+        self.firewall.updated_rule_sg_ids = set(['fake_sg_id'])
+        self.firewall.sg_rules['fake_sg_id'] = [
+            {'direction': direction, 'ethertype': ethertype,
+             'protocol': protocol}]
+
+        self.firewall.filter_defer_apply_on()
+        self.firewall.sg_rules['fake_sg_id'] = []
+        self.firewall.filter_defer_apply_off()
+        cmd = ['conntrack', '-D']
+        if protocol:
+            cmd.extend(['-p', protocol])
+        if ethertype == 'IPv4':
+            cmd.extend(['-f', 'ipv4'])
+            if direction == 'ingress':
+                cmd.extend(['-d', '10.0.0.1'])
+            else:
+                cmd.extend(['-s', '10.0.0.1'])
+        else:
+            cmd.extend(['-f', 'ipv6'])
+            if direction == 'ingress':
+                cmd.extend(['-d', 'fe80::1'])
+            else:
+                cmd.extend(['-s', 'fe80::1'])
+        cmd.extend(['-w', 1])
+        calls = [
+            mock.call(cmd, run_as_root=True, check_exit_code=True,
+                      extra_ok_codes=[1])]
+        self.utils_exec.assert_has_calls(calls)
+
+    def test_remove_conntrack_entries_for_delete_rule_ipv4(self):
+        for direction in ['ingress', 'egress']:
+            for pro in [None, 'tcp', 'icmp', 'udp']:
+                self._test_remove_conntrack_entries(
+                    'IPv4', pro, direction)
+
+    def test_remove_conntrack_entries_for_delete_rule_ipv6(self):
+        for direction in ['ingress', 'egress']:
+            for pro in [None, 'tcp', 'icmp', 'udp']:
+                self._test_remove_conntrack_entries(
+                    'IPv6', pro, direction)
+
+    def test_remove_conntrack_entries_for_port_sec_group_change(self):
+        port = self._fake_port()
+        port['zone_id'] = 1
+        port['security_groups'] = ['fake_sg_id']
+        self.firewall.filtered_ports[port['device']] = port
+        self.firewall.updated_sg_members = set(['tapfake_dev'])
+        self.firewall.filter_defer_apply_on()
+        new_port = copy.deepcopy(port)
+        new_port['security_groups'] = ['fake_sg_id2']
+        self.firewall.filtered_ports[port['device']] = new_port
+        self.firewall.filter_defer_apply_off()
+        calls = [
+            mock.call(['conntrack', '-D', '-f', 'ipv4', '-d', '10.0.0.1',
+                       '-w', 1],
+                      run_as_root=True, check_exit_code=True,
+                      extra_ok_codes=[1]),
+            mock.call(['conntrack', '-D', '-f', 'ipv6', '-d', 'fe80::1',
+                       '-w', 1],
+                      run_as_root=True, check_exit_code=True,
+                      extra_ok_codes=[1])]
+        self.utils_exec.assert_has_calls(calls)
+
     def test_update_delete_port_filter(self):
         port = self._fake_port()
         port['security_group_rules'] = [{'ethertype': 'IPv4',