"""Update rules in a security group."""
raise NotImplementedError()
+ def security_group_updated(self, action_type, sec_group_ids,
+ device_id=None):
+ """Called when a security group is updated."""
+ raise NotImplementedError()
+
class NoopFirewallDriver(FirewallDriver):
"""Noop Firewall Driver.
def update_security_group_rules(self, sg_id, rules):
pass
+
+ def security_group_updated(self, action_type, sec_group_ids,
+ device_id=None):
+ pass
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import netaddr
+from oslo_log import log as logging
+
+from neutron.agent.linux import utils as linux_utils
+from neutron.i18n import _LE
+
+LOG = logging.getLogger(__name__)
+
+
+class IpConntrackManager(object):
+ """Smart wrapper for ip conntrack."""
+
+ def __init__(self, execute=None, namespace=None):
+ self.execute = execute or linux_utils.execute
+ self.namespace = namespace
+
+ @staticmethod
+ def _generate_conntrack_cmd_by_rule(rule, namespace):
+ ethertype = rule.get('ethertype')
+ protocol = rule.get('protocol')
+ direction = rule.get('direction')
+ cmd = ['conntrack', '-D']
+ if protocol:
+ cmd.extend(['-p', str(protocol)])
+ cmd.extend(['-f', str(ethertype).lower()])
+ cmd.append('-d' if direction == 'ingress' else '-s')
+ cmd_ns = []
+ if namespace:
+ cmd_ns.extend(['ip', 'netns', 'exec', namespace])
+ cmd_ns.extend(cmd)
+ return cmd_ns
+
+ def _get_conntrack_cmds(self, device_info_list, rule, remote_ip=None):
+ conntrack_cmds = []
+ cmd = self._generate_conntrack_cmd_by_rule(rule, self.namespace)
+ ethertype = rule.get('ethertype')
+ for device_info in device_info_list:
+ zone_id = device_info.get('zone_id')
+ if not zone_id:
+ continue
+ ips = device_info.get('fixed_ips', [])
+ for ip in ips:
+ net = netaddr.IPNetwork(ip)
+ if str(net.version) not in ethertype:
+ continue
+ ip_cmd = [str(net.ip), '-w', zone_id]
+ if remote_ip and str(
+ netaddr.IPNetwork(remote_ip).version) in ethertype:
+ ip_cmd.extend(['-s', str(remote_ip)])
+ conntrack_cmds.append(cmd + ip_cmd)
+ return conntrack_cmds
+
+ def _delete_conntrack_state(self, device_info_list, rule, remote_ip=None):
+ conntrack_cmds = self._get_conntrack_cmds(device_info_list,
+ rule, remote_ip)
+ for cmd in conntrack_cmds:
+ try:
+ self.execute(cmd, run_as_root=True,
+ check_exit_code=True,
+ extra_ok_codes=[1])
+ except RuntimeError:
+ LOG.exception(
+ _LE("Failed execute conntrack command %s"), str(cmd))
+
+ def delete_conntrack_state_by_rule(self, device_info_list, rule):
+ self._delete_conntrack_state(device_info_list, rule)
+
+ def delete_conntrack_state_by_remote_ips(self, device_info_list,
+ ethertype, remote_ips):
+ rule = {'ethertype': str(ethertype).lower(), 'direction': 'ingress'}
+ if remote_ips:
+ for remote_ip in remote_ips:
+ self._delete_conntrack_state(
+ device_info_list, rule, remote_ip)
+ else:
+ self._delete_conntrack_state(device_info_list, rule)
import six
from neutron.agent import firewall
+from neutron.agent.linux import ip_conntrack
from neutron.agent.linux import ipset_manager
from neutron.agent.linux import iptables_comments as ic
from neutron.agent.linux import iptables_manager
# TODO(majopela, shihanzhang): refactor out ipset to a separate
# driver composed over this one
self.ipset = ipset_manager.IpsetManager(namespace=namespace)
+ self.ipconntrack = ip_conntrack.IpConntrackManager(namespace=namespace)
# list of port which has security group
self.filtered_ports = {}
self.unfiltered_ports = {}
self.pre_sg_members = None
self.enable_ipset = cfg.CONF.SECURITYGROUP.enable_ipset
self._enabled_netfilter_for_bridges = False
+ self.updated_rule_sg_ids = set()
+ self.updated_sg_members = set()
+ self.devices_with_udpated_sg_members = collections.defaultdict(list)
def _enable_netfilter_for_bridges(self):
# we only need to set these values once, but it has to be when
def ports(self):
return dict(self.filtered_ports, **self.unfiltered_ports)
+ def _update_remote_security_group_members(self, sec_group_ids):
+ for sg_id in sec_group_ids:
+ for device in self.filtered_ports.values():
+ if sg_id in device.get('security_group_source_groups', []):
+ self.devices_with_udpated_sg_members[sg_id].append(device)
+
+ def security_group_updated(self, action_type, sec_group_ids,
+ device_ids=[]):
+ if action_type == 'sg_rule':
+ self.updated_rule_sg_ids.update(sec_group_ids)
+ elif action_type == 'sg_member':
+ if device_ids:
+ self.updated_sg_members.update(device_ids)
+ else:
+ self._update_remote_security_group_members(sec_group_ids)
+
def update_security_group_rules(self, sg_id, sg_rules):
LOG.debug("Update rules of security group (%s)", sg_id)
self.sg_rules[sg_id] = sg_rules
if not sg_has_members:
del self.sg_members[sg_id]
+ def _find_deleted_sg_rules(self, sg_id):
+ del_rules = list()
+ for pre_rule in self.pre_sg_rules.get(sg_id, []):
+ if pre_rule not in self.sg_rules.get(sg_id, []):
+ del_rules.append(pre_rule)
+ return del_rules
+
+ def _find_devices_on_security_group(self, sg_id):
+ device_list = list()
+ for device in self.filtered_ports.values():
+ if sg_id in device.get('security_groups', []):
+ device_list.append(device)
+ return device_list
+
+ def _clean_deleted_sg_rule_conntrack_entries(self):
+ deleted_sg_ids = set()
+ for sg_id in self.updated_rule_sg_ids:
+ del_rules = self._find_deleted_sg_rules(sg_id)
+ if not del_rules:
+ continue
+ device_list = self._find_devices_on_security_group(sg_id)
+ for rule in del_rules:
+ self.ipconntrack.delete_conntrack_state_by_rule(
+ device_list, rule)
+ deleted_sg_ids.add(sg_id)
+ for id in deleted_sg_ids:
+ self.updated_rule_sg_ids.remove(id)
+
+ def _clean_updated_sg_member_conntrack_entries(self):
+ updated_device_ids = set()
+ for device in self.updated_sg_members:
+ sec_group_change = False
+ device_info = self.filtered_ports.get(device)
+ pre_device_info = self._pre_defer_filtered_ports.get(device)
+ if not (device_info or pre_device_info):
+ continue
+ for sg_id in pre_device_info.get('security_groups', []):
+ if sg_id not in device_info.get('security_groups', []):
+ sec_group_change = True
+ break
+ if not sec_group_change:
+ continue
+ for ethertype in [constants.IPv4, constants.IPv6]:
+ self.ipconntrack.delete_conntrack_state_by_remote_ips(
+ [device_info], ethertype, set())
+ updated_device_ids.add(device)
+ for id in updated_device_ids:
+ self.updated_sg_members.remove(id)
+
+ def _clean_deleted_remote_sg_members_conntrack_entries(self):
+ deleted_sg_ids = set()
+ for sg_id, devices in self.devices_with_udpated_sg_members.items():
+ for ethertype in [constants.IPv4, constants.IPv6]:
+ pre_ips = self._get_sg_members(
+ self.pre_sg_members, sg_id, ethertype)
+ cur_ips = self._get_sg_members(
+ self.sg_members, sg_id, ethertype)
+ ips = (pre_ips - cur_ips)
+ if devices and ips:
+ self.ipconntrack.delete_conntrack_state_by_remote_ips(
+ devices, ethertype, ips)
+ deleted_sg_ids.add(sg_id)
+ for id in deleted_sg_ids:
+ self.devices_with_udpated_sg_members.pop(id, None)
+
+ def _remove_conntrack_entries_from_sg_updates(self):
+ self._clean_deleted_sg_rule_conntrack_entries()
+ self._clean_updated_sg_member_conntrack_entries()
+ self._clean_deleted_remote_sg_members_conntrack_entries()
+
+ def _get_sg_members(self, sg_info, sg_id, ethertype):
+ return set(sg_info.get(sg_id, {}).get(ethertype, []))
+
def filter_defer_apply_off(self):
if self._defer_apply:
self._defer_apply = False
self._setup_chains_apply(self.filtered_ports,
self.unfiltered_ports)
self.iptables.defer_apply_off()
+ self._remove_conntrack_entries_from_sg_updates()
self._remove_unused_security_group_info()
self._pre_defer_filtered_ports = None
self._pre_defer_unfiltered_ports = None
"rule updated %r"), security_groups)
self._security_group_updated(
security_groups,
- 'security_groups')
+ 'security_groups',
+ 'sg_rule')
def security_groups_member_updated(self, security_groups):
LOG.info(_LI("Security group "
"member updated %r"), security_groups)
self._security_group_updated(
security_groups,
- 'security_group_source_groups')
+ 'security_group_source_groups',
+ 'sg_member')
- def _security_group_updated(self, security_groups, attribute):
+ def _security_group_updated(self, security_groups, attribute, action_type):
devices = []
sec_grp_set = set(security_groups)
for device in self.firewall.ports.values():
if sec_grp_set & set(device.get(attribute, [])):
devices.append(device['device'])
if devices:
+ self.firewall.security_group_updated(action_type, sec_grp_set)
if self.defer_refresh_firewall:
LOG.debug("Adding %s devices to the list of devices "
"for which firewall needs to be refreshed",
LOG.debug("Refreshing firewall for all filtered devices")
self.refresh_firewall()
else:
+ self.firewall.security_group_updated('sg_member', [],
+ updated_devices)
# If a device is both in new and updated devices
# avoid reprocessing it
updated_devices = ((updated_devices | devices_to_refilter) -
filter_inst.assert_has_calls(calls)
+ def _test_remove_conntrack_entries(self, ethertype, protocol,
+ direction):
+ port = self._fake_port()
+ port['zone_id'] = 1
+ port['security_groups'] = 'fake_sg_id'
+ self.firewall.filtered_ports[port['device']] = port
+ self.firewall.updated_rule_sg_ids = set(['fake_sg_id'])
+ self.firewall.sg_rules['fake_sg_id'] = [
+ {'direction': direction, 'ethertype': ethertype,
+ 'protocol': protocol}]
+
+ self.firewall.filter_defer_apply_on()
+ self.firewall.sg_rules['fake_sg_id'] = []
+ self.firewall.filter_defer_apply_off()
+ cmd = ['conntrack', '-D']
+ if protocol:
+ cmd.extend(['-p', protocol])
+ if ethertype == 'IPv4':
+ cmd.extend(['-f', 'ipv4'])
+ if direction == 'ingress':
+ cmd.extend(['-d', '10.0.0.1'])
+ else:
+ cmd.extend(['-s', '10.0.0.1'])
+ else:
+ cmd.extend(['-f', 'ipv6'])
+ if direction == 'ingress':
+ cmd.extend(['-d', 'fe80::1'])
+ else:
+ cmd.extend(['-s', 'fe80::1'])
+ cmd.extend(['-w', 1])
+ calls = [
+ mock.call(cmd, run_as_root=True, check_exit_code=True,
+ extra_ok_codes=[1])]
+ self.utils_exec.assert_has_calls(calls)
+
+ def test_remove_conntrack_entries_for_delete_rule_ipv4(self):
+ for direction in ['ingress', 'egress']:
+ for pro in [None, 'tcp', 'icmp', 'udp']:
+ self._test_remove_conntrack_entries(
+ 'IPv4', pro, direction)
+
+ def test_remove_conntrack_entries_for_delete_rule_ipv6(self):
+ for direction in ['ingress', 'egress']:
+ for pro in [None, 'tcp', 'icmp', 'udp']:
+ self._test_remove_conntrack_entries(
+ 'IPv6', pro, direction)
+
+ def test_remove_conntrack_entries_for_port_sec_group_change(self):
+ port = self._fake_port()
+ port['zone_id'] = 1
+ port['security_groups'] = ['fake_sg_id']
+ self.firewall.filtered_ports[port['device']] = port
+ self.firewall.updated_sg_members = set(['tapfake_dev'])
+ self.firewall.filter_defer_apply_on()
+ new_port = copy.deepcopy(port)
+ new_port['security_groups'] = ['fake_sg_id2']
+ self.firewall.filtered_ports[port['device']] = new_port
+ self.firewall.filter_defer_apply_off()
+ calls = [
+ mock.call(['conntrack', '-D', '-f', 'ipv4', '-d', '10.0.0.1',
+ '-w', 1],
+ run_as_root=True, check_exit_code=True,
+ extra_ok_codes=[1]),
+ mock.call(['conntrack', '-D', '-f', 'ipv6', '-d', 'fe80::1',
+ '-w', 1],
+ run_as_root=True, check_exit_code=True,
+ extra_ok_codes=[1])]
+ self.utils_exec.assert_has_calls(calls)
+
def test_update_delete_port_filter(self):
port = self._fake_port()
port['security_group_rules'] = [{'ethertype': 'IPv4',