]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Batch ports from security groups RPC handler
authorKevin Benton <blak111@gmail.com>
Fri, 26 Sep 2014 16:40:44 +0000 (09:40 -0700)
committerKevin Benton <kevinbenton@buttewifi.com>
Fri, 31 Oct 2014 23:01:33 +0000 (23:01 +0000)
The security groups RPC handler calls get_port_from_device
individually for each device in a list it receives. Each
one of these results in a separate SQL query for the security
groups and port details. This becomes very inefficient as the
number of devices on a single node increases.

This patch adds logic to the RPC handler to see if the core
plugin has a method to lookup all of the device IDs at once.
If so, it uses that method, otherwise it continues as normal.

The ML2 plugin is modified to include the batch function, which
uses one SQL query regardless of the number of devices.

Closes-Bug: #1374556
Change-Id: I15d19c22e8c44577db190309b6636a3251a9c66a
(cherry picked from commit abc16ebfcf8fd1fbdb4ef68590140d4d355b0a7c)

neutron/api/rpc/handlers/securitygroups_rpc.py
neutron/db/securitygroups_rpc_base.py
neutron/plugins/ml2/db.py
neutron/plugins/ml2/plugin.py
neutron/tests/unit/ml2/test_security_group.py

index 2a748cfbcc7961a575fdee40bc8de6e3257f2859..e4a16b2f58d6566daf57cd7e507c761c76d2d786 100644 (file)
@@ -36,15 +36,11 @@ class SecurityGroupServerRpcCallback(n_rpc.RpcCallback):
         return manager.NeutronManager.get_plugin()
 
     def _get_devices_info(self, devices):
-        devices_info = {}
-        for device in devices:
-            port = self.plugin.get_port_from_device(device)
-            if not port:
-                continue
-            if port['device_owner'].startswith('network:'):
-                continue
-            devices_info[port['id']] = port
-        return devices_info
+        return dict(
+            (port['id'], port)
+            for port in self.plugin.get_ports_from_devices(devices)
+            if port and not port['device_owner'].startswith('network:')
+        )
 
     def security_group_rules_for_devices(self, context, **kwargs):
         """Callback method to return security group rules for each port.
index 1dda6bb46982325494e504fe61dc7a072f0b7802..76872a74691f9ab8b0f550f446ee0f195c305488 100644 (file)
@@ -39,7 +39,7 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
     def get_port_from_device(self, device):
         """Get port dict from device name on an agent.
 
-        Subclass must provide this method.
+        Subclass must provide this method or get_ports_from_devices.
 
         :param device: device name which identifies a port on the agent side.
         What is specified in "device" depends on a plugin agent implementation.
@@ -53,9 +53,18 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
         - security_group_source_groups
         - fixed_ips
         """
-        raise NotImplementedError(_("%s must implement get_port_from_device.")
+        raise NotImplementedError(_("%s must implement get_port_from_device "
+                                    "or get_ports_from_devices.")
                                   % self.__class__.__name__)
 
+    def get_ports_from_devices(self, devices):
+        """Bulk method of get_port_from_device.
+
+        Subclasses may override this to provide better performance for DB
+        queries, backend calls, etc.
+        """
+        return [self.get_port_from_device(device) for device in devices]
+
     def create_security_group_rule(self, context, security_group_rule):
         bulk_rule = {'security_group_rules': [security_group_rule]}
         rule = self.create_security_group_rule_bulk_native(context,
index d8caa9384af1bb8567a20e1421adca9df3be3c3e..40e1c22e52fd521717667408282f581f57305ad8 100644 (file)
@@ -13,6 +13,9 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import collections
+
+from sqlalchemy import or_
 from sqlalchemy.orm import exc
 
 from oslo.db import exception as db_exc
@@ -30,6 +33,9 @@ from neutron.plugins.ml2 import models
 
 LOG = log.getLogger(__name__)
 
+# limit the number of port OR LIKE statements in one query
+MAX_PORTS_PER_QUERY = 500
+
 
 def _make_segment_dict(record):
     """Make a segment dictionary out of a DB record."""
@@ -206,32 +212,64 @@ def get_port_from_device_mac(device_mac):
     return qry.first()
 
 
-def get_port_and_sgs(port_id):
-    """Get port from database with security group info."""
+def get_ports_and_sgs(port_ids):
+    """Get ports from database with security group info."""
+
+    # break large queries into smaller parts
+    if len(port_ids) > MAX_PORTS_PER_QUERY:
+        LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
+                  "query %(maxp)s. Partitioning queries.",
+                  {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
+        return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
+                get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
+
+    LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
 
-    LOG.debug(_("get_port_and_sgs() called for port_id %s"), port_id)
+    if not port_ids:
+        # if port_ids is empty, avoid querying to DB to ask it for nothing
+        return []
+    ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
+    return [make_port_dict_with_security_groups(port, sec_groups)
+            for port, sec_groups in ports_to_sg_ids.iteritems()]
+
+
+def get_sg_ids_grouped_by_port(port_ids):
+    sg_ids_grouped_by_port = collections.defaultdict(list)
     session = db_api.get_session()
     sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
 
     with session.begin(subtransactions=True):
+        # partial UUIDs must be individually matched with startswith.
+        # full UUIDs may be matched directly in an IN statement
+        partial_uuids = set(port_id for port_id in port_ids
+                            if not uuidutils.is_uuid_like(port_id))
+        full_uuids = set(port_ids) - partial_uuids
+        or_criteria = [models_v2.Port.id.startswith(port_id)
+                       for port_id in partial_uuids]
+        if full_uuids:
+            or_criteria.append(models_v2.Port.id.in_(full_uuids))
+
         query = session.query(models_v2.Port,
                               sg_db.SecurityGroupPortBinding.security_group_id)
         query = query.outerjoin(sg_db.SecurityGroupPortBinding,
                                 models_v2.Port.id == sg_binding_port)
-        query = query.filter(models_v2.Port.id.startswith(port_id))
-        port_and_sgs = query.all()
-        if not port_and_sgs:
-            return
-        port = port_and_sgs[0][0]
-        plugin = manager.NeutronManager.get_plugin()
-        port_dict = plugin._make_port_dict(port)
-        port_dict['security_groups'] = [
-            sg_id for port_, sg_id in port_and_sgs if sg_id]
-        port_dict['security_group_rules'] = []
-        port_dict['security_group_source_groups'] = []
-        port_dict['fixed_ips'] = [ip['ip_address']
-                                  for ip in port['fixed_ips']]
-        return port_dict
+        query = query.filter(or_(*or_criteria))
+
+        for port, sg_id in query:
+            if sg_id:
+                sg_ids_grouped_by_port[port].append(sg_id)
+    return sg_ids_grouped_by_port
+
+
+def make_port_dict_with_security_groups(port, sec_groups):
+    plugin = manager.NeutronManager.get_plugin()
+    port_dict = plugin._make_port_dict(port)
+    port_dict['security_groups'] = sec_groups
+    port_dict['security_group_rules'] = []
+    port_dict['security_group_source_groups'] = []
+    port_dict['fixed_ips'] = [ip['ip_address']
+                              for ip in port['fixed_ips']]
+    return port_dict
 
 
 def get_port_binding_host(port_id):
index 72cf151006ac3ee697a79cbc4f520e8db86f3a7d..d29deda6cef8eef6bb502a06d9e24ebeed6ee1a2 100644 (file)
@@ -1156,12 +1156,18 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             port_host = db.get_port_binding_host(port_id)
             return (port_host == host)
 
-    def get_port_from_device(self, device):
-        port_id = self._device_to_port_id(device)
-        port = db.get_port_and_sgs(port_id)
-        if port:
-            port['device'] = device
-        return port
+    def get_ports_from_devices(self, devices):
+        port_ids_to_devices = dict((self._device_to_port_id(device), device)
+                                   for device in devices)
+        port_ids = port_ids_to_devices.keys()
+        ports = db.get_ports_and_sgs(port_ids)
+        for port in ports:
+            # map back to original requested id
+            port_id = next((port_id for port_id in port_ids
+                           if port['id'].startswith(port_id)), None)
+            port['device'] = port_ids_to_devices.get(port_id)
+
+        return ports
 
     def _device_to_port_id(self, device):
         # REVISIT(rkukura): Consider calling into MechanismDrivers to
index 39c3cc2baedb69f845bd78da63e4ad5e33793270..cc8468ae233d6293331d9e6142a5c4e21e8413e4 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import contextlib
+import math
 import mock
 
 from neutron.api.v2 import attributes
+from neutron.common import constants as const
 from neutron.extensions import securitygroup as ext_sg
 from neutron import manager
+from neutron.tests.unit import test_api_v2
 from neutron.tests.unit import test_extension_security_group as test_sg
 from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc
 
@@ -55,38 +59,91 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
         plugin = manager.NeutronManager.get_plugin()
         plugin.start_rpc_listeners()
 
-    def test_security_group_get_port_from_device(self):
+    def _make_port_with_new_sec_group(self, net_id):
+        sg = self._make_security_group(self.fmt, 'name', 'desc')
+        port = self._make_port(
+            self.fmt, net_id, security_groups=[sg['security_group']['id']])
+        return port['port']
+
+    def test_security_group_get_ports_from_devices(self):
         with self.network() as n:
             with self.subnet(n):
-                with self.security_group() as sg:
-                    security_group_id = sg['security_group']['id']
-                    res = self._create_port(self.fmt, n['network']['id'])
-                    port = self.deserialize(self.fmt, res)
-                    fixed_ips = port['port']['fixed_ips']
-                    data = {'port': {'fixed_ips': fixed_ips,
-                                     'name': port['port']['name'],
-                                     ext_sg.SECURITYGROUPS:
-                                     [security_group_id]}}
-
-                    req = self.new_update_request('ports', data,
-                                                  port['port']['id'])
-                    res = self.deserialize(self.fmt,
-                                           req.get_response(self.api))
-                    port_id = res['port']['id']
-                    plugin = manager.NeutronManager.get_plugin()
-                    port_dict = plugin.get_port_from_device(port_id)
-                    self.assertEqual(port_id, port_dict['id'])
-                    self.assertEqual([security_group_id],
+                port1 = self._make_port_with_new_sec_group(n['network']['id'])
+                port2 = self._make_port_with_new_sec_group(n['network']['id'])
+                plugin = manager.NeutronManager.get_plugin()
+                # should match full ID and starting chars
+                ports = plugin.get_ports_from_devices(
+                    [port1['id'], port2['id'][0:8]])
+                self.assertEqual(2, len(ports))
+                for port_dict in ports:
+                    p = port1 if port1['id'] == port_dict['id'] else port2
+                    self.assertEqual(p['id'], port_dict['id'])
+                    self.assertEqual(p['security_groups'],
                                      port_dict[ext_sg.SECURITYGROUPS])
                     self.assertEqual([], port_dict['security_group_rules'])
-                    self.assertEqual([fixed_ips[0]['ip_address']],
+                    self.assertEqual([p['fixed_ips'][0]['ip_address']],
                                      port_dict['fixed_ips'])
-                    self._delete('ports', port_id)
+                    self._delete('ports', p['id'])
+
+    def test_security_group_get_ports_from_devices_with_bad_id(self):
+        plugin = manager.NeutronManager.get_plugin()
+        ports = plugin.get_ports_from_devices(['bad_device_id'])
+        self.assertFalse(ports)
 
-    def test_security_group_get_port_from_device_with_no_port(self):
+    def test_security_group_no_db_calls_with_no_ports(self):
+        plugin = manager.NeutronManager.get_plugin()
+        with mock.patch(
+            'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
+        ) as get_mock:
+            self.assertFalse(plugin.get_ports_from_devices([]))
+            self.assertFalse(get_mock.called)
+
+    def test_large_port_count_broken_into_parts(self):
+        plugin = manager.NeutronManager.get_plugin()
+        max_ports_per_query = 5
+        ports_to_query = 73
+        for max_ports_per_query in (1, 2, 5, 7, 9, 31):
+            with contextlib.nested(
+                mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY',
+                           new=max_ports_per_query),
+                mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
+                           return_value={}),
+            ) as (max_mock, get_mock):
+                plugin.get_ports_from_devices(
+                    ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
+                     for i in range(ports_to_query)])
+                all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
+                last_call_args = all_call_args.pop()
+                # all but last should be getting MAX_PORTS_PER_QUERY ports
+                self.assertTrue(
+                    all(map(lambda x: len(x) == max_ports_per_query,
+                            all_call_args))
+                )
+                remaining = ports_to_query % max_ports_per_query
+                if remaining:
+                    self.assertEqual(remaining, len(last_call_args))
+                # should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls
+                self.assertEqual(
+                    math.ceil(ports_to_query / float(max_ports_per_query)),
+                    get_mock.call_count
+                )
+
+    def test_full_uuids_skip_port_id_lookup(self):
         plugin = manager.NeutronManager.get_plugin()
-        port_dict = plugin.get_port_from_device('bad_device_id')
-        self.assertIsNone(port_dict)
+        # when full UUIDs are provided, the _or statement should only
+        # have one matching 'IN' critiera for all of the IDs
+        with contextlib.nested(
+            mock.patch('neutron.plugins.ml2.db.or_'),
+            mock.patch('neutron.plugins.ml2.db.db_api.get_session')
+        ) as (or_mock, sess_mock):
+            fmock = sess_mock.query.return_value.outerjoin.return_value.filter
+            # return no ports to exit the method early since we are mocking
+            # the query
+            fmock.return_value.all.return_value = []
+            plugin.get_ports_from_devices([test_api_v2._uuid(),
+                                           test_api_v2._uuid()])
+            # the or_ function should only have one argument
+            or_mock.assert_called_once_with(mock.ANY)
 
 
 class TestMl2SecurityGroupsXML(TestMl2SecurityGroups):