Implements the security groups API in the Hyper-V agent.
To enable security groups on the Hyper-V agent, its config file
must contain the following option:
[SECURITYGROUP]
firewall_driver=neutron.plugins.hyperv.agent.security_groups_driver.HyperVSecurityGroupsDriver
Change-Id: I7556001557cd013c10b7f883dbf371afa8d09626
Implements: blueprint hyperv-security-groups
from neutron.agent.common import config
from neutron.agent import rpc as agent_rpc
+from neutron.agent import securitygroups_rpc as sg_rpc
from neutron.common import config as logging_config
from neutron.common import constants as n_const
from neutron.common import topics
config.register_agent_state_opts_helper(cfg.CONF)
+class HyperVSecurityAgent(sg_rpc.SecurityGroupAgentRpcMixin):
+ # Set RPC API version to 1.1 by default.
+ RPC_API_VERSION = '1.1'
+
+ def __init__(self, context, plugin_rpc):
+ self.context = context
+ self.plugin_rpc = plugin_rpc
+ self.init_firewall()
+
+ if sg_rpc.is_firewall_enabled():
+ self._setup_rpc()
+
+ def _setup_rpc(self):
+ self.topic = topics.AGENT
+ self.dispatcher = self._create_rpc_dispatcher()
+ consumers = [[topics.SECURITY_GROUP, topics.UPDATE]]
+
+ self.connection = agent_rpc.create_consumers(self.dispatcher,
+ self.topic,
+ consumers)
+
+ def _create_rpc_dispatcher(self):
+ rpc_callback = HyperVSecurityCallbackMixin(self)
+ return dispatcher.RpcDispatcher([rpc_callback])
+
+
+class HyperVSecurityCallbackMixin(sg_rpc.SecurityGroupAgentRpcCallbackMixin):
+ # Set RPC API version to 1.1 by default.
+ RPC_API_VERSION = '1.1'
+
+ def __init__(self, sg_agent):
+ self.sg_agent = sg_agent
+
+
+class HyperVPluginApi(agent_rpc.PluginApi,
+ sg_rpc.SecurityGroupServerRpcApiMixin):
+ pass
+
+
class HyperVNeutronAgent(object):
# Set RPC API version to 1.0 by default.
RPC_API_VERSION = '1.0'
def _setup_rpc(self):
self.agent_id = 'hyperv_%s' % platform.node()
self.topic = topics.AGENT
- self.plugin_rpc = agent_rpc.PluginApi(topics.PLUGIN)
+ self.plugin_rpc = HyperVPluginApi(topics.PLUGIN)
self.state_rpc = agent_rpc.PluginReportStateAPI(topics.PLUGIN)
self.connection = agent_rpc.create_consumers(self.dispatcher,
self.topic,
consumers)
+
+ self.sec_groups_agent = HyperVSecurityAgent(
+ self.context, self.plugin_rpc)
report_interval = CONF.AGENT.report_interval
if report_interval:
heartbeat = loopingcall.LoopingCall(self._report_state)
def port_update(self, context, port=None, network_type=None,
segmentation_id=None, physical_network=None):
LOG.debug(_("port_update received"))
+ if 'security_groups' in port:
+ self.sec_groups_agent.refresh_firewall()
+
self._treat_vif_port(
port['id'], port['network_id'],
network_type, physical_network,
device_details['physical_network'],
device_details['segmentation_id'],
device_details['admin_state_up'])
+
+ self.sec_groups_agent.prepare_devices_filter(devices)
self.plugin_rpc.update_device_up(self.context,
device,
self.agent_id,
--- /dev/null
+#Copyright 2014 Cloudbase Solutions SRL
+#All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+# @author: Claudiu Belu, Cloudbase Solutions Srl
+
+from neutron.agent import firewall
+from neutron.openstack.common import log as logging
+from neutron.plugins.hyperv.agent import utilsfactory
+from neutron.plugins.hyperv.agent import utilsv2
+
+LOG = logging.getLogger(__name__)
+
+
+class HyperVSecurityGroupsDriver(firewall.FirewallDriver):
+ """Security Groups Driver.
+
+ Security Groups implementation for Hyper-V VMs.
+ """
+
+ _ACL_PROP_MAP = {
+ 'direction': {'ingress': utilsv2.HyperVUtilsV2._ACL_DIR_IN,
+ 'egress': utilsv2.HyperVUtilsV2._ACL_DIR_OUT},
+ 'ethertype': {'IPv4': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV4,
+ 'IPv6': utilsv2.HyperVUtilsV2._ACL_TYPE_IPV6},
+ 'default': "ANY",
+ 'address_default': {'IPv4': '0.0.0.0/0', 'IPv6': '::/0'}
+ }
+
+ def __init__(self):
+ self._utils = utilsfactory.get_hypervutils()
+ self._security_ports = {}
+
+ def prepare_port_filter(self, port):
+ LOG.debug('Creating port %s rules' % len(port['security_group_rules']))
+
+ # newly created port, add default rules.
+ if port['device'] not in self._security_ports:
+ LOG.debug('Creating default reject rules.')
+ self._utils.create_default_reject_all_rules(port['id'])
+
+ self._security_ports[port['device']] = port
+ self._create_port_rules(port['id'], port['security_group_rules'])
+
+ def _create_port_rules(self, port_id, rules):
+ for rule in rules:
+ param_map = self._create_param_map(rule)
+ try:
+ self._utils.create_security_rule(port_id, **param_map)
+ except Exception as ex:
+ LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
+ 'adding rule: %(rule)s'),
+ dict(hyperv_exeption=ex, rule=rule))
+
+ def _remove_port_rules(self, port_id, rules):
+ for rule in rules:
+ param_map = self._create_param_map(rule)
+ try:
+ self._utils.remove_security_rule(port_id, **param_map)
+ except Exception as ex:
+ LOG.error(_('Hyper-V Exception: %(hyperv_exeption)s while '
+ 'removing rule: %(rule)s'),
+ dict(hyperv_exeption=ex, rule=rule))
+
+ def _create_param_map(self, rule):
+ if 'port_range_min' in rule and 'port_range_max' in rule:
+ local_port = '%s-%s' % (rule['port_range_min'],
+ rule['port_range_max'])
+ else:
+ local_port = self._ACL_PROP_MAP['default']
+
+ return {
+ 'direction': self._ACL_PROP_MAP['direction'][rule['direction']],
+ 'acl_type': self._ACL_PROP_MAP['ethertype'][rule['ethertype']],
+ 'local_port': local_port,
+ 'protocol': self._get_rule_prop_or_default(rule, 'protocol'),
+ 'remote_address': self._get_rule_remote_address(rule)
+ }
+
+ def apply_port_filter(self, port):
+ LOG.info('Aplying port filter.')
+
+ def update_port_filter(self, port):
+ LOG.info('Updating port rules.')
+
+ if port['device'] not in self._security_ports:
+ self.prepare_port_filter(port)
+ return
+
+ old_port = self._security_ports[port['device']]
+ rules = old_port['security_group_rules']
+ param_port_rules = port['security_group_rules']
+
+ new_rules = [r for r in param_port_rules if r not in rules]
+ remove_rules = [r for r in rules if r not in param_port_rules]
+
+ LOG.info("Creating %s new rules, removing %s old rules." % (
+ len(new_rules), len(remove_rules)))
+
+ self._remove_port_rules(old_port['id'], remove_rules)
+ self._create_port_rules(port['id'], new_rules)
+
+ self._security_ports[port['device']] = port
+
+ def remove_port_filter(self, port):
+ LOG.info('Removing port filter')
+ self._security_ports.pop(port['device'], None)
+
+ @property
+ def ports(self):
+ return self._security_ports
+
+ def _get_rule_remote_address(self, rule):
+ if rule['direction'] is 'ingress':
+ ip_prefix = 'source_ip_prefix'
+ else:
+ ip_prefix = 'dest_ip_prefix'
+
+ if ip_prefix in rule:
+ return rule[ip_prefix]
+ return self._ACL_PROP_MAP['address_default'][rule['ethertype']]
+
+ def _get_rule_prop_or_default(self, rule, prop):
+ if prop in rule:
+ return rule[prop]
+ return self._ACL_PROP_MAP['default']
return map(int, version_str.split('.')) >= [major, minor, build]
-def _get_class(v1_class, v2_class, force_v1_flag):
- # V2 classes are supported starting from Hyper-V Server 2012 and
- # Windows Server 2012 (kernel version 6.2)
- if not force_v1_flag and _check_min_windows_version(6, 2):
- cls = v2_class
+def get_hypervutils():
+ # V1 virtualization namespace features are supported up to
+ # Windows Server / Hyper-V Server 2012
+ # V2 virtualization namespace features are supported starting with
+ # Windows Server / Hyper-V Server 2012
+ # Windows Server / Hyper-V Server 2012 R2 uses the V2 namespace and
+ # introduces additional features
+
+ force_v1_flag = CONF.hyperv.force_hyperv_utils_v1
+ if _check_min_windows_version(6, 3):
+ if force_v1_flag:
+ LOG.warning('V1 virtualization namespace no longer supported on '
+ 'Windows Server / Hyper-V Server 2012 R2 or above.')
+ cls = utilsv2.HyperVUtilsV2R2
+ elif not force_v1_flag and _check_min_windows_version(6, 2):
+ cls = utilsv2.HyperVUtilsV2
else:
- cls = v1_class
+ cls = utils.HyperVUtils
LOG.debug(_("Loading class: %(module_name)s.%(class_name)s"),
{'module_name': cls.__module__, 'class_name': cls.__name__})
- return cls
-
-
-def get_hypervutils():
- return _get_class(utils.HyperVUtils, utilsv2.HyperVUtilsV2,
- CONF.hyperv.force_hyperv_utils_v1)()
+ return cls()
_ETHERNET_SWITCH_PORT = 'Msvm_EthernetSwitchPort'
_PORT_ALLOC_SET_DATA = 'Msvm_EthernetPortAllocationSettingData'
_PORT_VLAN_SET_DATA = 'Msvm_EthernetSwitchPortVlanSettingData'
+ _PORT_SECURITY_SET_DATA = 'Msvm_EthernetSwitchPortSecuritySettingData'
_PORT_ALLOC_ACL_SET_DATA = 'Msvm_EthernetSwitchPortAclSettingData'
+ _PORT_EXT_ACL_SET_DATA = _PORT_ALLOC_ACL_SET_DATA
_LAN_ENDPOINT = 'Msvm_LANEndpoint'
_STATE_DISABLED = 3
_OPERATION_MODE_ACCESS = 1
_ACL_DIR_IN = 1
_ACL_DIR_OUT = 2
+
_ACL_TYPE_IPV4 = 2
_ACL_TYPE_IPV6 = 3
+
+ _ACL_ACTION_ALLOW = 1
+ _ACL_ACTION_DENY = 2
_ACL_ACTION_METER = 3
+
_ACL_APPLICABILITY_LOCAL = 1
+ _ACL_APPLICABILITY_REMOTE = 2
+
+ _ACL_DEFAULT = 'ANY'
+ _IPV4_ANY = '0.0.0.0/0'
+ _IPV6_ANY = '::/0'
+ _TCP_PROTOCOL = 'tcp'
+ _UDP_PROTOCOL = 'udp'
+ _MAX_WEIGHT = 65500
_wmi_namespace = '//./root/virtualization/v2'
element.path_(), [res_setting_data.GetText_(1)])
self._check_job_status(ret_val, job_path)
+ def _remove_virt_feature(self, feature_resource):
+ vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
+ (job_path, ret_val) = vs_man_svc.RemoveFeatureSettings(
+ FeatureSettings=[feature_resource.path_()])
+ self._check_job_status(ret_val, job_path)
+
def disconnect_switch_port(
self, vswitch_name, switch_port_name, delete_port):
"""Disconnects the switch port."""
port_alloc, found = self._get_switch_port_allocation(switch_port_name)
if not found:
raise utils.HyperVException(
- msg=_('Port Alloc not found: %s') % switch_port_name)
+ msg=_('Port Allocation not found: %s') % switch_port_name)
vs_man_svc = self._conn.Msvm_VirtualSystemManagementService()[0]
vlan_settings = self._get_vlan_setting_data_from_port_alloc(port_alloc)
acl.Action = self._ACL_ACTION_METER
acl.Applicability = self._ACL_APPLICABILITY_LOCAL
self._add_virt_feature(port, acl)
+
+ def create_security_rule(self, switch_port_name, direction, acl_type,
+ local_port, protocol, remote_address):
+ port, found = self._get_switch_port_allocation(switch_port_name, False)
+ if not found:
+ return
+
+ # Add the ACLs only if they don't already exist
+ acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+ weight = self._get_new_weight(acls)
+ self._bind_security_rule(
+ port, direction, acl_type, self._ACL_ACTION_ALLOW, local_port,
+ protocol, remote_address, weight)
+
+ def remove_security_rule(self, switch_port_name, direction, acl_type,
+ local_port, protocol, remote_address):
+ port, found = self._get_switch_port_allocation(switch_port_name, False)
+ if not found:
+ # Port not found. It happens when the VM was already deleted.
+ return
+
+ acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+ filtered_acls = self._filter_security_acls(
+ acls, self._ACL_ACTION_ALLOW, direction, acl_type, local_port,
+ protocol, remote_address)
+
+ for acl in filtered_acls:
+ self._remove_virt_feature(acl)
+
+ def create_default_reject_all_rules(self, switch_port_name):
+ port, found = self._get_switch_port_allocation(switch_port_name, False)
+ if not found:
+ raise utils.HyperVException(
+ msg=_('Port Allocation not found: %s') % switch_port_name)
+
+ acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+ filtered_acls = [v for v in acls if v.Action == self._ACL_ACTION_DENY]
+
+ # 2 directions x 2 address types x 2 protocols = 8 ACLs
+ if len(filtered_acls) >= 8:
+ return
+
+ for acl in filtered_acls:
+ self._remove_virt_feature(acl)
+
+ weight = 0
+ ipv4_pair = (self._ACL_TYPE_IPV4, self._IPV4_ANY)
+ ipv6_pair = (self._ACL_TYPE_IPV6, self._IPV6_ANY)
+ for direction in [self._ACL_DIR_IN, self._ACL_DIR_OUT]:
+ for acl_type, address in [ipv4_pair, ipv6_pair]:
+ for protocol in [self._TCP_PROTOCOL, self._UDP_PROTOCOL]:
+ self._bind_security_rule(
+ port, direction, acl_type, self._ACL_ACTION_DENY,
+ self._ACL_DEFAULT, protocol, address, weight)
+ weight += 1
+
+ def _bind_security_rule(self, port, direction, acl_type, action,
+ local_port, protocol, remote_address, weight):
+ acls = port.associators(wmi_result_class=self._PORT_EXT_ACL_SET_DATA)
+ filtered_acls = self._filter_security_acls(
+ acls, action, direction, acl_type, local_port, protocol,
+ remote_address)
+
+ for acl in filtered_acls:
+ self._remove_virt_feature(acl)
+
+ acl = self._create_security_acl(
+ direction, acl_type, action, local_port, protocol, remote_address,
+ weight)
+
+ self._add_virt_feature(port, acl)
+
+ def _create_acl(self, direction, acl_type, action):
+ acl = self._get_default_setting_data(self._PORT_ALLOC_ACL_SET_DATA)
+ acl.set(Direction=direction,
+ AclType=acl_type,
+ Action=action,
+ Applicability=self._ACL_APPLICABILITY_LOCAL)
+ return acl
+
+ def _create_security_acl(self, direction, acl_type, action, local_port,
+ protocol, remote_ip_address, weight):
+ acl = self._create_acl(direction, acl_type, action)
+ (remote_address, remote_prefix_length) = remote_ip_address.split('/')
+ acl.set(Applicability=self._ACL_APPLICABILITY_REMOTE,
+ RemoteAddress=remote_address,
+ RemoteAddressPrefixLength=remote_prefix_length)
+ return acl
+
+ def _filter_acls(self, acls, action, direction, acl_type, remote_addr=""):
+ return [v for v in acls
+ if v.Action == action and
+ v.Direction == direction and
+ v.AclType == acl_type and
+ v.RemoteAddress == remote_addr]
+
+ def _filter_security_acls(self, acls, acl_action, direction, acl_type,
+ local_port, protocol, remote_addr=""):
+ (remote_address, remote_prefix_length) = remote_addr.split('/')
+ remote_prefix_length = int(remote_prefix_length)
+
+ return [v for v in acls
+ if v.Direction == direction and
+ v.Action in [self._ACL_ACTION_ALLOW, self._ACL_ACTION_DENY] and
+ v.AclType == acl_type and
+ v.RemoteAddress == remote_address and
+ v.RemoteAddressPrefixLength == remote_prefix_length]
+
+ def _get_new_weight(self, acls):
+ return 0
+
+
+class HyperVUtilsV2R2(HyperVUtilsV2):
+ _PORT_EXT_ACL_SET_DATA = 'Msvm_EthernetSwitchPortExtendedAclSettingData'
+ _MAX_WEIGHT = 65500
+
+ def create_security_rule(self, switch_port_name, direction, acl_type,
+ local_port, protocol, remote_address):
+ protocols = [protocol]
+ if protocol is self._ACL_DEFAULT:
+ protocols = [self._TCP_PROTOCOL, self._UDP_PROTOCOL]
+
+ for proto in protocols:
+ super(HyperVUtilsV2R2, self).create_security_rule(
+ switch_port_name, direction, acl_type, local_port,
+ proto, remote_address)
+
+ def remove_security_rule(self, switch_port_name, direction, acl_type,
+ local_port, protocol, remote_address):
+ protocols = [protocol]
+ if protocol is self._ACL_DEFAULT:
+ protocols = ['tcp', 'udp']
+
+ for proto in protocols:
+ super(HyperVUtilsV2R2, self).remove_security_rule(
+ switch_port_name, direction, acl_type,
+ local_port, proto, remote_address)
+
+ def _create_security_acl(self, direction, acl_type, action, local_port,
+ protocol, remote_addr, weight):
+ acl = self._get_default_setting_data(self._PORT_EXT_ACL_SET_DATA)
+ acl.set(Direction=direction,
+ Action=action,
+ LocalPort=str(local_port),
+ Protocol=protocol,
+ RemoteIPAddress=remote_addr,
+ IdleSessionTimeout=0,
+ Weight=weight)
+ return acl
+
+ def _filter_security_acls(self, acls, action, direction, acl_type,
+ local_port, protocol, remote_addr=""):
+ return [v for v in acls
+ if v.Action == action and
+ v.Direction == direction and
+ v.LocalPort in [str(local_port), self._ACL_DEFAULT] and
+ v.Protocol in [protocol] and
+ v.RemoteIPAddress == remote_addr]
+
+ def _get_new_weight(self, acls):
+ if not acls:
+ return self._MAX_WEIGHT - 1
+
+ weights = [a.Weight for a in acls]
+ min_weight = min(weights)
+ for weight in range(min_weight, self._MAX_WEIGHT):
+ if weight not in weights:
+ return weight
+
+ return min_weight - 1
self.agent = hyperv_neutron_agent.HyperVNeutronAgent()
self.agent.plugin_rpc = mock.Mock()
+ self.agent.sec_groups_agent = mock.MagicMock()
self.agent.context = mock.Mock()
self.agent.agent_id = mock.Mock()
--- /dev/null
+# Copyright 2014 Cloudbase Solutions SRL
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+# @author: Claudiu Belu, Cloudbase Solutions Srl
+
+"""
+Unit tests for the Hyper-V Security Groups Driver.
+"""
+
+import mock
+from oslo.config import cfg
+
+from neutron.plugins.hyperv.agent import security_groups_driver as sg_driver
+from neutron.plugins.hyperv.agent import utilsfactory
+from neutron.tests import base
+
+CONF = cfg.CONF
+
+
+class TestHyperVSecurityGroupsDriver(base.BaseTestCase):
+
+ _FAKE_DEVICE = 'fake_device'
+ _FAKE_ID = 'fake_id'
+ _FAKE_DIRECTION = 'ingress'
+ _FAKE_ETHERTYPE = 'IPv4'
+ _FAKE_ETHERTYPE_IPV6 = 'IPv6'
+ _FAKE_DEST_IP_PREFIX = 'fake_dest_ip_prefix'
+ _FAKE_SOURCE_IP_PREFIX = 'fake_source_ip_prefix'
+ _FAKE_PARAM_NAME = 'fake_param_name'
+ _FAKE_PARAM_VALUE = 'fake_param_value'
+
+ _FAKE_PORT_MIN = 9001
+ _FAKE_PORT_MAX = 9011
+
+ def setUp(self):
+ super(TestHyperVSecurityGroupsDriver, self).setUp()
+ self._mock_windows_version = mock.patch.object(utilsfactory,
+ 'get_hypervutils')
+ self._mock_windows_version.start()
+ self.addCleanup(mock.patch.stopall)
+ self._driver = sg_driver.HyperVSecurityGroupsDriver()
+ self._driver._utils = mock.MagicMock()
+
+ @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+ '.HyperVSecurityGroupsDriver._create_port_rules')
+ def test_prepare_port_filter(self, mock_create_rules):
+ mock_port = self._get_port()
+ mock_utils_method = self._driver._utils.create_default_reject_all_rules
+ self._driver.prepare_port_filter(mock_port)
+
+ self.assertEqual(mock_port,
+ self._driver._security_ports[self._FAKE_DEVICE])
+ mock_utils_method.assert_called_once_with(self._FAKE_ID)
+ self._driver._create_port_rules.assert_called_once_with(
+ self._FAKE_ID, mock_port['security_group_rules'])
+
+ def test_update_port_filter(self):
+ mock_port = self._get_port()
+ new_mock_port = self._get_port()
+ new_mock_port['id'] += '2'
+ new_mock_port['security_group_rules'][0]['ethertype'] += "2"
+
+ self._driver._security_ports[mock_port['device']] = mock_port
+ self._driver._create_port_rules = mock.MagicMock()
+ self._driver._remove_port_rules = mock.MagicMock()
+ self._driver.update_port_filter(new_mock_port)
+
+ self._driver._remove_port_rules.assert_called_once_with(
+ mock_port['id'], mock_port['security_group_rules'])
+ self._driver._create_port_rules.assert_called_once_with(
+ new_mock_port['id'], new_mock_port['security_group_rules'])
+ self.assertEqual(new_mock_port,
+ self._driver._security_ports[new_mock_port['device']])
+
+ @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+ '.HyperVSecurityGroupsDriver.prepare_port_filter')
+ def test_update_port_filter_new_port(self, mock_method):
+ mock_port = self._get_port()
+ self._driver.prepare_port_filter = mock.MagicMock()
+ self._driver.update_port_filter(mock_port)
+
+ self._driver.prepare_port_filter.assert_called_once_with(mock_port)
+
+ def test_remove_port_filter(self):
+ mock_port = self._get_port()
+ self._driver._security_ports[mock_port['device']] = mock_port
+ self._driver.remove_port_filter(mock_port)
+ self.assertFalse(mock_port['device'] in self._driver._security_ports)
+
+ def test_create_port_rules_exception(self):
+ fake_rule = self._create_security_rule()
+ self._driver._utils.create_security_rule.side_effect = Exception(
+ 'Generated Exception for testing.')
+ self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
+
+ def test_create_param_map(self):
+ fake_rule = self._create_security_rule()
+ self._driver._get_rule_remote_address = mock.MagicMock(
+ return_value=self._FAKE_SOURCE_IP_PREFIX)
+ actual = self._driver._create_param_map(fake_rule)
+ expected = {
+ 'direction': self._driver._ACL_PROP_MAP[
+ 'direction'][self._FAKE_DIRECTION],
+ 'acl_type': self._driver._ACL_PROP_MAP[
+ 'ethertype'][self._FAKE_ETHERTYPE],
+ 'local_port': '%s-%s' % (self._FAKE_PORT_MIN, self._FAKE_PORT_MAX),
+ 'protocol': self._driver._ACL_PROP_MAP['default'],
+ 'remote_address': self._FAKE_SOURCE_IP_PREFIX
+ }
+
+ self.assertEqual(expected, actual)
+
+ @mock.patch('neutron.plugins.hyperv.agent.security_groups_driver'
+ '.HyperVSecurityGroupsDriver._create_param_map')
+ def test_create_port_rules(self, mock_method):
+ fake_rule = self._create_security_rule()
+ mock_method.return_value = {
+ self._FAKE_PARAM_NAME: self._FAKE_PARAM_VALUE}
+ self._driver._create_port_rules(self._FAKE_ID, [fake_rule])
+
+ self._driver._utils.create_security_rule.assert_called_once_with(
+ self._FAKE_ID, fake_param_name=self._FAKE_PARAM_VALUE)
+
+ def test_convert_any_address_to_same_ingress(self):
+ rule = self._create_security_rule()
+ actual = self._driver._get_rule_remote_address(rule)
+ self.assertEqual(self._FAKE_SOURCE_IP_PREFIX, actual)
+
+ def test_convert_any_address_to_same_egress(self):
+ rule = self._create_security_rule()
+ rule['direction'] += '2'
+ actual = self._driver._get_rule_remote_address(rule)
+ self.assertEqual(self._FAKE_DEST_IP_PREFIX, actual)
+
+ def test_convert_any_address_to_ipv4(self):
+ rule = self._create_security_rule()
+ del rule['source_ip_prefix']
+ actual = self._driver._get_rule_remote_address(rule)
+ self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv4'],
+ actual)
+
+ def test_convert_any_address_to_ipv6(self):
+ rule = self._create_security_rule()
+ del rule['source_ip_prefix']
+ rule['ethertype'] = self._FAKE_ETHERTYPE_IPV6
+ actual = self._driver._get_rule_remote_address(rule)
+ self.assertEqual(self._driver._ACL_PROP_MAP['address_default']['IPv6'],
+ actual)
+
+ def _get_port(self):
+ return {
+ 'device': self._FAKE_DEVICE,
+ 'id': self._FAKE_ID,
+ 'security_group_rules': [self._create_security_rule()]
+ }
+
+ def _create_security_rule(self):
+ return {
+ 'direction': self._FAKE_DIRECTION,
+ 'ethertype': self._FAKE_ETHERTYPE,
+ 'dest_ip_prefix': self._FAKE_DEST_IP_PREFIX,
+ 'source_ip_prefix': self._FAKE_SOURCE_IP_PREFIX,
+ 'port_range_min': self._FAKE_PORT_MIN,
+ 'port_range_max': self._FAKE_PORT_MAX
+ }
class TestHyperVUtilsFactory(base.BaseTestCase):
+ def test_get_hypervutils_v2_r2(self):
+ self._test_returned_class(utilsv2.HyperVUtilsV2R2, True, '6.3.0')
+
def test_get_hypervutils_v2(self):
self._test_returned_class(utilsv2.HyperVUtilsV2, False, '6.2.0')
_FAKE_CLASS_NAME = "fake_class_name"
_FAKE_ELEMENT_NAME = "fake_element_name"
+ _FAKE_ACL_ACT = 'fake_acl_action'
+ _FAKE_ACL_DIR = 'fake_acl_dir'
+ _FAKE_ACL_TYPE = 'fake_acl_type'
+ _FAKE_LOCAL_PORT = 'fake_local_port'
+ _FAKE_PROTOCOL = 'fake_port_protocol'
+ _FAKE_REMOTE_ADDR = '0.0.0.0/0'
+ _FAKE_WEIGHT = 'fake_weight'
+
def setUp(self):
super(TestHyperVUtilsV2, self).setUp()
self._utils = utilsv2.HyperVUtilsV2()
mock_svc.RemoveResourceSettings.assert_called_with(
ResourceSettings=[self._FAKE_RES_PATH])
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._check_job_status')
+ def test_remove_virt_feature(self, mock_check_job_status):
+ mock_svc = self._utils._conn.Msvm_VirtualSystemManagementService()[0]
+ mock_svc.RemoveFeatureSettings.return_value = (self._FAKE_JOB_PATH,
+ self._FAKE_RET_VAL)
+ mock_res_setting_data = mock.MagicMock()
+ mock_res_setting_data.path_.return_value = self._FAKE_RES_PATH
+
+ self._utils._remove_virt_feature(mock_res_setting_data)
+
+ mock_svc.RemoveFeatureSettings.assert_called_with(
+ FeatureSettings=[self._FAKE_RES_PATH])
+
def test_disconnect_switch_port_delete_port(self):
self._test_disconnect_switch_port(True)
self.assertEqual(4, len(self._utils._add_virt_feature.mock_calls))
self._utils._add_virt_feature.assert_called_with(
mock_port, mock_acl)
+
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._remove_virt_feature')
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._bind_security_rule')
+ def test_create_default_reject_all_rules(self, mock_bind, mock_remove):
+ (m_port, m_acl) = self._setup_security_rule_test()
+ m_acl.Action = self._utils._ACL_ACTION_DENY
+ self._utils.create_default_reject_all_rules(self._FAKE_PORT_NAME)
+
+ calls = []
+ ipv4_pair = (self._utils._ACL_TYPE_IPV4, self._utils._IPV4_ANY)
+ ipv6_pair = (self._utils._ACL_TYPE_IPV6, self._utils._IPV6_ANY)
+ for direction in [self._utils._ACL_DIR_IN, self._utils._ACL_DIR_OUT]:
+ for acl_type, address in [ipv4_pair, ipv6_pair]:
+ for protocol in [self._utils._TCP_PROTOCOL,
+ self._utils._UDP_PROTOCOL]:
+ calls.append(mock.call(m_port, direction, acl_type,
+ self._utils._ACL_ACTION_DENY,
+ self._utils._ACL_DEFAULT,
+ protocol, address, mock.ANY))
+
+ self._utils._remove_virt_feature.assert_called_once_with(m_acl)
+ self._utils._bind_security_rule.assert_has_calls(calls)
+
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._remove_virt_feature')
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._add_virt_feature')
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._create_security_acl')
+ def test_bind_security_rule(self, mock_create_acl, mock_add, mock_remove):
+ (m_port, m_acl) = self._setup_security_rule_test()
+ mock_create_acl.return_value = m_acl
+
+ self._utils._bind_security_rule(
+ m_port, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+ self._FAKE_ACL_ACT, self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL,
+ self._FAKE_REMOTE_ADDR, self._FAKE_WEIGHT)
+
+ self._utils._add_virt_feature.assert_called_once_with(m_port, m_acl)
+
+ @mock.patch('neutron.plugins.hyperv.agent.utilsv2.HyperVUtilsV2'
+ '._remove_virt_feature')
+ def test_remove_security_rule(self, mock_remove_feature):
+ mock_acl = self._setup_security_rule_test()[1]
+ self._utils.remove_security_rule(
+ self._FAKE_PORT_NAME, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+ self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
+ self._utils._remove_virt_feature.assert_called_once_with(mock_acl)
+
+ def _setup_security_rule_test(self):
+ mock_port = mock.MagicMock()
+ mock_acl = mock.MagicMock()
+ mock_port.associators.return_value = [mock_acl]
+
+ self._utils._get_switch_port_allocation = mock.MagicMock(return_value=(
+ mock_port, True))
+ self._utils._filter_security_acls = mock.MagicMock(
+ return_value=[mock_acl])
+
+ return (mock_port, mock_acl)
+
+ def test_filter_acls(self):
+ mock_acl = mock.MagicMock()
+ mock_acl.Action = self._FAKE_ACL_ACT
+ mock_acl.Applicability = self._utils._ACL_APPLICABILITY_LOCAL
+ mock_acl.Direction = self._FAKE_ACL_DIR
+ mock_acl.AclType = self._FAKE_ACL_TYPE
+ mock_acl.RemoteAddress = self._FAKE_REMOTE_ADDR
+
+ acls = [mock_acl, mock_acl]
+ good_acls = self._utils._filter_acls(
+ acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR,
+ self._FAKE_ACL_TYPE, self._FAKE_REMOTE_ADDR)
+ bad_acls = self._utils._filter_acls(
+ acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE)
+
+ self.assertEqual(acls, good_acls)
+ self.assertEqual([], bad_acls)
+
+
+class TestHyperVUtilsV2R2(base.BaseTestCase):
+ _FAKE_ACL_ACT = 'fake_acl_action'
+ _FAKE_ACL_DIR = 'fake_direction'
+ _FAKE_ACL_TYPE = 'fake_acl_type'
+ _FAKE_LOCAL_PORT = 'fake_local_port'
+ _FAKE_PROTOCOL = 'fake_port_protocol'
+ _FAKE_REMOTE_ADDR = '10.0.0.0/0'
+
+ def setUp(self):
+ super(TestHyperVUtilsV2R2, self).setUp()
+ self._utils = utilsv2.HyperVUtilsV2R2()
+
+ def test_filter_security_acls(self):
+ self._test_filter_security_acls(
+ self._FAKE_LOCAL_PORT, self._FAKE_PROTOCOL, self._FAKE_REMOTE_ADDR)
+
+ def test_filter_security_acls_default(self):
+ default = self._utils._ACL_DEFAULT
+ self._test_filter_security_acls(
+ default, default, self._FAKE_REMOTE_ADDR)
+
+ def _test_filter_security_acls(self, local_port, protocol, remote_addr):
+ mock_acl = mock.MagicMock()
+ mock_acl.Action = self._utils._ACL_ACTION_ALLOW
+ mock_acl.Direction = self._FAKE_ACL_DIR
+ mock_acl.LocalPort = local_port
+ mock_acl.Protocol = protocol
+ mock_acl.RemoteIPAddress = remote_addr
+
+ acls = [mock_acl, mock_acl]
+ good_acls = self._utils._filter_security_acls(
+ acls, mock_acl.Action, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+ local_port, protocol, remote_addr)
+ bad_acls = self._utils._filter_security_acls(
+ acls, self._FAKE_ACL_ACT, self._FAKE_ACL_DIR, self._FAKE_ACL_TYPE,
+ local_port, protocol, remote_addr)
+
+ self.assertEqual(acls, good_acls)
+ self.assertEqual([], bad_acls)
+
+ def test_get_new_weight(self):
+ mockacl1 = mock.MagicMock()
+ mockacl1.Weight = self._utils._MAX_WEIGHT - 1
+ mockacl2 = mock.MagicMock()
+ mockacl2.Weight = self._utils._MAX_WEIGHT - 3
+ self.assertEqual(self._utils._MAX_WEIGHT - 2,
+ self._utils._get_new_weight([mockacl1, mockacl2]))
+
+ def test_get_new_weight_no_acls(self):
+ self.assertEqual(self._utils._MAX_WEIGHT - 1,
+ self._utils._get_new_weight([]))