def plugin(self):
return manager.NeutronManager.get_plugin()
- def _get_devices_info(self, devices):
+ def _get_devices_info(self, context, devices):
return dict(
(port['id'], port)
- for port in self.plugin.get_ports_from_devices(devices)
+ for port in self.plugin.get_ports_from_devices(context, devices)
if port and not port['device_owner'].startswith('network:')
)
:returns: port correspond to the devices with security group rules
"""
devices_info = kwargs.get('devices')
- ports = self._get_devices_info(devices_info)
+ ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_rules_for_ports(context, ports)
def security_group_info_for_devices(self, context, **kwargs):
Note that sets are serialized into lists by rpc code.
"""
devices_info = kwargs.get('devices')
- ports = self._get_devices_info(devices_info)
+ ports = self._get_devices_info(context, devices_info)
return self.plugin.security_group_info_for_ports(context, ports)
class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
"""Mixin class to add agent-based security group implementation."""
- def get_port_from_device(self, device):
+ def get_port_from_device(self, context, device):
"""Get port dict from device name on an agent.
Subclass must provide this method or get_ports_from_devices.
"or get_ports_from_devices.")
% self.__class__.__name__)
- def get_ports_from_devices(self, devices):
+ def get_ports_from_devices(self, context, devices):
"""Bulk method of get_port_from_device.
Subclasses may override this to provide better performance for DB
queries, backend calls, etc.
"""
- return [self.get_port_from_device(device) for device in devices]
+ return [self.get_port_from_device(context, device)
+ for device in devices]
def create_security_group_rule(self, context, security_group_rule):
bulk_rule = {'security_group_rules': [security_group_rule]}
from sqlalchemy.orm import exc
from neutron.common import constants as n_const
-from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.db import securitygroups_db as sg_db
from neutron.extensions import portbindings
return
-def get_port_from_device_mac(device_mac):
+def get_port_from_device_mac(context, device_mac):
LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
- session = db_api.get_session()
- qry = session.query(models_v2.Port).filter_by(mac_address=device_mac)
+ qry = context.session.query(models_v2.Port).filter_by(
+ mac_address=device_mac)
return qry.first()
-def get_ports_and_sgs(port_ids):
+def get_ports_and_sgs(context, port_ids):
"""Get ports from database with security group info."""
# break large queries into smaller parts
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
"query %(maxp)s. Partitioning queries.",
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
- return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
- get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
+ return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
+ get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
if not port_ids:
# if port_ids is empty, avoid querying to DB to ask it for nothing
return []
- ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
+ ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
return [make_port_dict_with_security_groups(port, sec_groups)
for port, sec_groups in ports_to_sg_ids.iteritems()]
-def get_sg_ids_grouped_by_port(port_ids):
+def get_sg_ids_grouped_by_port(context, port_ids):
sg_ids_grouped_by_port = {}
- session = db_api.get_session()
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
- with session.begin(subtransactions=True):
+ with context.session.begin(subtransactions=True):
# partial UUIDs must be individually matched with startswith.
# full UUIDs may be matched directly in an IN statement
partial_uuids = set(port_id for port_id in port_ids
if full_uuids:
or_criteria.append(models_v2.Port.id.in_(full_uuids))
- query = session.query(models_v2.Port,
- sg_db.SecurityGroupPortBinding.security_group_id)
+ query = context.session.query(
+ models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
models_v2.Port.id == sg_binding_port)
query = query.filter(or_(*or_criteria))
port_host = db.get_port_binding_host(context.session, port_id)
return (port_host == host)
- def get_ports_from_devices(self, devices):
- port_ids_to_devices = dict((self._device_to_port_id(device), device)
- for device in devices)
+ def get_ports_from_devices(self, context, devices):
+ port_ids_to_devices = dict(
+ (self._device_to_port_id(context, device), device)
+ for device in devices)
port_ids = port_ids_to_devices.keys()
- ports = db.get_ports_and_sgs(port_ids)
+ ports = db.get_ports_and_sgs(context, port_ids)
for port in ports:
# map back to original requested id
port_id = next((port_id for port_id in port_ids
return ports
@staticmethod
- def _device_to_port_id(device):
+ def _device_to_port_id(context, device):
# REVISIT(rkukura): Consider calling into MechanismDrivers to
# process device names, or having MechanismDrivers supply list
# of device prefixes to strip.
# REVISIT(irenab): Consider calling into bound MD to
# handle the get_device_details RPC
if not uuidutils.is_uuid_like(device):
- port = db.get_port_from_device_mac(device)
+ port = db.get_port_from_device_mac(context, device)
if port:
return port.id
return device
{'device': device, 'agent_id': agent_id, 'host': host})
plugin = manager.NeutronManager.get_plugin()
- port_id = plugin._device_to_port_id(device)
+ port_id = plugin._device_to_port_id(rpc_context, device)
port_context = plugin.get_bound_port_context(rpc_context,
port_id,
host,
"%(agent_id)s",
{'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin()
- port_id = plugin._device_to_port_id(device)
+ port_id = plugin._device_to_port_id(rpc_context, device)
port_exists = True
if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)):
LOG.debug("Device %(device)s up at agent %(agent_id)s",
{'device': device, 'agent_id': agent_id})
plugin = manager.NeutronManager.get_plugin()
- port_id = plugin._device_to_port_id(device)
+ port_id = plugin._device_to_port_id(rpc_context, device)
if (host and not plugin.port_bound_to_host(rpc_context,
port_id, host)):
LOG.debug("Device %(device)s not bound to the"
class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
@staticmethod
- def get_port_from_device(device):
+ def get_port_from_device(context, device):
port = nvsd_db.get_port_from_device(device)
if port:
port['device'] = device
self.notify_security_groups_member_updated(context, port)
del self.devices[id]
- def get_port_from_device(self, device):
+ def get_port_from_device(self, context, device):
device = self.devices.get(device)
if device:
device['security_group_rules'] = []
self._setup_neutron_network(network_id)
port = self._setup_neutron_port(network_id, port_id)
- observed_port = ml2_db.get_port_from_device_mac(port['mac_address'])
+ observed_port = ml2_db.get_port_from_device_mac(self.ctx,
+ port['mac_address'])
self.assertEqual(port_id, observed_port.id)
def test_get_locked_port_and_binding(self):
('qvo567890', '567890')]
for device, expected in input_output:
self.assertEqual(expected,
- ml2_plugin.Ml2Plugin._device_to_port_id(device))
+ ml2_plugin.Ml2Plugin._device_to_port_id(
+ self.context, device))
def test__device_to_port_id_mac_address(self):
with self.port() as p:
mac = p['port']['mac_address']
port_id = p['port']['id']
self.assertEqual(port_id,
- ml2_plugin.Ml2Plugin._device_to_port_id(mac))
+ ml2_plugin.Ml2Plugin._device_to_port_id(
+ self.context, mac))
def test__device_to_port_id_not_uuid_not_mac(self):
dev = '1234567'
- self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev))
+ self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
+ self.context, dev))
def test__device_to_port_id_UUID(self):
port_id = uuidutils.generate_uuid()
- self.assertEqual(port_id,
- ml2_plugin.Ml2Plugin._device_to_port_id(port_id))
+ self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
+ self.context, port_id))
class TestMl2DvrPortsV2(TestMl2PortsV2):
self.plugin.get_bound_port_context.return_value = None
self.assertEqual(
{'device': 'fake_device'},
- self.callbacks.get_device_details('fake_context',
+ self.callbacks.get_device_details(mock.Mock(),
device='fake_device'))
def test_get_device_details_port_context_without_bounded_segment(self):
self.plugin.get_bound_port_context().bottom_bound_segment = None
self.assertEqual(
{'device': 'fake_device'},
- self.callbacks.get_device_details('fake_context',
+ self.callbacks.get_device_details(mock.Mock(),
device='fake_device'))
def test_get_device_details_port_status_equal_new_status(self):
port['admin_state_up'] = admin_state_up
port['status'] = status
self.plugin.update_port_status.reset_mock()
- self.callbacks.get_device_details('fake_context')
+ self.callbacks.get_device_details(mock.Mock())
self.assertEqual(status == new_status,
not self.plugin.update_port_status.called)
self.plugin.get_bound_port_context().current = port
self.plugin.get_bound_port_context().network.current = (
{"id": "fake_network"})
- self.callbacks.get_device_details('fake_context', host='fake_host',
+ self.callbacks.get_device_details(mock.Mock(), host='fake_host',
cached_networks=cached_networks)
self.assertTrue('fake_port' in cached_networks)
port_context.current = port
port_context.host = 'fake'
self.plugin.update_port_status.reset_mock()
- self.callbacks.get_device_details('fake_context',
+ self.callbacks.get_device_details(mock.Mock(),
host='fake_host')
self.assertFalse(self.plugin.update_port_status.called)
port_context = self.plugin.get_bound_port_context()
port_context.current = port
self.plugin.update_port_status.reset_mock()
- self.callbacks.get_device_details('fake_context')
+ self.callbacks.get_device_details(mock.Mock())
self.assertTrue(self.plugin.update_port_status.called)
def test_get_devices_details_list(self):
def _test_update_device_not_bound_to_host(self, func):
self.plugin.port_bound_to_host.return_value = False
self.plugin._device_to_port_id.return_value = 'fake_port_id'
- res = func('fake_context', device='fake_device', host='fake_host')
- self.plugin.port_bound_to_host.assert_called_once_with('fake_context',
+ res = func(mock.Mock(), device='fake_device', host='fake_host')
+ self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
'fake_port_id',
'fake_host')
return res
self.plugin._device_to_port_id.return_value = 'fake_port_id'
self.assertEqual(
{'device': 'fake_device', 'exists': False},
- self.callbacks.update_device_down('fake_context',
+ self.callbacks.update_device_down(mock.Mock(),
device='fake_device',
host='fake_host'))
self.plugin.update_port_status.assert_called_once_with(
- 'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN,
+ mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
'fake_host')
def test_update_device_down_call_update_port_status_failed(self):
self.plugin.update_port_status.side_effect = exc.StaleDataError
self.assertEqual({'device': 'fake_device', 'exists': False},
self.callbacks.update_device_down(
- 'fake_context', device='fake_device'))
+ mock.Mock(), device='fake_device'))
class RpcApiTestCase(base.BaseTestCase):
import mock
from neutron.common import constants as const
+from neutron import context
from neutron.extensions import securitygroup as ext_sg
from neutron import manager
from neutron.tests import tools
test_sg_rpc.SGNotificationTestMixin):
def setUp(self):
super(TestMl2SecurityGroups, self).setUp()
+ self.ctx = context.get_admin_context()
plugin = manager.NeutronManager.get_plugin()
plugin.start_rpc_listeners()
]
plugin = manager.NeutronManager.get_plugin()
# should match full ID and starting chars
- ports = plugin.get_ports_from_devices(
+ ports = plugin.get_ports_from_devices(self.ctx,
[orig_ports[0]['id'], orig_ports[1]['id'][0:8],
orig_ports[2]['id']])
self.assertEqual(len(orig_ports), len(ports))
def test_security_group_get_ports_from_devices_with_bad_id(self):
plugin = manager.NeutronManager.get_plugin()
- ports = plugin.get_ports_from_devices(['bad_device_id'])
+ ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
self.assertFalse(ports)
def test_security_group_no_db_calls_with_no_ports(self):
with mock.patch(
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
) as get_mock:
- self.assertFalse(plugin.get_ports_from_devices([]))
+ self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
self.assertFalse(get_mock.called)
def test_large_port_count_broken_into_parts(self):
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
return_value={}),
) as (max_mock, get_mock):
- plugin.get_ports_from_devices(
+ plugin.get_ports_from_devices(self.ctx,
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
for i in range(ports_to_query)])
- all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
+ all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
last_call_args = all_call_args.pop()
# all but last should be getting MAX_PORTS_PER_QUERY ports
self.assertTrue(
# have one matching 'IN' critiera for all of the IDs
with contextlib.nested(
mock.patch('neutron.plugins.ml2.db.or_'),
- mock.patch('neutron.plugins.ml2.db.db_api.get_session')
- ) as (or_mock, sess_mock):
- qmock = sess_mock.return_value.query
+ mock.patch('sqlalchemy.orm.Session.query')
+ ) as (or_mock, qmock):
fmock = qmock.return_value.outerjoin.return_value.filter
# return no ports to exit the method early since we are mocking
# the query
fmock.return_value = []
- plugin.get_ports_from_devices([test_base._uuid(),
+ plugin.get_ports_from_devices(self.ctx,
+ [test_base._uuid(),
test_base._uuid()])
# the or_ function should only have one argument
or_mock.assert_called_once_with(mock.ANY)
req.get_response(self.api))
port_id = res['port']['id']
plugin = manager.NeutronManager.get_plugin()
- port_dict = plugin.get_port_from_device(port_id)
+ port_dict = plugin.get_port_from_device(mock.Mock(),
+ port_id)
self.assertEqual(port_id, port_dict['id'])
self.assertEqual([security_group_id],
port_dict[ext_sg.SECURITYGROUPS])
def test_security_group_get_port_from_device_with_no_port(self):
plugin = manager.NeutronManager.get_plugin()
- port_dict = plugin.get_port_from_device('bad_device_id')
+ port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
self.assertIsNone(port_dict)