]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Reuse caller's session in ML2 DB methods
authorDane LeBlanc <leblancd@cisco.com>
Tue, 14 Apr 2015 13:18:18 +0000 (09:18 -0400)
committerDane LeBlanc <leblancd@cisco.com>
Thu, 7 May 2015 21:26:25 +0000 (17:26 -0400)
This patch changes the get_port_from_device_mac() and
get_sg_ids_grouped_by_port() methods in ML2 db.py module so that
they do not create a new database session (via get_session()), but
instead reuse the session associated with the caller's context.

In order to make the session that is associated with the caller's
context available to these ML2 DB methods, the
get_ports_from_devices plugin API in securitygroups_rps_base.py
needs to be modified so that the context can be passed down to the
ML2 plugin. (A similar change is made to the get_port_from_device
plugin API for consistency.)

Change-Id: I3f990895887e156de929bd7ac3732df114dd4a4b
Closes-Bug: 1441205

12 files changed:
neutron/api/rpc/handlers/securitygroups_rpc.py
neutron/db/securitygroups_rpc_base.py
neutron/plugins/ml2/db.py
neutron/plugins/ml2/plugin.py
neutron/plugins/ml2/rpc.py
neutron/plugins/oneconvergence/plugin.py
neutron/tests/unit/agent/test_securitygroups_rpc.py
neutron/tests/unit/plugins/ml2/test_db.py
neutron/tests/unit/plugins/ml2/test_plugin.py
neutron/tests/unit/plugins/ml2/test_rpc.py
neutron/tests/unit/plugins/ml2/test_security_group.py
neutron/tests/unit/plugins/oneconvergence/test_security_group.py

index 58d9c7d3dcd2d88c62c340fdc3096c65047c2e24..d63a1335171a6cc7b61615eef7e3346ef6dd6cea 100644 (file)
@@ -76,10 +76,10 @@ class SecurityGroupServerRpcCallback(object):
     def plugin(self):
         return manager.NeutronManager.get_plugin()
 
-    def _get_devices_info(self, devices):
+    def _get_devices_info(self, context, devices):
         return dict(
             (port['id'], port)
-            for port in self.plugin.get_ports_from_devices(devices)
+            for port in self.plugin.get_ports_from_devices(context, devices)
             if port and not port['device_owner'].startswith('network:')
         )
 
@@ -93,7 +93,7 @@ class SecurityGroupServerRpcCallback(object):
         :returns: port correspond to the devices with security group rules
         """
         devices_info = kwargs.get('devices')
-        ports = self._get_devices_info(devices_info)
+        ports = self._get_devices_info(context, devices_info)
         return self.plugin.security_group_rules_for_ports(context, ports)
 
     def security_group_info_for_devices(self, context, **kwargs):
@@ -110,7 +110,7 @@ class SecurityGroupServerRpcCallback(object):
         Note that sets are serialized into lists by rpc code.
         """
         devices_info = kwargs.get('devices')
-        ports = self._get_devices_info(devices_info)
+        ports = self._get_devices_info(context, devices_info)
         return self.plugin.security_group_info_for_ports(context, ports)
 
 
index c47493599e165f494bd7742d40a3c8763af6e2fc..e6005025edca4e6358b3e2221b8bf126069a7d10 100644 (file)
@@ -38,7 +38,7 @@ DHCP_RULE_PORT = {4: (67, 68, q_const.IPv4), 6: (547, 546, q_const.IPv6)}
 class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
     """Mixin class to add agent-based security group implementation."""
 
-    def get_port_from_device(self, device):
+    def get_port_from_device(self, context, device):
         """Get port dict from device name on an agent.
 
         Subclass must provide this method or get_ports_from_devices.
@@ -59,13 +59,14 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
                                     "or get_ports_from_devices.")
                                   % self.__class__.__name__)
 
-    def get_ports_from_devices(self, devices):
+    def get_ports_from_devices(self, context, devices):
         """Bulk method of get_port_from_device.
 
         Subclasses may override this to provide better performance for DB
         queries, backend calls, etc.
         """
-        return [self.get_port_from_device(device) for device in devices]
+        return [self.get_port_from_device(context, device)
+                for device in devices]
 
     def create_security_group_rule(self, context, security_group_rule):
         bulk_rule = {'security_group_rules': [security_group_rule]}
index 9e4f8ca14d4bda2a77f21d4a93b82451297bb624..c6aef07d682f09be473c76120e6a8ca1bd5797dc 100644 (file)
@@ -19,7 +19,6 @@ from sqlalchemy import or_
 from sqlalchemy.orm import exc
 
 from neutron.common import constants as n_const
-from neutron.db import api as db_api
 from neutron.db import models_v2
 from neutron.db import securitygroups_db as sg_db
 from neutron.extensions import portbindings
@@ -244,14 +243,14 @@ def get_port(session, port_id):
             return
 
 
-def get_port_from_device_mac(device_mac):
+def get_port_from_device_mac(context, device_mac):
     LOG.debug("get_port_from_device_mac() called for mac %s", device_mac)
-    session = db_api.get_session()
-    qry = session.query(models_v2.Port).filter_by(mac_address=device_mac)
+    qry = context.session.query(models_v2.Port).filter_by(
+        mac_address=device_mac)
     return qry.first()
 
 
-def get_ports_and_sgs(port_ids):
+def get_ports_and_sgs(context, port_ids):
     """Get ports from database with security group info."""
 
     # break large queries into smaller parts
@@ -259,25 +258,24 @@ def get_ports_and_sgs(port_ids):
         LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
                   "query %(maxp)s. Partitioning queries.",
                   {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
-        return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
-                get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
+        return (get_ports_and_sgs(context, port_ids[:MAX_PORTS_PER_QUERY]) +
+                get_ports_and_sgs(context, port_ids[MAX_PORTS_PER_QUERY:]))
 
     LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
 
     if not port_ids:
         # if port_ids is empty, avoid querying to DB to ask it for nothing
         return []
-    ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
+    ports_to_sg_ids = get_sg_ids_grouped_by_port(context, port_ids)
     return [make_port_dict_with_security_groups(port, sec_groups)
             for port, sec_groups in ports_to_sg_ids.iteritems()]
 
 
-def get_sg_ids_grouped_by_port(port_ids):
+def get_sg_ids_grouped_by_port(context, port_ids):
     sg_ids_grouped_by_port = {}
-    session = db_api.get_session()
     sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
 
-    with session.begin(subtransactions=True):
+    with context.session.begin(subtransactions=True):
         # partial UUIDs must be individually matched with startswith.
         # full UUIDs may be matched directly in an IN statement
         partial_uuids = set(port_id for port_id in port_ids
@@ -288,8 +286,8 @@ def get_sg_ids_grouped_by_port(port_ids):
         if full_uuids:
             or_criteria.append(models_v2.Port.id.in_(full_uuids))
 
-        query = session.query(models_v2.Port,
-                              sg_db.SecurityGroupPortBinding.security_group_id)
+        query = context.session.query(
+            models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id)
         query = query.outerjoin(sg_db.SecurityGroupPortBinding,
                                 models_v2.Port.id == sg_binding_port)
         query = query.filter(or_(*or_criteria))
index 2f209db7723a58d27d69a9d8ce9882c69e05a5aa..c1b4bde99697526fdaa9916dd0b381c15f510be4 100644 (file)
@@ -1451,11 +1451,12 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             port_host = db.get_port_binding_host(context.session, port_id)
             return (port_host == host)
 
-    def get_ports_from_devices(self, devices):
-        port_ids_to_devices = dict((self._device_to_port_id(device), device)
-                                   for device in devices)
+    def get_ports_from_devices(self, context, devices):
+        port_ids_to_devices = dict(
+            (self._device_to_port_id(context, device), device)
+            for device in devices)
         port_ids = port_ids_to_devices.keys()
-        ports = db.get_ports_and_sgs(port_ids)
+        ports = db.get_ports_and_sgs(context, port_ids)
         for port in ports:
             # map back to original requested id
             port_id = next((port_id for port_id in port_ids
@@ -1465,7 +1466,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
         return ports
 
     @staticmethod
-    def _device_to_port_id(device):
+    def _device_to_port_id(context, device):
         # REVISIT(rkukura): Consider calling into MechanismDrivers to
         # process device names, or having MechanismDrivers supply list
         # of device prefixes to strip.
@@ -1475,7 +1476,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
         # REVISIT(irenab): Consider calling into bound MD to
         # handle the get_device_details RPC
         if not uuidutils.is_uuid_like(device):
-            port = db.get_port_from_device_mac(device)
+            port = db.get_port_from_device_mac(context, device)
             if port:
                 return port.id
         return device
index bdbf4b510aa6fa5842373fad3e9f561e89d7c1e8..6338642fc694d11b16f3336cfe386ca316efccd8 100644 (file)
@@ -67,7 +67,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
                   {'device': device, 'agent_id': agent_id, 'host': host})
 
         plugin = manager.NeutronManager.get_plugin()
-        port_id = plugin._device_to_port_id(device)
+        port_id = plugin._device_to_port_id(rpc_context, device)
         port_context = plugin.get_bound_port_context(rpc_context,
                                                      port_id,
                                                      host,
@@ -144,7 +144,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
                   "%(agent_id)s",
                   {'device': device, 'agent_id': agent_id})
         plugin = manager.NeutronManager.get_plugin()
-        port_id = plugin._device_to_port_id(device)
+        port_id = plugin._device_to_port_id(rpc_context, device)
         port_exists = True
         if (host and not plugin.port_bound_to_host(rpc_context,
                                                    port_id, host)):
@@ -173,7 +173,7 @@ class RpcCallbacks(type_tunnel.TunnelRpcCallbackMixin):
         LOG.debug("Device %(device)s up at agent %(agent_id)s",
                   {'device': device, 'agent_id': agent_id})
         plugin = manager.NeutronManager.get_plugin()
-        port_id = plugin._device_to_port_id(device)
+        port_id = plugin._device_to_port_id(rpc_context, device)
         if (host and not plugin.port_bound_to_host(rpc_context,
                                                    port_id, host)):
             LOG.debug("Device %(device)s not bound to the"
index af6b5c1a6bdeca3af601c102b6931dba6e941c96..50b425c848a2ac8bb4854671b136988459000277 100644 (file)
@@ -56,7 +56,7 @@ IPv6 = 6
 class SecurityGroupServerRpcMixin(sg_db_rpc.SecurityGroupServerRpcMixin):
 
     @staticmethod
-    def get_port_from_device(device):
+    def get_port_from_device(context, device):
         port = nvsd_db.get_port_from_device(device)
         if port:
             port['device'] = device
index 8dd6c90b0a651df325cdfe099ebc760addf99715..c07ad3519316f8ab671b163c4bbe18af8f6790f7 100644 (file)
@@ -94,7 +94,7 @@ class SecurityGroupRpcTestPlugin(test_sg.SecurityGroupTestPlugin,
         self.notify_security_groups_member_updated(context, port)
         del self.devices[id]
 
-    def get_port_from_device(self, device):
+    def get_port_from_device(self, context, device):
         device = self.devices.get(device)
         if device:
             device['security_group_rules'] = []
index c34b82abfdcbdbc1474e7dc8b447d392b863025a..db2c123376fbaa1efe9228fb4c1721b00b04f70d 100644 (file)
@@ -201,7 +201,8 @@ class Ml2DBTestCase(testlib_api.SqlTestCase):
         self._setup_neutron_network(network_id)
         port = self._setup_neutron_port(network_id, port_id)
 
-        observed_port = ml2_db.get_port_from_device_mac(port['mac_address'])
+        observed_port = ml2_db.get_port_from_device_mac(self.ctx,
+                                                        port['mac_address'])
         self.assertEqual(port_id, observed_port.id)
 
     def test_get_locked_port_and_binding(self):
index 21b90976a321b0639067c050d85924754c30fd30..77e00ef4f4e0cda7e1fbd54703b8d51ced7076fa 100644 (file)
@@ -614,23 +614,26 @@ class TestMl2PluginOnly(Ml2PluginV2TestCase):
                         ('qvo567890', '567890')]
         for device, expected in input_output:
             self.assertEqual(expected,
-                             ml2_plugin.Ml2Plugin._device_to_port_id(device))
+                             ml2_plugin.Ml2Plugin._device_to_port_id(
+                                 self.context, device))
 
     def test__device_to_port_id_mac_address(self):
         with self.port() as p:
             mac = p['port']['mac_address']
             port_id = p['port']['id']
             self.assertEqual(port_id,
-                             ml2_plugin.Ml2Plugin._device_to_port_id(mac))
+                             ml2_plugin.Ml2Plugin._device_to_port_id(
+                                 self.context, mac))
 
     def test__device_to_port_id_not_uuid_not_mac(self):
         dev = '1234567'
-        self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(dev))
+        self.assertEqual(dev, ml2_plugin.Ml2Plugin._device_to_port_id(
+            self.context, dev))
 
     def test__device_to_port_id_UUID(self):
         port_id = uuidutils.generate_uuid()
-        self.assertEqual(port_id,
-                         ml2_plugin.Ml2Plugin._device_to_port_id(port_id))
+        self.assertEqual(port_id, ml2_plugin.Ml2Plugin._device_to_port_id(
+            self.context, port_id))
 
 
 class TestMl2DvrPortsV2(TestMl2PortsV2):
index 56cbdbcb97947cd7baa3f439af7f48a177fb451b..cd7bcba14eec4d31b765017f5a0dbfbbb1576d3b 100644 (file)
@@ -75,14 +75,14 @@ class RpcCallbacksTestCase(base.BaseTestCase):
         self.plugin.get_bound_port_context.return_value = None
         self.assertEqual(
             {'device': 'fake_device'},
-            self.callbacks.get_device_details('fake_context',
+            self.callbacks.get_device_details(mock.Mock(),
                                               device='fake_device'))
 
     def test_get_device_details_port_context_without_bounded_segment(self):
         self.plugin.get_bound_port_context().bottom_bound_segment = None
         self.assertEqual(
             {'device': 'fake_device'},
-            self.callbacks.get_device_details('fake_context',
+            self.callbacks.get_device_details(mock.Mock(),
                                               device='fake_device'))
 
     def test_get_device_details_port_status_equal_new_status(self):
@@ -99,7 +99,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
                 port['admin_state_up'] = admin_state_up
                 port['status'] = status
                 self.plugin.update_port_status.reset_mock()
-                self.callbacks.get_device_details('fake_context')
+                self.callbacks.get_device_details(mock.Mock())
                 self.assertEqual(status == new_status,
                                  not self.plugin.update_port_status.called)
 
@@ -109,7 +109,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
         self.plugin.get_bound_port_context().current = port
         self.plugin.get_bound_port_context().network.current = (
             {"id": "fake_network"})
-        self.callbacks.get_device_details('fake_context', host='fake_host',
+        self.callbacks.get_device_details(mock.Mock(), host='fake_host',
                                           cached_networks=cached_networks)
         self.assertTrue('fake_port' in cached_networks)
 
@@ -119,7 +119,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
         port_context.current = port
         port_context.host = 'fake'
         self.plugin.update_port_status.reset_mock()
-        self.callbacks.get_device_details('fake_context',
+        self.callbacks.get_device_details(mock.Mock(),
                                           host='fake_host')
         self.assertFalse(self.plugin.update_port_status.called)
 
@@ -128,7 +128,7 @@ class RpcCallbacksTestCase(base.BaseTestCase):
         port_context = self.plugin.get_bound_port_context()
         port_context.current = port
         self.plugin.update_port_status.reset_mock()
-        self.callbacks.get_device_details('fake_context')
+        self.callbacks.get_device_details(mock.Mock())
         self.assertTrue(self.plugin.update_port_status.called)
 
     def test_get_devices_details_list(self):
@@ -155,8 +155,8 @@ class RpcCallbacksTestCase(base.BaseTestCase):
     def _test_update_device_not_bound_to_host(self, func):
         self.plugin.port_bound_to_host.return_value = False
         self.plugin._device_to_port_id.return_value = 'fake_port_id'
-        res = func('fake_context', device='fake_device', host='fake_host')
-        self.plugin.port_bound_to_host.assert_called_once_with('fake_context',
+        res = func(mock.Mock(), device='fake_device', host='fake_host')
+        self.plugin.port_bound_to_host.assert_called_once_with(mock.ANY,
                                                                'fake_port_id',
                                                                'fake_host')
         return res
@@ -176,18 +176,18 @@ class RpcCallbacksTestCase(base.BaseTestCase):
         self.plugin._device_to_port_id.return_value = 'fake_port_id'
         self.assertEqual(
             {'device': 'fake_device', 'exists': False},
-            self.callbacks.update_device_down('fake_context',
+            self.callbacks.update_device_down(mock.Mock(),
                                               device='fake_device',
                                               host='fake_host'))
         self.plugin.update_port_status.assert_called_once_with(
-            'fake_context', 'fake_port_id', constants.PORT_STATUS_DOWN,
+            mock.ANY, 'fake_port_id', constants.PORT_STATUS_DOWN,
             'fake_host')
 
     def test_update_device_down_call_update_port_status_failed(self):
         self.plugin.update_port_status.side_effect = exc.StaleDataError
         self.assertEqual({'device': 'fake_device', 'exists': False},
                          self.callbacks.update_device_down(
-                             'fake_context', device='fake_device'))
+                             mock.Mock(), device='fake_device'))
 
 
 class RpcApiTestCase(base.BaseTestCase):
index 97f36a75ba43fd17121d92bf1b33496be1cbd0d1..772853938ddf66c1c060c3e113935fb023b707c8 100644 (file)
@@ -19,6 +19,7 @@ import math
 import mock
 
 from neutron.common import constants as const
+from neutron import context
 from neutron.extensions import securitygroup as ext_sg
 from neutron import manager
 from neutron.tests import tools
@@ -51,6 +52,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
                             test_sg_rpc.SGNotificationTestMixin):
     def setUp(self):
         super(TestMl2SecurityGroups, self).setUp()
+        self.ctx = context.get_admin_context()
         plugin = manager.NeutronManager.get_plugin()
         plugin.start_rpc_listeners()
 
@@ -75,7 +77,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
                 ]
                 plugin = manager.NeutronManager.get_plugin()
                 # should match full ID and starting chars
-                ports = plugin.get_ports_from_devices(
+                ports = plugin.get_ports_from_devices(self.ctx,
                     [orig_ports[0]['id'], orig_ports[1]['id'][0:8],
                      orig_ports[2]['id']])
                 self.assertEqual(len(orig_ports), len(ports))
@@ -92,7 +94,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
 
     def test_security_group_get_ports_from_devices_with_bad_id(self):
         plugin = manager.NeutronManager.get_plugin()
-        ports = plugin.get_ports_from_devices(['bad_device_id'])
+        ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
         self.assertFalse(ports)
 
     def test_security_group_no_db_calls_with_no_ports(self):
@@ -100,7 +102,7 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
         with mock.patch(
             'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
         ) as get_mock:
-            self.assertFalse(plugin.get_ports_from_devices([]))
+            self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
             self.assertFalse(get_mock.called)
 
     def test_large_port_count_broken_into_parts(self):
@@ -114,10 +116,10 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
                 mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
                            return_value={}),
             ) as (max_mock, get_mock):
-                plugin.get_ports_from_devices(
+                plugin.get_ports_from_devices(self.ctx,
                     ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
                      for i in range(ports_to_query)])
-                all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
+                all_call_args = map(lambda x: x[1][1], get_mock.mock_calls)
                 last_call_args = all_call_args.pop()
                 # all but last should be getting MAX_PORTS_PER_QUERY ports
                 self.assertTrue(
@@ -139,14 +141,14 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
         # have one matching 'IN' critiera for all of the IDs
         with contextlib.nested(
             mock.patch('neutron.plugins.ml2.db.or_'),
-            mock.patch('neutron.plugins.ml2.db.db_api.get_session')
-        ) as (or_mock, sess_mock):
-            qmock = sess_mock.return_value.query
+            mock.patch('sqlalchemy.orm.Session.query')
+        ) as (or_mock, qmock):
             fmock = qmock.return_value.outerjoin.return_value.filter
             # return no ports to exit the method early since we are mocking
             # the query
             fmock.return_value = []
-            plugin.get_ports_from_devices([test_base._uuid(),
+            plugin.get_ports_from_devices(self.ctx,
+                                          [test_base._uuid(),
                                            test_base._uuid()])
             # the or_ function should only have one argument
             or_mock.assert_called_once_with(mock.ANY)
index dcde1915a9b67ae298f55453cd58274811694e95..203d2a7e0e3de43ae92de1a0cea19086b2b5d7a3 100644 (file)
@@ -89,7 +89,8 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
                                            req.get_response(self.api))
                     port_id = res['port']['id']
                     plugin = manager.NeutronManager.get_plugin()
-                    port_dict = plugin.get_port_from_device(port_id)
+                    port_dict = plugin.get_port_from_device(mock.Mock(),
+                                                            port_id)
                     self.assertEqual(port_id, port_dict['id'])
                     self.assertEqual([security_group_id],
                                      port_dict[ext_sg.SECURITYGROUPS])
@@ -101,5 +102,5 @@ class TestOneConvergenceSecurityGroups(OneConvergenceSecurityGroupsTestCase,
     def test_security_group_get_port_from_device_with_no_port(self):
 
         plugin = manager.NeutronManager.get_plugin()
-        port_dict = plugin.get_port_from_device('bad_device_id')
+        port_dict = plugin.get_port_from_device(mock.Mock(), 'bad_device_id')
         self.assertIsNone(port_dict)