]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Adds Hyper-V Security Groups implementation
authorClaudiu Belu <cbelu@cloudbasesolutions.com>
Thu, 13 Feb 2014 00:52:47 +0000 (16:52 -0800)
committerClaudiu Belu <cbelu@cloudbasesolutions.com>
Tue, 4 Mar 2014 15:43:18 +0000 (07:43 -0800)
Implements the security groups API in the Hyper-V agent.

To enable security groups on the Hyper-V agent, its config file
must contain the following option:

[SECURITYGROUP]
firewall_driver=neutron.plugins.hyperv.agent.security_groups_driver.HyperVSecurityGroupsDriver

Change-Id: I7556001557cd013c10b7f883dbf371afa8d09626
Implements: blueprint hyperv-security-groups

neutron/plugins/hyperv/agent/hyperv_neutron_agent.py
neutron/plugins/hyperv/agent/security_groups_driver.py [new file with mode: 0644]
neutron/plugins/hyperv/agent/utilsfactory.py
neutron/plugins/hyperv/agent/utilsv2.py
neutron/tests/unit/hyperv/test_hyperv_neutron_agent.py
neutron/tests/unit/hyperv/test_hyperv_security_groups_driver.py [new file with mode: 0644]
neutron/tests/unit/hyperv/test_hyperv_utilsfactory.py
neutron/tests/unit/hyperv/test_hyperv_utilsv2.py

index 755ed67f12c7671b11dc928cdac83cff86b90cfd..cb3054f12058162652e1907bf9bd3abee1f8ffd8 100644 (file)
@@ -27,6 +27,7 @@ from oslo.config import cfg
 
 from neutron.agent.common import config
 from neutron.agent import rpc as agent_rpc
+from neutron.agent import securitygroups_rpc as sg_rpc
 from neutron.common import config as logging_config
 from neutron.common import constants as n_const
 from neutron.common import topics
@@ -70,6 +71,45 @@ CONF.register_opts(agent_opts, "AGENT")
 config.register_agent_state_opts_helper(cfg.CONF)
 
 
+class HyperVSecurityAgent(sg_rpc.SecurityGroupAgentRpcMixin):
+    # Set RPC API version to 1.1 by default.
+    RPC_API_VERSION = '1.1'
+
+    def __init__(self, context, plugin_rpc):
+        self.context = context
+        self.plugin_rpc = plugin_rpc
+        self.init_firewall()
+
+        if sg_rpc.is_firewall_enabled():
+            self._setup_rpc()
+
+    def _setup_rpc(self):
+        self.topic = topics.AGENT
+        self.dispatcher = self._create_rpc_dispatcher()
+        consumers = [[topics.SECURITY_GROUP, topics.UPDATE]]
+
+        self.connection = agent_rpc.create_consumers(self.dispatcher,
+                                                     self.topic,
+                                                     consumers)
+
+    def _create_rpc_dispatcher(self):
+        rpc_callback = HyperVSecurityCallbackMixin(self)
+        return dispatcher.RpcDispatcher([rpc_callback])
+
+
+class HyperVSecurityCallbackMixin(sg_rpc.SecurityGroupAgentRpcCallbackMixin):
+    # Set RPC API version to 1.1 by default.
+    RPC_API_VERSION = '1.1'
+
+    def __init__(self, sg_agent):
+        self.sg_agent = sg_agent
+
+
+class HyperVPluginApi(agent_rpc.PluginApi,
+                      sg_rpc.SecurityGroupServerRpcApiMixin):
+    pass
+
+
 class HyperVNeutronAgent(object):
     # Set RPC API version to 1.0 by default.
     RPC_API_VERSION = '1.0'
@@ -103,7 +143,7 @@ class HyperVNeutronAgent(object):
     def _setup_rpc(self):
         self.agent_id = 'hyperv_%s' % platform.node()
         self.topic = topics.AGENT
-        self.plugin_rpc = agent_rpc.PluginApi(topics.PLUGIN)
+        self.plugin_rpc = HyperVPluginApi(topics.PLUGIN)
 
         self.state_rpc = agent_rpc.PluginReportStateAPI(topics.PLUGIN)
 
@@ -119,6 +159,9 @@ class HyperVNeutronAgent(object):
         self.connection = agent_rpc.create_consumers(self.dispatcher,
                                                      self.topic,
                                                      consumers)
+
+        self.sec_groups_agent = HyperVSecurityAgent(
+            self.context, self.plugin_rpc)
         report_interval = CONF.AGENT.report_interval
         if report_interval:
             heartbeat = loopingcall.LoopingCall(self._report_state)
@@ -165,6 +208,9 @@ class HyperVNeutronAgent(object):
     def port_update(self, context, port=None, network_type=None,
                     segmentation_id=None, physical_network=None):
         LOG.debug(_("port_update received"))
+        if 'security_groups' in port:
+            self.sec_groups_agent.refresh_firewall()
+
         self._treat_vif_port(
             port['id'], port['network_id'],
             network_type, physical_network,
@@ -311,6 +357,8 @@ class HyperVNeutronAgent(object):
                     device_details['physical_network'],
                     device_details['segmentation_id'],
                     device_details['admin_state_up'])
+
+                self.sec_groups_agent.prepare_devices_filter(devices)
                 self.plugin_rpc.update_device_up(self.context,
                                                  device,
                                                  self.agent_id,
diff --git a/neutron/plugins/hyperv/agent/security_groups_driver.py b/neutron/plugins/hyperv/agent/security_groups_driver.py
new file mode 100644 (file)
index 0000000..ac3f4c9
--- /dev/null
@@ -0,0 +1,136 @@
+#Copyright 2014 Cloudbase Solutions SRL
+#All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+# @author: Claudiu Belu, Cloudbase Solutions Srl
+
+from neutron.agent import firewall
+from neutron.openstack.common import log as logging
+from neutron.plugins.hyperv.agent import utilsfactory
+from neutron.plugins.hyperv.agent import utilsv2
+
+LOG = logging.getLogger(__name__)
+
+
+class HyperVSecurityGroupsDriver(firewall.FirewallDriver):
+    """Security Groups Driver.
+
+    Security Groups implementation for Hyper-V VMs.
+    """
+
+    _ACL_PROP_MAP = {
+        'direction': {'ingress': utilsv2.HyperVUtilsV2._ACL_DIR_IN,
+                      'egress': utilsv2.HyperVUtilsV2._ACL_DIR_OUT},
+        'ethertype': {'IPv4': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV4,
+                      'IPv6': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV6},
+        'default': "ANY",
+        'address_default': {'IPv4': '0.0.0.0/0', 'IPv6': '::/0'}
+    }
+
+    def __init__(self):
+        self._utils = utilsfactory.get_hypervutils()
+        self._security_ports = {}
+
+    def prepare_port_filter(self, port):
+        LOG.debug('Creating port %s rules' % len(port['security_group_rules']))
+
+        # newly created port, add default rules.
+        if port['device'] not in self._security_ports:
+            LOG.debug('Creating default reject rules.')
+            self._utils.create_default_reject_all_rules(port['id'])
+
+        self._security_ports[port['device']] = port
+        self._create_port_rules(port['id'], port['security_group_rules'])
+
+    def _create_port_rules(self, port_id, rules):
+        for rule in rules:
+            param_map = self._create_param_map(rule)
+            try:
+                self._utils.create_security_rule(port_id, **param_map)
+            except Exception as ex:
+                LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
+                            'adding rule: %(rule)s'),
+                          dict(hyperv_exeption=ex, rule=rule))
+
+    def _remove_port_rules(self, port_id, rules):
+        for rule in rules:
+            param_map = self._create_param_map(rule)
+            try:
+                self._utils.remove_security_rule(port_id, **param_map)
+            except Exception as ex:
+                LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
+                            'removing rule: %(rule)s'),
+                          dict(hyperv_exeption=ex, rule=rule))
+
+    def _create_param_map(self, rule):
+        if 'port_range_min' in rule and 'port_range_max' in rule:
+            local_port = '%s-%s' % (rule['port_range_min'],
+                                    rule['port_range_max'])
+        else:
+            local_port = self._ACL_PROP_MAP['default']
+
+        return {
+            'direction': self._ACL_PROP_MAP['direction'][rule['direction']],
+            'acl_type': self._ACL_PROP_MAP['ethertype'][rule['ethertype']],
+            'local_port': local_port,
+            'protocol': self._get_rule_prop_or_default(rule, 'protocol'),
+            'remote_address': self._get_rule_remote_address(rule)
+        }
+
+    def apply_port_filter(self, port):
+        LOG.info('Aplying port filter.')
+
+    def update_port_filter(self, port):
+        LOG.info('Updating port rules.')
+
+        if port['device'] not in self._security_ports:
+            self.prepare_port_filter(port)
+            return
+
+        old_port = self._security_ports[port['device']]
+        rules = old_port['security_group_rules']
+        param_port_rules = port['security_group_rules']
+
+        new_rules = [r for r in param_port_rules if r not in rules]
+        remove_rules = [r for r in rules if r not in param_port_rules]
+
+        LOG.info("Creating %s new rules, removing %s old rules." % (
+                 len(new_rules), len(remove_rules)))
+
+        self._remove_port_rules(old_port['id'], remove_rules)
+        self._create_port_rules(port['id'], new_rules)
+
+        self._security_ports[port['device']] = port
+
+    def remove_port_filter(self, port):
+        LOG.info('Removing port filter')
+        self._security_ports.pop(port['device'], None)
+
+    @property
+    def ports(self):
+        return self._security_ports
+
+    def _get_rule_remote_address(self, rule):
+        if rule['direction'] is 'ingress':
+            ip_prefix = 'source_ip_prefix'
+        else:
+            ip_prefix = 'dest_ip_prefix'
+
+        if ip_prefix in rule:
+            return rule[ip_prefix]
+        return self._ACL_PROP_MAP['address_default'][rule['ethertype']]
+
+    def _get_rule_prop_or_default(self, rule, prop):
+        if prop in rule:
+            return rule[prop]
+        return self._ACL_PROP_MAP['default']
index 8d594a223ad1c0ea6640f8cb541a8dcce0272e94..8b6ae77293fd7c7ef8f604a7c5ffb2b8f771aba5 100644 (file)
@@ -49,18 +49,24 @@ def _check_min_windows_version(major, minor, build=0):
     return map(int, version_str.split('.')) >= [major, minor, build]
 
 
-def _get_class(v1_class, v2_class, force_v1_flag):
-    # V2 classes are supported starting from Hyper-V Server 2012 and
-    # Windows Server 2012 (kernel version 6.2)
-    if not force_v1_flag and _check_min_windows_version(6, 2):
-        cls = v2_class
+def get_hypervutils():
+    # V1 virtualization namespace features are supported up to
+    # Windows Server / Hyper-V Server 2012
+    # V2 virtualization namespace features are supported starting with
+    # Windows Server / Hyper-V Server 2012
+    # Windows Server / Hyper-V Server 2012 R2 uses the V2 namespace and
+    # introduces additional features
+
+    force_v1_flag = CONF.hyperv.force_hyperv_utils_v1
+    if _check_min_windows_version(6, 3):
+        if force_v1_flag:
+            LOG.warning('V1 virtualization namespace no longer supported on '
+                        'Windows Server / Hyper-V Server 2012 R2 or above.')
+        cls = utilsv2.HyperVUtilsV2R2
+    elif not force_v1_flag and _check_min_windows_version(6, 2):
+        cls = utilsv2.HyperVUtilsV2
     else:
-        cls = v1_class
+        cls = utils.HyperVUtils
     LOG.debug(_("Loading class: %(module_name)s.%(class_name)s"),
               {'module_name': cls.__module__, 'class_name': cls.__name__})
-    return cls
-
-
-def get_hypervutils():
-    return _get_class(utils.HyperVUtils, utilsv2.HyperVUtilsV2,
-                      CONF.hyperv.force_hyperv_utils_v1)()
+    return cls()
index d2a9a7d3ef78ff49d24381018736636b6171a5cd..5b280c1853ea562e1c33bcebf4b1f603ba89dd4e 100644 (file)
@@ -26,17 +26,32 @@ class HyperVUtilsV2(utils.HyperVUtils):
     _ETHERNET_SWITCH_PORT = 'Msvm_EthernetSwitchPort'
     _PORT_ALLOC_SET_DATA = 'Msvm_EthernetPortAllocationSettingData'
     _PORT_VLAN_SET_DATA = 'Msvm_EthernetSwitchPortVlanSettingData'
+    _PORT_SECURITY_SET_DATA = 'Msvm_EthernetSwitchPortSecuritySettingData'
     _PORT_ALLOC_ACL_SET_DATA = 'Msvm_EthernetSwitchPortAclSettingData'
+    _PORT_EXT_ACL_SET_DATA = _PORT_ALLOC_ACL_SET_DATA
     _LAN_ENDPOINT = 'Msvm_LANEndpoint'
     _STATE_DISABLED = 3
     _OPERATION_MODE_ACCESS = 1
 
     _ACL_DIR_IN = 1
     _ACL_DIR_OUT = 2
+
     _ACL_TYPE_IPV4 = 2
     _ACL_TYPE_IPV6 = 3
+
+    _ACL_ACTION_ALLOW = 1
+    _ACL_ACTION_DENY = 2
     _ACL_ACTION_METER = 3
+
     _ACL_APPLICABILITY_LOCAL = 1
+    _ACL_APPLICABILITY_REMOTE = 2
+
+    _ACL_DEFAULT = 'ANY'
+    _IPV4_ANY = '0.0.0.0/0'
+    _IPV6_ANY = '::/0'
+    _TCP_PROTOCOL = 'tcp'
+    _UDP_PROTOCOL = 'udp'
+    _MAX_WEIGHT = 65500
 
     _wmi_namespace = '//./root/virtualization/v2'
 
@@ -80,6 +95,12 @@ class HyperVUtilsV2(utils.HyperVUtils):
             element.path_(), [res_setting_data.GetText_(1)])
         self._check_job_status(ret_val, job_path)
 
+    def _remove_virt_feature(self, feature_resource):
+        vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
+        (job_path, ret_val) = vs_man_svc.RemoveFeatureSettings(
+            FeatureSettings=[feature_resource.path_()])
+        self._check_job_status(ret_val, job_path)
+
     def disconnect_switch_port(
             self, vswitch_name, switch_port_name, delete_port):
         """Disconnects the switch port."""
@@ -121,7 +142,7 @@ class HyperVUtilsV2(utils.HyperVUtils):
         port_alloc, found = self._get_switch_port_allocation(switch_port_name)
         if not found:
             raise utils.HyperVException(
-                msg=_('Port Alloc not found: %s') % switch_port_name)
+                msg=_('Port Allocation not found: %s') % switch_port_name)
 
         vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
         vlan_settings = self._get_vlan_setting_data_from_port_alloc(port_alloc)
@@ -196,3 +217,173 @@ class HyperVUtilsV2(utils.HyperVUtils):
                     acl.Action = self._ACL_ACTION_METER
                     acl.Applicability = self._ACL_APPLICABILITY_LOCAL
                     self._add_virt_feature(port, acl)
+
+    def create_security_rule(self, switch_port_name, direction, acl_type,
+                             local_port, protocol, remote_address):
+        port, found = self._get_switch_port_allocation(switch_port_name, False)
+        if not found:
+            return
+
+        # Add the ACLs only if they don't already exist
+        acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+        weight = self._get_new_weight(acls)
+        self._bind_security_rule(
+            port, direction, acl_type, self._ACL_ACTION_ALLOW, local_port,
+            protocol, remote_address, weight)
+
+    def remove_security_rule(self, switch_port_name, direction, acl_type,
+                             local_port, protocol, remote_address):
+        port, found = self._get_switch_port_allocation(switch_port_name, False)
+        if not found:
+            # Port not found. It happens when the VM was already deleted.
+            return
+
+        acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+        filtered_acls = self._filter_security_acls(
+            acls, self._ACL_ACTION_ALLOW, direction, acl_type, local_port,
+            protocol, remote_address)
+
+        for acl in filtered_acls:
+            self._remove_virt_feature(acl)
+
+    def create_default_reject_all_rules(self, switch_port_name):
+        port, found = self._get_switch_port_allocation(switch_port_name, False)
+        if not found:
+            raise utils.HyperVException(
+                msg=_('Port Allocation not found: %s') % switch_port_name)
+
+        acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+        filtered_acls = [v for v in acls if v.Action == self._ACL_ACTION_DENY]
+
+        # 2 directions x 2 address types x 2 protocols = 8 ACLs
+        if len(filtered_acls) >= 8:
+            return
+
+        for acl in filtered_acls:
+            self._remove_virt_feature(acl)
+
+        weight = 0
+        ipv4_pair = (self._ACL_TYPE_IPV4, self._IPV4_ANY)
+        ipv6_pair = (self._ACL_TYPE_IPV6, self._IPV6_ANY)
+        for direction in [self._ACL_DIR_IN, self._ACL_DIR_OUT]:
+            for acl_type, address in [ipv4_pair, ipv6_pair]:
+                for protocol in [self._TCP_PROTOCOL, self._UDP_PROTOCOL]:
+                    self._bind_security_rule(
+                        port, direction, acl_type, self._ACL_ACTION_DENY,
+                        self._ACL_DEFAULT, protocol, address, weight)
+                    weight += 1
+
+    def _bind_security_rule(self, port, direction, acl_type, action,
+                            local_port, protocol, remote_address, weight):
+        acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+        filtered_acls = self._filter_security_acls(
+            acls, action, direction, acl_type, local_port, protocol,
+            remote_address)
+
+        for acl in filtered_acls:
+            self._remove_virt_feature(acl)
+
+        acl = self._create_security_acl(
+            direction, acl_type, action, local_port, protocol, remote_address,
+            weight)
+
+        self._add_virt_feature(port, acl)
+
+    def _create_acl(self, direction, acl_type, action):
+        acl = self._get_default_setting_data(self._PORT_ALLOC_ACL_SET_DATA)
+        acl.set(Direction=direction,
+                AclType=acl_type,
+                Action=action,
+                Applicability=self._ACL_APPLICABILITY_LOCAL)
+        return acl
+
+    def _create_security_acl(self, direction, acl_type, action, local_port,
+                             protocol, remote_ip_address, weight):
+        acl = self._create_acl(direction, acl_type, action)
+        (remote_address, remote_prefix_length) = remote_ip_address.split('/')
+        acl.set(Applicability=self._ACL_APPLICABILITY_REMOTE,
+                RemoteAddress=remote_address,
+                RemoteAddressPrefixLength=remote_prefix_length)
+        return acl
+
+    def _filter_acls(self, acls, action, direction, acl_type, remote_addr=""):
+        return [v for v in acls
+                if v.Action == action and
+                v.Direction == direction and
+                v.AclType == acl_type and
+                v.RemoteAddress == remote_addr]
+
+    def _filter_security_acls(self, acls, acl_action, direction, acl_type,
+                              local_port, protocol, remote_addr=""):
+        (remote_address, remote_prefix_length) = remote_addr.split('/')
+        remote_prefix_length = int(remote_prefix_length)
+
+        return [v for v in acls
+                if v.Direction == direction and
+                v.Action in [self._ACL_ACTION_ALLOW, self._ACL_ACTION_DENY] and
+                v.AclType == acl_type and
+                v.RemoteAddress == remote_address and
+                v.RemoteAddressPrefixLength == remote_prefix_length]
+
+    def _get_new_weight(self, acls):
+        return 0
+
+
+class HyperVUtilsV2R2(HyperVUtilsV2):
+    _PORT_EXT_ACL_SET_DATA = 'Msvm_EthernetSwitchPortExtendedAclSettingData'
+    _MAX_WEIGHT = 65500
+
+    def create_security_rule(self, switch_port_name, direction, acl_type,
+                             local_port, protocol, remote_address):
+        protocols = [protocol]
+        if protocol is self._ACL_DEFAULT:
+            protocols = [self._TCP_PROTOCOL, self._UDP_PROTOCOL]
+
+        for proto in protocols:
+            super(HyperVUtilsV2R2, self).create_security_rule(
+                switch_port_name, direction, acl_type, local_port,
+                proto, remote_address)
+
+    def remove_security_rule(self, switch_port_name, direction, acl_type,
+                             local_port, protocol, remote_address):
+        protocols = [protocol]
+        if protocol is self._ACL_DEFAULT:
+            protocols = ['tcp', 'udp']
+
+        for proto in protocols:
+            super(HyperVUtilsV2R2, self).remove_security_rule(
+                switch_port_name, direction, acl_type,
+                local_port, proto, remote_address)
+
+    def _create_security_acl(self, direction, acl_type, action, local_port,
+                             protocol, remote_addr, weight):
+        acl = self._get_default_setting_data(self._PORT_EXT_ACL_SET_DATA)
+        acl.set(Direction=direction,
+                Action=action,
+                LocalPort=str(local_port),
+                Protocol=protocol,
+                RemoteIPAddress=remote_addr,
+                IdleSessionTimeout=0,
+                Weight=weight)
+        return acl
+
+    def _filter_security_acls(self, acls, action, direction, acl_type,
+                              local_port, protocol, remote_addr=""):
+        return [v for v in acls
+                if v.Action == action and
+                v.Direction == direction and
+                v.LocalPort in [str(local_port), self._ACL_DEFAULT] and
+                v.Protocol in [protocol] and
+                v.RemoteIPAddress == remote_addr]
+
+    def _get_new_weight(self, acls):
+        if not acls:
+            return self._MAX_WEIGHT - 1
+
+        weights = [a.Weight for a in acls]
+        min_weight = min(weights)
+        for weight in range(min_weight, self._MAX_WEIGHT):
+            if weight not in weights:
+                return weight
+
+        return min_weight - 1
index dc835647d7db7e52e6488595d5cebf1cb76a0c11..28455c6b203575fc293dd937e2ca66a6177a2e3e 100644 (file)
@@ -57,6 +57,7 @@ class TestHyperVNeutronAgent(base.BaseTestCase):
 
         self.agent = hyperv_neutron_agent.HyperVNeutronAgent()
         self.agent.plugin_rpc = mock.Mock()
+        self.agent.sec_groups_agent = mock.MagicMock()
         self.agent.context = mock.Mock()
         self.agent.agent_id = mock.Mock()
 
diff --git a/neutron/tests/unit/hyperv/test_hyperv_security_groups_driver.py b/neutron/tests/unit/hyperv/test_hyperv_security_groups_driver.py
new file mode 100644 (file)
index 0000000..1ac5b31
--- /dev/null
@@ -0,0 +1,176 @@
+# Copyright 2014 Cloudbase Solutions SRL
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+# @author: Claudiu Belu, Cloudbase Solutions Srl
+
+"""
+Unit tests for the Hyper-V Security Groups Driver.
+"""
+
+import mock
+from oslo.config import cfg
+
+from neutron.plugins.hyperv.agent import security_groups_driver as sg_driver
+from neutron.plugins.hyperv.agent import utilsfactory
+from neutron.tests import base
+
+CONF = cfg.CONF
+
+
+class TestHyperVSecurityGroupsDriver(base.BaseTestCase):
+
+    _FAKE_DEVICE = 'fake_device'
+    _FAKE_ID = 'fake_id'
+    _FAKE_DIRECTION = 'ingress'
+    _FAKE_ETHERTYPE = 'IPv4'
+    _FAKE_ETHERTYPE_IPV6 = 'IPv6'
+    _FAKE_DEST_IP_PREFIX = 'fake_dest_ip_prefix'
+    _FAKE_SOURCE_IP_PREFIX = 'fake_source_ip_prefix'
+    _FAKE_PARAM_NAME = 'fake_param_name'
+    _FAKE_PARAM_VALUE = 'fake_param_value'
+
+    _FAKE_PORT_MIN = 9001
+    _FAKE_PORT_MAX = 9011
+
+    def setUp(self):
+        super(TestHyperVSecurityGroupsDriver, self).setUp()
+        self._mock_windows_version = mock.patch.object(utilsfactory,
+                                                       'get_hypervutils')
+        self._mock_windows_version.start()
+        self.addCleanup(mock.patch.stopall)
+        self._driver = sg_driver.HyperVSecurityGroupsDriver()
+        self._driver._utils = mock.MagicMock()
+
+    @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+                '.HyperVSecurityGroupsDriver._create_port_rules')
+    def test_prepare_port_filter(self, mock_create_rules):
+        mock_port = self._get_port()
+        mock_utils_method = self._driver._utils.create_default_reject_all_rules
+        self._driver.prepare_port_filter(mock_port)
+
+        self.assertEqual(mock_port,
+                         self._driver._security_ports[self._FAKE_DEVICE])
+        mock_utils_method.assert_called_once_with(self._FAKE_ID)
+        self._driver._create_port_rules.assert_called_once_with(
+            self._FAKE_ID, mock_port['security_group_rules'])
+
+    def test_update_port_filter(self):
+        mock_port = self._get_port()
+        new_mock_port = self._get_port()
+        new_mock_port['id'] += '2'
+        new_mock_port['security_group_rules'][0]['ethertype'] += "2"
+
+        self._driver._security_ports[mock_port['device']] = mock_port
+        self._driver._create_port_rules = mock.MagicMock()
+        self._driver._remove_port_rules = mock.MagicMock()
+        self._driver.update_port_filter(new_mock_port)
+
+        self._driver._remove_port_rules.assert_called_once_with(
+            mock_port['id'], mock_port['security_group_rules'])
+        self._driver._create_port_rules.assert_called_once_with(
+            new_mock_port['id'], new_mock_port['security_group_rules'])
+        self.assertEqual(new_mock_port,
+                         self._driver._security_ports[new_mock_port['device']])
+
+    @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+                '.HyperVSecurityGroupsDriver.prepare_port_filter')
+    def test_update_port_filter_new_port(self, mock_method):
+        mock_port = self._get_port()
+        self._driver.prepare_port_filter = mock.MagicMock()
+        self._driver.update_port_filter(mock_port)
+
+        self._driver.prepare_port_filter.assert_called_once_with(mock_port)
+
+    def test_remove_port_filter(self):
+        mock_port = self._get_port()
+        self._driver._security_ports[mock_port['device']] = mock_port
+        self._driver.remove_port_filter(mock_port)
+        self.assertFalse(mock_port['device'] in self._driver._security_ports)
+
+    def test_create_port_rules_exception(self):
+        fake_rule = self._create_security_rule()
+        self._driver._utils.create_security_rule.side_effect = Exception(
+            'Generated Exception for testing.')
+        self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
+
+    def test_create_param_map(self):
+        fake_rule = self._create_security_rule()
+        self._driver._get_rule_remote_address = mock.MagicMock(
+            return_value=self._FAKE_SOURCE_IP_PREFIX)
+        actual = self._driver._create_param_map(fake_rule)
+        expected = {
+            'direction': self._driver._ACL_PROP_MAP[
+                'direction'][self._FAKE_DIRECTION],
+            'acl_type': self._driver._ACL_PROP_MAP[
+                'ethertype'][self._FAKE_ETHERTYPE],
+            'local_port': '%s-%s' % (self._FAKE_PORT_MIN, self._FAKE_PORT_MAX),
+            'protocol': self._driver._ACL_PROP_MAP['default'],
+            'remote_address': self._FAKE_SOURCE_IP_PREFIX
+        }
+
+        self.assertEqual(expected, actual)
+
+    @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+                '.HyperVSecurityGroupsDriver._create_param_map')
+    def test_create_port_rules(self, mock_method):
+        fake_rule = self._create_security_rule()
+        mock_method.return_value = {
+            self._FAKE_PARAM_NAME: self._FAKE_PARAM_VALUE}
+        self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
+
+        self._driver._utils.create_security_rule.assert_called_once_with(
+            self._FAKE_ID, fake_param_name=self._FAKE_PARAM_VALUE)
+
+    def test_convert_any_address_to_same_ingress(self):
+        rule = self._create_security_rule()
+        actual = self._driver._get_rule_remote_address(rule)
+        self.assertEqual(self._FAKE_SOURCE_IP_PREFIX, actual)
+
+    def test_convert_any_address_to_same_egress(self):
+        rule = self._create_security_rule()
+        rule['direction'] += '2'
+        actual = self._driver._get_rule_remote_address(rule)
+        self.assertEqual(self._FAKE_DEST_IP_PREFIX, actual)
+
+    def test_convert_any_address_to_ipv4(self):
+        rule = self._create_security_rule()
+        del rule['source_ip_prefix']
+        actual = self._driver._get_rule_remote_address(rule)
+        self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv4'],
+                         actual)
+
+    def test_convert_any_address_to_ipv6(self):
+        rule = self._create_security_rule()
+        del rule['source_ip_prefix']
+        rule['ethertype'] = self._FAKE_ETHERTYPE_IPV6
+        actual = self._driver._get_rule_remote_address(rule)
+        self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv6'],
+                         actual)
+
+    def _get_port(self):
+        return {
+            'device': self._FAKE_DEVICE,
+            'id': self._FAKE_ID,
+            'security_group_rules': [self._create_security_rule()]
+        }
+
+    def _create_security_rule(self):
+        return {
+            'direction': self._FAKE_DIRECTION,
+            'ethertype': self._FAKE_ETHERTYPE,
+            'dest_ip_prefix': self._FAKE_DEST_IP_PREFIX,
+            'source_ip_prefix': self._FAKE_SOURCE_IP_PREFIX,
+            'port_range_min': self._FAKE_PORT_MIN,
+            'port_range_max': self._FAKE_PORT_MAX
+        }
index bc2622844746d815a75dcfb39cc12f014cfaa60b..fef96d734bd87f970c6b932d15e1b73efdc04703 100644 (file)
@@ -34,6 +34,9 @@ CONF = cfg.CONF
 
 class TestHyperVUtilsFactory(base.BaseTestCase):
 
+    def test_get_hypervutils_v2_r2(self):
+        self._test_returned_class(utilsv2.HyperVUtilsV2R2, True, '6.3.0')
+
     def test_get_hypervutils_v2(self):
         self._test_returned_class(utilsv2.HyperVUtilsV2, False, '6.2.0')
 
index 82786c918802906900b60e562c38cd18cabc945c..b30d4785ab2ed83108aa42bb9aedd54c6778262d 100644 (file)
@@ -41,6 +41,14 @@ class TestHyperVUtilsV2(base.BaseTestCase):
     _FAKE_CLASS_NAME = "fake_class_name"
     _FAKE_ELEMENT_NAME = "fake_element_name"
 
+    _FAKE_ACL_ACT = 'fake_acl_action'
+    _FAKE_ACL_DIR = 'fake_acl_dir'
+    _FAKE_ACL_TYPE = 'fake_acl_type'
+    _FAKE_LOCAL_PORT = 'fake_local_port'
+    _FAKE_PROTOCOL = 'fake_port_protocol'
+    _FAKE_REMOTE_ADDR = '0.0.0.0/0'
+    _FAKE_WEIGHT = 'fake_weight'
+
     def setUp(self):
         super(TestHyperVUtilsV2, self).setUp()
         self._utils = utilsv2.HyperVUtilsV2()
@@ -144,6 +152,20 @@ class TestHyperVUtilsV2(base.BaseTestCase):
         mock_svc.RemoveResourceSettings.assert_called_with(
             ResourceSettings=[self._FAKE_RES_PATH])
 
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._check_job_status')
+    def test_remove_virt_feature(self, mock_check_job_status):
+        mock_svc = self._utils._conn.Msvm_VirtualSystemManagementService()[0]
+        mock_svc.RemoveFeatureSettings.return_value = (self._FAKE_JOB_PATH,
+                                                       self._FAKE_RET_VAL)
+        mock_res_setting_data = mock.MagicMock()
+        mock_res_setting_data.path_.return_value = self._FAKE_RES_PATH
+
+        self._utils._remove_virt_feature(mock_res_setting_data)
+
+        mock_svc.RemoveFeatureSettings.assert_called_with(
+            FeatureSettings=[self._FAKE_RES_PATH])
+
     def test_disconnect_switch_port_delete_port(self):
         self._test_disconnect_switch_port(True)
 
@@ -249,3 +271,136 @@ class TestHyperVUtilsV2(base.BaseTestCase):
             self.assertEqual(4, len(self._utils._add_virt_feature.mock_calls))
             self._utils._add_virt_feature.assert_called_with(
                 mock_port, mock_acl)
+
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._remove_virt_feature')
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._bind_security_rule')
+    def test_create_default_reject_all_rules(self, mock_bind, mock_remove):
+        (m_port, m_acl) = self._setup_security_rule_test()
+        m_acl.Action = self._utils._ACL_ACTION_DENY
+        self._utils.create_default_reject_all_rules(self._FAKE_PORT_NAME)
+
+        calls = []
+        ipv4_pair = (self._utils._ACL_TYPE_IPV4, self._utils._IPV4_ANY)
+        ipv6_pair = (self._utils._ACL_TYPE_IPV6, self._utils._IPV6_ANY)
+        for direction in [self._utils._ACL_DIR_IN, self._utils._ACL_DIR_OUT]:
+            for acl_type, address in [ipv4_pair, ipv6_pair]:
+                for protocol in [self._utils._TCP_PROTOCOL,
+                                 self._utils._UDP_PROTOCOL]:
+                    calls.append(mock.call(m_port, direction, acl_type,
+                                           self._utils._ACL_ACTION_DENY,
+                                           self._utils._ACL_DEFAULT,
+                                           protocol, address, mock.ANY))
+
+        self._utils._remove_virt_feature.assert_called_once_with(m_acl)
+        self._utils._bind_security_rule.assert_has_calls(calls)
+
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._remove_virt_feature')
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._add_virt_feature')
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._create_security_acl')
+    def test_bind_security_rule(self, mock_create_acl, mock_add, mock_remove):
+        (m_port, m_acl) = self._setup_security_rule_test()
+        mock_create_acl.return_value = m_acl
+
+        self._utils._bind_security_rule(
+            m_port, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+            self._FAKE_ACL_ACT, self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL,
+            self._FAKE_REMOTE_ADDR, self._FAKE_WEIGHT)
+
+        self._utils._add_virt_feature.assert_called_once_with(m_port, m_acl)
+
+    @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+                '._remove_virt_feature')
+    def test_remove_security_rule(self, mock_remove_feature):
+        mock_acl = self._setup_security_rule_test()[1]
+        self._utils.remove_security_rule(
+            self._FAKE_PORT_NAME, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+            self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
+        self._utils._remove_virt_feature.assert_called_once_with(mock_acl)
+
+    def _setup_security_rule_test(self):
+        mock_port = mock.MagicMock()
+        mock_acl = mock.MagicMock()
+        mock_port.associators.return_value = [mock_acl]
+
+        self._utils._get_switch_port_allocation = mock.MagicMock(return_value=(
+            mock_port, True))
+        self._utils._filter_security_acls = mock.MagicMock(
+            return_value=[mock_acl])
+
+        return (mock_port, mock_acl)
+
+    def test_filter_acls(self):
+        mock_acl = mock.MagicMock()
+        mock_acl.Action = self._FAKE_ACL_ACT
+        mock_acl.Applicability = self._utils._ACL_APPLICABILITY_LOCAL
+        mock_acl.Direction = self._FAKE_ACL_DIR
+        mock_acl.AclType = self._FAKE_ACL_TYPE
+        mock_acl.RemoteAddress = self._FAKE_REMOTE_ADDR
+
+        acls = [mock_acl, mock_acl]
+        good_acls = self._utils._filter_acls(
+            acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR,
+            self._FAKE_ACL_TYPE, self._FAKE_REMOTE_ADDR)
+        bad_acls = self._utils._filter_acls(
+            acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE)
+
+        self.assertEqual(acls, good_acls)
+        self.assertEqual([], bad_acls)
+
+
+class TestHyperVUtilsV2R2(base.BaseTestCase):
+    _FAKE_ACL_ACT = 'fake_acl_action'
+    _FAKE_ACL_DIR = 'fake_direction'
+    _FAKE_ACL_TYPE = 'fake_acl_type'
+    _FAKE_LOCAL_PORT = 'fake_local_port'
+    _FAKE_PROTOCOL = 'fake_port_protocol'
+    _FAKE_REMOTE_ADDR = '10.0.0.0/0'
+
+    def setUp(self):
+        super(TestHyperVUtilsV2R2, self).setUp()
+        self._utils = utilsv2.HyperVUtilsV2R2()
+
+    def test_filter_security_acls(self):
+        self._test_filter_security_acls(
+            self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
+
+    def test_filter_security_acls_default(self):
+        default = self._utils._ACL_DEFAULT
+        self._test_filter_security_acls(
+            default, default, self._FAKE_REMOTE_ADDR)
+
+    def _test_filter_security_acls(self, local_port, protocol, remote_addr):
+        mock_acl = mock.MagicMock()
+        mock_acl.Action = self._utils._ACL_ACTION_ALLOW
+        mock_acl.Direction = self._FAKE_ACL_DIR
+        mock_acl.LocalPort = local_port
+        mock_acl.Protocol = protocol
+        mock_acl.RemoteIPAddress = remote_addr
+
+        acls = [mock_acl, mock_acl]
+        good_acls = self._utils._filter_security_acls(
+            acls, mock_acl.Action, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+            local_port, protocol, remote_addr)
+        bad_acls = self._utils._filter_security_acls(
+            acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+            local_port, protocol, remote_addr)
+
+        self.assertEqual(acls, good_acls)
+        self.assertEqual([], bad_acls)
+
+    def test_get_new_weight(self):
+        mockacl1 = mock.MagicMock()
+        mockacl1.Weight = self._utils._MAX_WEIGHT - 1
+        mockacl2 = mock.MagicMock()
+        mockacl2.Weight = self._utils._MAX_WEIGHT - 3
+        self.assertEqual(self._utils._MAX_WEIGHT - 2,
+                         self._utils._get_new_weight([mockacl1, mockacl2]))
+
+    def test_get_new_weight_no_acls(self):
+        self.assertEqual(self._utils._MAX_WEIGHT - 1,
+                         self._utils._get_new_weight([]))