]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Batch ports from security groups RPC handler
authorKevin Benton <blak111@gmail.com>
Fri, 26 Sep 2014 16:40:44 +0000 (09:40 -0700)
committerKevin Benton <blak111@gmail.com>
Wed, 29 Oct 2014 04:04:30 +0000 (21:04 -0700)
The security groups RPC handler calls get_port_from_device
individually for each device in a list it receives. Each
one of these results in a separate SQL query for the security
groups and port details. This becomes very inefficient as the
number of devices on a single node increases.

This patch adds logic to the RPC handler to see if the core
plugin has a method to lookup all of the device IDs at once.
If so, it uses that method, otherwise it continues as normal.

The ML2 plugin is modified to include the batch function, which
uses one SQL query regardless of the number of devices.

Closes-Bug: #1374556
Change-Id: I15d19c22e8c44577db190309b6636a3251a9c66a

neutron/api/rpc/handlers/securitygroups_rpc.py
neutron/db/securitygroups_rpc_base.py
neutron/plugins/ml2/db.py
neutron/plugins/ml2/plugin.py
neutron/tests/unit/ml2/test_security_group.py

index 2a748cfbcc7961a575fdee40bc8de6e3257f2859..e4a16b2f58d6566daf57cd7e507c761c76d2d786 100644 (file)
@@ -36,15 +36,11 @@ class SecurityGroupServerRpcCallback(n_rpc.RpcCallback):
         return manager.NeutronManager.get_plugin()
 
     def _get_devices_info(self, devices):
-        devices_info = {}
-        for device in devices:
-            port = self.plugin.get_port_from_device(device)
-            if not port:
-                continue
-            if port['device_owner'].startswith('network:'):
-                continue
-            devices_info[port['id']] = port
-        return devices_info
+        return dict(
+            (port['id'], port)
+            for port in self.plugin.get_ports_from_devices(devices)
+            if port and not port['device_owner'].startswith('network:')
+        )
 
     def security_group_rules_for_devices(self, context, **kwargs):
         """Callback method to return security group rules for each port.
index bcbe32c556131d8c02122b02fc61b76efe9e9782..8c1df8bea088cc1323434911a4a6e34a2a8d0e6b 100644 (file)
@@ -40,7 +40,7 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
     def get_port_from_device(self, device):
         """Get port dict from device name on an agent.
 
-        Subclass must provide this method.
+        Subclass must provide this method or get_ports_from_devices.
 
         :param device: device name which identifies a port on the agent side.
         What is specified in "device" depends on a plugin agent implementation.
@@ -54,9 +54,18 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
         - security_group_source_groups
         - fixed_ips
         """
-        raise NotImplementedError(_("%s must implement get_port_from_device.")
+        raise NotImplementedError(_("%s must implement get_port_from_device "
+                                    "or get_ports_from_devices.")
                                   % self.__class__.__name__)
 
+    def get_ports_from_devices(self, devices):
+        """Bulk method of get_port_from_device.
+
+        Subclasses may override this to provide better performance for DB
+        queries, backend calls, etc.
+        """
+        return [self.get_port_from_device(device) for device in devices]
+
     def create_security_group_rule(self, context, security_group_rule):
         bulk_rule = {'security_group_rules': [security_group_rule]}
         rule = self.create_security_group_rule_bulk_native(context,
index 37e91bc791c5c8ccdbf6f0065eb7089d77b05d4e..8dc473e2e0eeeb859d2fea819a4b9e5e17080717 100644 (file)
@@ -13,6 +13,9 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import collections
+
+from sqlalchemy import or_
 from sqlalchemy.orm import exc
 
 from oslo.db import exception as db_exc
@@ -30,6 +33,9 @@ from neutron.plugins.ml2 import models
 
 LOG = log.getLogger(__name__)
 
+# limit the number of port OR LIKE statements in one query
+MAX_PORTS_PER_QUERY = 500
+
 
 def _make_segment_dict(record):
     """Make a segment dictionary out of a DB record."""
@@ -209,32 +215,64 @@ def get_port_from_device_mac(device_mac):
     return qry.first()
 
 
-def get_port_and_sgs(port_id):
-    """Get port from database with security group info."""
+def get_ports_and_sgs(port_ids):
+    """Get ports from database with security group info."""
+
+    # break large queries into smaller parts
+    if len(port_ids) > MAX_PORTS_PER_QUERY:
+        LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
+                  "query %(maxp)s. Partitioning queries.",
+                  {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
+        return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
+                get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
+
+    LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
 
-    LOG.debug(_("get_port_and_sgs() called for port_id %s"), port_id)
+    if not port_ids:
+        # if port_ids is empty, avoid querying to DB to ask it for nothing
+        return []
+    ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
+    return [make_port_dict_with_security_groups(port, sec_groups)
+            for port, sec_groups in ports_to_sg_ids.iteritems()]
+
+
+def get_sg_ids_grouped_by_port(port_ids):
+    sg_ids_grouped_by_port = collections.defaultdict(list)
     session = db_api.get_session()
     sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
 
     with session.begin(subtransactions=True):
+        # partial UUIDs must be individually matched with startswith.
+        # full UUIDs may be matched directly in an IN statement
+        partial_uuids = set(port_id for port_id in port_ids
+                            if not uuidutils.is_uuid_like(port_id))
+        full_uuids = set(port_ids) - partial_uuids
+        or_criteria = [models_v2.Port.id.startswith(port_id)
+                       for port_id in partial_uuids]
+        if full_uuids:
+            or_criteria.append(models_v2.Port.id.in_(full_uuids))
+
         query = session.query(models_v2.Port,
                               sg_db.SecurityGroupPortBinding.security_group_id)
         query = query.outerjoin(sg_db.SecurityGroupPortBinding,
                                 models_v2.Port.id == sg_binding_port)
-        query = query.filter(models_v2.Port.id.startswith(port_id))
-        port_and_sgs = query.all()
-        if not port_and_sgs:
-            return
-        port = port_and_sgs[0][0]
-        plugin = manager.NeutronManager.get_plugin()
-        port_dict = plugin._make_port_dict(port)
-        port_dict['security_groups'] = [
-            sg_id for port_, sg_id in port_and_sgs if sg_id]
-        port_dict['security_group_rules'] = []
-        port_dict['security_group_source_groups'] = []
-        port_dict['fixed_ips'] = [ip['ip_address']
-                                  for ip in port['fixed_ips']]
-        return port_dict
+        query = query.filter(or_(*or_criteria))
+
+        for port, sg_id in query:
+            if sg_id:
+                sg_ids_grouped_by_port[port].append(sg_id)
+    return sg_ids_grouped_by_port
+
+
+def make_port_dict_with_security_groups(port, sec_groups):
+    plugin = manager.NeutronManager.get_plugin()
+    port_dict = plugin._make_port_dict(port)
+    port_dict['security_groups'] = sec_groups
+    port_dict['security_group_rules'] = []
+    port_dict['security_group_source_groups'] = []
+    port_dict['fixed_ips'] = [ip['ip_address']
+                              for ip in port['fixed_ips']]
+    return port_dict
 
 
 def get_port_binding_host(port_id):
index 6dbedd12a1c8f15fef09467f2d2b6bdc74546b10..db0c52356375267348e59ab3ec13519551ae932f 100755 (executable)
@@ -1176,12 +1176,18 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             port_host = db.get_port_binding_host(port_id)
             return (port_host == host)
 
-    def get_port_from_device(self, device):
-        port_id = self._device_to_port_id(device)
-        port = db.get_port_and_sgs(port_id)
-        if port:
-            port['device'] = device
-        return port
+    def get_ports_from_devices(self, devices):
+        port_ids_to_devices = dict((self._device_to_port_id(device), device)
+                                   for device in devices)
+        port_ids = port_ids_to_devices.keys()
+        ports = db.get_ports_and_sgs(port_ids)
+        for port in ports:
+            # map back to original requested id
+            port_id = next((port_id for port_id in port_ids
+                           if port['id'].startswith(port_id)), None)
+            port['device'] = port_ids_to_devices.get(port_id)
+
+        return ports
 
     def _device_to_port_id(self, device):
         # REVISIT(rkukura): Consider calling into MechanismDrivers to
index 5fa8063f3fec772fb9b94b28de99f92b3bba43a7..6d3d5f491aadf099c44f3b2f2289ce251d04805a 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import contextlib
+import math
 import mock
 
 from neutron.api.v2 import attributes
+from neutron.common import constants as const
 from neutron.extensions import securitygroup as ext_sg
 from neutron import manager
+from neutron.tests.unit import test_api_v2
 from neutron.tests.unit import test_extension_security_group as test_sg
 from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc
 
@@ -55,38 +59,91 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
         plugin = manager.NeutronManager.get_plugin()
         plugin.start_rpc_listeners()
 
-    def test_security_group_get_port_from_device(self):
+    def _make_port_with_new_sec_group(self, net_id):
+        sg = self._make_security_group(self.fmt, 'name', 'desc')
+        port = self._make_port(
+            self.fmt, net_id, security_groups=[sg['security_group']['id']])
+        return port['port']
+
+    def test_security_group_get_ports_from_devices(self):
         with self.network() as n:
             with self.subnet(n):
-                with self.security_group() as sg:
-                    security_group_id = sg['security_group']['id']
-                    res = self._create_port(self.fmt, n['network']['id'])
-                    port = self.deserialize(self.fmt, res)
-                    fixed_ips = port['port']['fixed_ips']
-                    data = {'port': {'fixed_ips': fixed_ips,
-                                     'name': port['port']['name'],
-                                     ext_sg.SECURITYGROUPS:
-                                     [security_group_id]}}
-
-                    req = self.new_update_request('ports', data,
-                                                  port['port']['id'])
-                    res = self.deserialize(self.fmt,
-                                           req.get_response(self.api))
-                    port_id = res['port']['id']
-                    plugin = manager.NeutronManager.get_plugin()
-                    port_dict = plugin.get_port_from_device(port_id)
-                    self.assertEqual(port_id, port_dict['id'])
-                    self.assertEqual([security_group_id],
+                port1 = self._make_port_with_new_sec_group(n['network']['id'])
+                port2 = self._make_port_with_new_sec_group(n['network']['id'])
+                plugin = manager.NeutronManager.get_plugin()
+                # should match full ID and starting chars
+                ports = plugin.get_ports_from_devices(
+                    [port1['id'], port2['id'][0:8]])
+                self.assertEqual(2, len(ports))
+                for port_dict in ports:
+                    p = port1 if port1['id'] == port_dict['id'] else port2
+                    self.assertEqual(p['id'], port_dict['id'])
+                    self.assertEqual(p['security_groups'],
                                      port_dict[ext_sg.SECURITYGROUPS])
                     self.assertEqual([], port_dict['security_group_rules'])
-                    self.assertEqual([fixed_ips[0]['ip_address']],
+                    self.assertEqual([p['fixed_ips'][0]['ip_address']],
                                      port_dict['fixed_ips'])
-                    self._delete('ports', port_id)
+                    self._delete('ports', p['id'])
+
+    def test_security_group_get_ports_from_devices_with_bad_id(self):
+        plugin = manager.NeutronManager.get_plugin()
+        ports = plugin.get_ports_from_devices(['bad_device_id'])
+        self.assertFalse(ports)
+
+    def test_security_group_no_db_calls_with_no_ports(self):
+        plugin = manager.NeutronManager.get_plugin()
+        with mock.patch(
+            'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
+        ) as get_mock:
+            self.assertFalse(plugin.get_ports_from_devices([]))
+            self.assertFalse(get_mock.called)
+
+    def test_large_port_count_broken_into_parts(self):
+        plugin = manager.NeutronManager.get_plugin()
+        max_ports_per_query = 5
+        ports_to_query = 73
+        for max_ports_per_query in (1, 2, 5, 7, 9, 31):
+            with contextlib.nested(
+                mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY',
+                           new=max_ports_per_query),
+                mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
+                           return_value={}),
+            ) as (max_mock, get_mock):
+                plugin.get_ports_from_devices(
+                    ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
+                     for i in range(ports_to_query)])
+                all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
+                last_call_args = all_call_args.pop()
+                # all but last should be getting MAX_PORTS_PER_QUERY ports
+                self.assertTrue(
+                    all(map(lambda x: len(x) == max_ports_per_query,
+                            all_call_args))
+                )
+                remaining = ports_to_query % max_ports_per_query
+                if remaining:
+                    self.assertEqual(remaining, len(last_call_args))
+                # should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls
+                self.assertEqual(
+                    math.ceil(ports_to_query / float(max_ports_per_query)),
+                    get_mock.call_count
+                )
 
-    def test_security_group_get_port_from_device_with_no_port(self):
+    def test_full_uuids_skip_port_id_lookup(self):
         plugin = manager.NeutronManager.get_plugin()
-        port_dict = plugin.get_port_from_device('bad_device_id')
-        self.assertIsNone(port_dict)
+        # when full UUIDs are provided, the _or statement should only
+        # have one matching 'IN' critiera for all of the IDs
+        with contextlib.nested(
+            mock.patch('neutron.plugins.ml2.db.or_'),
+            mock.patch('neutron.plugins.ml2.db.db_api.get_session')
+        ) as (or_mock, sess_mock):
+            fmock = sess_mock.query.return_value.outerjoin.return_value.filter
+            # return no ports to exit the method early since we are mocking
+            # the query
+            fmock.return_value.all.return_value = []
+            plugin.get_ports_from_devices([test_api_v2._uuid(),
+                                           test_api_v2._uuid()])
+            # the or_ function should only have one argument
+            or_mock.assert_called_once_with(mock.ANY)
 
 
 class TestMl2SGServerRpcCallBack(