From 47dd65cf986d712e9c6ca5dcf4420dfc44900b66 Mon Sep 17 00:00:00 2001 From: Dane LeBlanc Date: Tue, 14 Apr 2015 09:18:18 -0400 Subject: [PATCH] Reuse caller's session in ML2 DB methods This patch changes the get_port_from_device_mac() and get_sg_ids_grouped_by_port() methods in ML2 db.py module so that they do not create a new database session (via get_session()), but instead reuse the session associated with the caller's context. In order to make the session that is associated with the caller's context available to these ML2 DB methods, the get_ports_from_devices plugin API in securitygroups_rps_base.py needs to be modified so that the context can be passed down to the ML2 plugin. (A similar change is made to the get_port_from_device plugin API for consistency.) Change-Id: I3f990895887e156de929bd7ac3732df114dd4a4b Closes-Bug: 1441205 --- .../api/rpc/handlers/securitygroups_rpc.py | 8 +++---- neutron/db/securitygroups_rpc_base.py | 7 +++--- neutron/plugins/ml2/db.py | 24 +++++++++---------- neutron/plugins/ml2/plugin.py | 13 +++++----- neutron/plugins/ml2/rpc.py | 6 ++--- neutron/plugins/oneconvergence/plugin.py | 2 +- .../unit/agent/test_securitygroups_rpc.py | 2 +- neutron/tests/unit/plugins/ml2/test_db.py | 3 ++- neutron/tests/unit/plugins/ml2/test_plugin.py | 13 ++++++---- neutron/tests/unit/plugins/ml2/test_rpc.py | 22 ++++++++--------- .../unit/plugins/ml2/test_security_group.py | 20 +++++++++------- .../oneconvergence/test_security_group.py | 5 ++-- 12 files changed, 66 insertions(+), 59 deletions(-) diff --git a/neutron/api/rpc/handlers/securitygroups_rpc.py b/neutron/api/rpc/handlers/securitygroups_rpc.py index 58d9c7d3d..d63a13351 100644 --- a/neutron/api/rpc/handlers/securitygroups_rpc.py +++ b/neutron/api/rpc/handlers/securitygroups_rpc.py @@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object): def plugin(self): return manager.NeutronManager.get_plugin() - def _get_devices_info(self, devices): + def _get_devices_info(self, context, devices): return dict( (port['id'], port) - for port in self.plugin.get_ports_from_devices(devices) + for port in self.plugin.get_ports_from_devices(context, devices) if port and not port['device_owner'].startswith('network:') ) @@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object): :returns: port correspond to the devices with security group rules """ devices_info = kwargs.get('devices') - ports = self._get_devices_info(devices_info) + ports = self._get_devices_info(context, devices_info) return self.plugin.security_group_rules_for_ports(context, ports) def security_group_info_for_devices(self, context, **kwargs): @@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object): Note that sets are serialized into lists by rpc code. """ devices_info = kwargs.get('devices') - ports = self._get_devices_info(devices_info) + ports = self._get_devices_info(context, devices_info) return self.plugin.security_group_info_for_ports(context, ports) diff --git a/neutron/db/securitygroups_rpc_base.py b/neutron/db/securitygroups_rpc_base.py index c47493599..e6005025e 100644 --- a/neutron/db/securitygroups_rpc_base.py +++ b/neutron/db/securitygroups_rpc_base.py @@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)} class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): """Mixin class to add agent-based security group implementation.""" - def get_port_from_device(self, device): + def get_port_from_device(self, context, device): """Get port dict from device name on an agent. Subclass must provide this method or get_ports_from_devices. @@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): "or get_ports_from_devices.") % self.__class__.__name__) - def get_ports_from_devices(self, devices): + def get_ports_from_devices(self, context, devices): """Bulk method of get_port_from_device. Subclasses may override this to provide better performance for DB queries, backend calls, etc. """ - return [self.get_port_from_device(device) for device in devices] + return [self.get_port_from_device(context, device) + for device in devices] def create_security_group_rule(self, context, security_group_rule): bulk_rule = {'security_group_rules': [security_group_rule]} diff --git a/neutron/plugins/ml2/db.py b/neutron/plugins/ml2/db.py index 9e4f8ca14..c6aef07d6 100644 --- a/neutron/plugins/ml2/db.py +++ b/neutron/plugins/ml2/db.py @@ -19,7 +19,6 @@ from sqlalchemy import or_ from sqlalchemy.orm import exc from neutron.common import constants as n_const -from neutron.db import api as db_api from neutron.db import models_v2 from neutron.db import securitygroups_db as sg_db from neutron.extensions import portbindings @@ -244,14 +243,14 @@ def get_port(session, port_id): return -def get_port_from_device_mac(device_mac): +def get_port_from_device_mac(context, device_mac): LOG.debug("get_port_from_device_mac() called for mac %s", device_mac) - session = db_api.get_session() - qry = session.query(models_v2.Port).filter_by(mac_address=device_mac) + qry = context.session.query(models_v2.Port).filter_by( + mac_address=device_mac) return qry.first() -def get_ports_and_sgs(port_ids): +def get_ports_and_sgs(context, port_ids): """Get ports from database with security group info.""" # break large queries into smaller parts @@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids): LOG.debug("Number of ports %(pcount)s exceeds the maximum per " "query %(maxp)s. Partitioning queries.", {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY}) - return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) + - get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:])) + return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) + + get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:])) LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids) if not port_ids: # if port_ids is empty, avoid querying to DB to ask it for nothing return [] - ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids) + ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids) return [make_port_dict_with_security_groups(port, sec_groups) for port, sec_groups in ports_to_sg_ids.iteritems()] -def get_sg_ids_grouped_by_port(port_ids): +def get_sg_ids_grouped_by_port(context, port_ids): sg_ids_grouped_by_port = {} - session = db_api.get_session() sg_binding_port = sg_db.SecurityGroupPortBinding.port_id - with session.begin(subtransactions=True): + with context.session.begin(subtransactions=True): # partial UUIDs must be individually matched with startswith. # full UUIDs may be matched directly in an IN statement partial_uuids = set(port_id for port_id in port_ids @@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids): if full_uuids: or_criteria.append(models_v2.Port.id.in_(full_uuids)) - query = session.query(models_v2.Port, - sg_db.SecurityGroupPortBinding.security_group_id) + query = context.session.query( + models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id) query = query.outerjoin(sg_db.SecurityGroupPortBinding, models_v2.Port.id == sg_binding_port) query = query.filter(or_(*or_criteria)) diff --git a/neutron/plugins/ml2/plugin.py b/neutron/plugins/ml2/plugin.py index 2f209db77..c1b4bde99 100644 --- a/neutron/plugins/ml2/plugin.py +++ b/neutron/plugins/ml2/plugin.py @@ -1451,11 +1451,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, port_host = db.get_port_binding_host(context.session, port_id) return (port_host == host) - def get_ports_from_devices(self, devices): - port_ids_to_devices = dict((self._device_to_port_id(device), device) - for device in devices) + def get_ports_from_devices(self, context, devices): + port_ids_to_devices = dict( + (self._device_to_port_id(context, device), device) + for device in devices) port_ids = port_ids_to_devices.keys() - ports = db.get_ports_and_sgs(port_ids) + ports = db.get_ports_and_sgs(context, port_ids) for port in ports: # map back to original requested id port_id = next((port_id for port_id in port_ids @@ -1465,7 +1466,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, return ports @staticmethod - def _device_to_port_id(device): + def _device_to_port_id(context, device): # REVISIT(rkukura): Consider calling into MechanismDrivers to # process device names, or having MechanismDrivers supply list # of device prefixes to strip. @@ -1475,7 +1476,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, # REVISIT(irenab): Consider calling into bound MD to # handle the get_device_details RPC if not uuidutils.is_uuid_like(device): - port = db.get_port_from_device_mac(device) + port = db.get_port_from_device_mac(context, device) if port: return port.id return device diff --git a/neutron/plugins/ml2/rpc.py b/neutron/plugins/ml2/rpc.py index bdbf4b510..6338642fc 100644 --- a/neutron/plugins/ml2/rpc.py +++ b/neutron/plugins/ml2/rpc.py @@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin): {'device': device, 'agent_id': agent_id, 'host': host}) plugin = manager.NeutronManager.get_plugin() - port_id = plugin._device_to_port_id(device) + port_id = plugin._device_to_port_id(rpc_context, device) port_context = plugin.get_bound_port_context(rpc_context, port_id, host, @@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin): "%(agent_id)s", {'device': device, 'agent_id': agent_id}) plugin = manager.NeutronManager.get_plugin() - port_id = plugin._device_to_port_id(device) + port_id = plugin._device_to_port_id(rpc_context, device) port_exists = True if (host and not plugin.port_bound_to_host(rpc_context, port_id, host)): @@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin): LOG.debug("Device %(device)s up at agent %(agent_id)s", {'device': device, 'agent_id': agent_id}) plugin = manager.NeutronManager.get_plugin() - port_id = plugin._device_to_port_id(device) + port_id = plugin._device_to_port_id(rpc_context, device) if (host and not plugin.port_bound_to_host(rpc_context, port_id, host)): LOG.debug("Device %(device)s not bound to the" diff --git a/neutron/plugins/oneconvergence/plugin.py b/neutron/plugins/oneconvergence/plugin.py index af6b5c1a6..50b425c84 100644 --- a/neutron/plugins/oneconvergence/plugin.py +++ b/neutron/plugins/oneconvergence/plugin.py @@ -56,7 +56,7 @@ IPv6 = 6 class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin): @staticmethod - def get_port_from_device(device): + def get_port_from_device(context, device): port = nvsd_db.get_port_from_device(device) if port: port['device'] = device diff --git a/neutron/tests/unit/agent/test_securitygroups_rpc.py b/neutron/tests/unit/agent/test_securitygroups_rpc.py index 8dd6c90b0..c07ad3519 100644 --- a/neutron/tests/unit/agent/test_securitygroups_rpc.py +++ b/neutron/tests/unit/agent/test_securitygroups_rpc.py @@ -94,7 +94,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin, self.notify_security_groups_member_updated(context, port) del self.devices[id] - def get_port_from_device(self, device): + def get_port_from_device(self, context, device): device = self.devices.get(device) if device: device['security_group_rules'] = [] diff --git a/neutron/tests/unit/plugins/ml2/test_db.py b/neutron/tests/unit/plugins/ml2/test_db.py index c34b82abf..db2c12337 100644 --- a/neutron/tests/unit/plugins/ml2/test_db.py +++ b/neutron/tests/unit/plugins/ml2/test_db.py @@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase): self._setup_neutron_network(network_id) port = self._setup_neutron_port(network_id, port_id) - observed_port = ml2_db.get_port_from_device_mac(port['mac_address']) + observed_port = ml2_db.get_port_from_device_mac(self.ctx, + port['mac_address']) self.assertEqual(port_id, observed_port.id) def test_get_locked_port_and_binding(self): diff --git a/neutron/tests/unit/plugins/ml2/test_plugin.py b/neutron/tests/unit/plugins/ml2/test_plugin.py index 21b90976a..77e00ef4f 100644 --- a/neutron/tests/unit/plugins/ml2/test_plugin.py +++ b/neutron/tests/unit/plugins/ml2/test_plugin.py @@ -614,23 +614,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase): ('qvo567890', '567890')] for device, expected in input_output: self.assertEqual(expected, - ml2_plugin.Ml2Plugin._device_to_port_id(device)) + ml2_plugin.Ml2Plugin._device_to_port_id( + self.context, device)) def test__device_to_port_id_mac_address(self): with self.port() as p: mac = p['port']['mac_address'] port_id = p['port']['id'] self.assertEqual(port_id, - ml2_plugin.Ml2Plugin._device_to_port_id(mac)) + ml2_plugin.Ml2Plugin._device_to_port_id( + self.context, mac)) def test__device_to_port_id_not_uuid_not_mac(self): dev = '1234567' - self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev)) + self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id( + self.context, dev)) def test__device_to_port_id_UUID(self): port_id = uuidutils.generate_uuid() - self.assertEqual(port_id, - ml2_plugin.Ml2Plugin._device_to_port_id(port_id)) + self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id( + self.context, port_id)) class TestMl2DvrPortsV2(TestMl2PortsV2): diff --git a/neutron/tests/unit/plugins/ml2/test_rpc.py b/neutron/tests/unit/plugins/ml2/test_rpc.py index 56cbdbcb9..cd7bcba14 100644 --- a/neutron/tests/unit/plugins/ml2/test_rpc.py +++ b/neutron/tests/unit/plugins/ml2/test_rpc.py @@ -75,14 +75,14 @@ class RpcCallbacksTestCase(base.BaseTestCase): self.plugin.get_bound_port_context.return_value = None self.assertEqual( {'device': 'fake_device'}, - self.callbacks.get_device_details('fake_context', + self.callbacks.get_device_details(mock.Mock(), device='fake_device')) def test_get_device_details_port_context_without_bounded_segment(self): self.plugin.get_bound_port_context().bottom_bound_segment = None self.assertEqual( {'device': 'fake_device'}, - self.callbacks.get_device_details('fake_context', + self.callbacks.get_device_details(mock.Mock(), device='fake_device')) def test_get_device_details_port_status_equal_new_status(self): @@ -99,7 +99,7 @@ class RpcCallbacksTestCase(base.BaseTestCase): port['admin_state_up'] = admin_state_up port['status'] = status self.plugin.update_port_status.reset_mock() - self.callbacks.get_device_details('fake_context') + self.callbacks.get_device_details(mock.Mock()) self.assertEqual(status == new_status, not self.plugin.update_port_status.called) @@ -109,7 +109,7 @@ class RpcCallbacksTestCase(base.BaseTestCase): self.plugin.get_bound_port_context().current = port self.plugin.get_bound_port_context().network.current = ( {"id": "fake_network"}) - self.callbacks.get_device_details('fake_context', host='fake_host', + self.callbacks.get_device_details(mock.Mock(), host='fake_host', cached_networks=cached_networks) self.assertTrue('fake_port' in cached_networks) @@ -119,7 +119,7 @@ class RpcCallbacksTestCase(base.BaseTestCase): port_context.current = port port_context.host = 'fake' self.plugin.update_port_status.reset_mock() - self.callbacks.get_device_details('fake_context', + self.callbacks.get_device_details(mock.Mock(), host='fake_host') self.assertFalse(self.plugin.update_port_status.called) @@ -128,7 +128,7 @@ class RpcCallbacksTestCase(base.BaseTestCase): port_context = self.plugin.get_bound_port_context() port_context.current = port self.plugin.update_port_status.reset_mock() - self.callbacks.get_device_details('fake_context') + self.callbacks.get_device_details(mock.Mock()) self.assertTrue(self.plugin.update_port_status.called) def test_get_devices_details_list(self): @@ -155,8 +155,8 @@ class RpcCallbacksTestCase(base.BaseTestCase): def _test_update_device_not_bound_to_host(self, func): self.plugin.port_bound_to_host.return_value = False self.plugin._device_to_port_id.return_value = 'fake_port_id' - res = func('fake_context', device='fake_device', host='fake_host') - self.plugin.port_bound_to_host.assert_called_once_with('fake_context', + res = func(mock.Mock(), device='fake_device', host='fake_host') + self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY, 'fake_port_id', 'fake_host') return res @@ -176,18 +176,18 @@ class RpcCallbacksTestCase(base.BaseTestCase): self.plugin._device_to_port_id.return_value = 'fake_port_id' self.assertEqual( {'device': 'fake_device', 'exists': False}, - self.callbacks.update_device_down('fake_context', + self.callbacks.update_device_down(mock.Mock(), device='fake_device', host='fake_host')) self.plugin.update_port_status.assert_called_once_with( - 'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN, + mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN, 'fake_host') def test_update_device_down_call_update_port_status_failed(self): self.plugin.update_port_status.side_effect = exc.StaleDataError self.assertEqual({'device': 'fake_device', 'exists': False}, self.callbacks.update_device_down( - 'fake_context', device='fake_device')) + mock.Mock(), device='fake_device')) class RpcApiTestCase(base.BaseTestCase): diff --git a/neutron/tests/unit/plugins/ml2/test_security_group.py b/neutron/tests/unit/plugins/ml2/test_security_group.py index 97f36a75b..772853938 100644 --- a/neutron/tests/unit/plugins/ml2/test_security_group.py +++ b/neutron/tests/unit/plugins/ml2/test_security_group.py @@ -19,6 +19,7 @@ import math import mock from neutron.common import constants as const +from neutron import context from neutron.extensions import securitygroup as ext_sg from neutron import manager from neutron.tests import tools @@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, test_sg_rpc.SGNotificationTestMixin): def setUp(self): super(TestMl2SecurityGroups, self).setUp() + self.ctx = context.get_admin_context() plugin = manager.NeutronManager.get_plugin() plugin.start_rpc_listeners() @@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, ] plugin = manager.NeutronManager.get_plugin() # should match full ID and starting chars - ports = plugin.get_ports_from_devices( + ports = plugin.get_ports_from_devices(self.ctx, [orig_ports[0]['id'], orig_ports[1]['id'][0:8], orig_ports[2]['id']]) self.assertEqual(len(orig_ports), len(ports)) @@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, def test_security_group_get_ports_from_devices_with_bad_id(self): plugin = manager.NeutronManager.get_plugin() - ports = plugin.get_ports_from_devices(['bad_device_id']) + ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id']) self.assertFalse(ports) def test_security_group_no_db_calls_with_no_ports(self): @@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, with mock.patch( 'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port' ) as get_mock: - self.assertFalse(plugin.get_ports_from_devices([])) + self.assertFalse(plugin.get_ports_from_devices(self.ctx, [])) self.assertFalse(get_mock.called) def test_large_port_count_broken_into_parts(self): @@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port', return_value={}), ) as (max_mock, get_mock): - plugin.get_ports_from_devices( + plugin.get_ports_from_devices(self.ctx, ['%s%s' % (const.TAP_DEVICE_PREFIX, i) for i in range(ports_to_query)]) - all_call_args = map(lambda x: x[1][0], get_mock.mock_calls) + all_call_args = map(lambda x: x[1][1], get_mock.mock_calls) last_call_args = all_call_args.pop() # all but last should be getting MAX_PORTS_PER_QUERY ports self.assertTrue( @@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, # have one matching 'IN' critiera for all of the IDs with contextlib.nested( mock.patch('neutron.plugins.ml2.db.or_'), - mock.patch('neutron.plugins.ml2.db.db_api.get_session') - ) as (or_mock, sess_mock): - qmock = sess_mock.return_value.query + mock.patch('sqlalchemy.orm.Session.query') + ) as (or_mock, qmock): fmock = qmock.return_value.outerjoin.return_value.filter # return no ports to exit the method early since we are mocking # the query fmock.return_value = [] - plugin.get_ports_from_devices([test_base._uuid(), + plugin.get_ports_from_devices(self.ctx, + [test_base._uuid(), test_base._uuid()]) # the or_ function should only have one argument or_mock.assert_called_once_with(mock.ANY) diff --git a/neutron/tests/unit/plugins/oneconvergence/test_security_group.py b/neutron/tests/unit/plugins/oneconvergence/test_security_group.py index dcde1915a..203d2a7e0 100644 --- a/neutron/tests/unit/plugins/oneconvergence/test_security_group.py +++ b/neutron/tests/unit/plugins/oneconvergence/test_security_group.py @@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase, req.get_response(self.api)) port_id = res['port']['id'] plugin = manager.NeutronManager.get_plugin() - port_dict = plugin.get_port_from_device(port_id) + port_dict = plugin.get_port_from_device(mock.Mock(), + port_id) self.assertEqual(port_id, port_dict['id']) self.assertEqual([security_group_id], port_dict[ext_sg.SECURITYGROUPS]) @@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase, def test_security_group_get_port_from_device_with_no_port(self): plugin = manager.NeutronManager.get_plugin() - port_dict = plugin.get_port_from_device('bad_device_id') + port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id') self.assertIsNone(port_dict) -- 2.45.2