return manager.NeutronManager.get_plugin()
def _get_devices_info(self, devices):
- devices_info = {}
- for device in devices:
- port = self.plugin.get_port_from_device(device)
- if not port:
- continue
- if port['device_owner'].startswith('network:'):
- continue
- devices_info[port['id']] = port
- return devices_info
+ return dict(
+ (port['id'], port)
+ for port in self.plugin.get_ports_from_devices(devices)
+ if port and not port['device_owner'].startswith('network:')
+ )
def security_group_rules_for_devices(self, context, **kwargs):
"""Callback method to return security group rules for each port.
def get_port_from_device(self, device):
"""Get port dict from device name on an agent.
- Subclass must provide this method.
+ Subclass must provide this method or get_ports_from_devices.
:param device: device name which identifies a port on the agent side.
What is specified in "device" depends on a plugin agent implementation.
- security_group_source_groups
- fixed_ips
"""
- raise NotImplementedError(_("%s must implement get_port_from_device.")
+ raise NotImplementedError(_("%s must implement get_port_from_device "
+ "or get_ports_from_devices.")
% self.__class__.__name__)
+ def get_ports_from_devices(self, devices):
+ """Bulk method of get_port_from_device.
+
+ Subclasses may override this to provide better performance for DB
+ queries, backend calls, etc.
+ """
+ return [self.get_port_from_device(device) for device in devices]
+
def create_security_group_rule(self, context, security_group_rule):
bulk_rule = {'security_group_rules': [security_group_rule]}
rule = self.create_security_group_rule_bulk_native(context,
# License for the specific language governing permissions and limitations
# under the License.
+import collections
+
+from sqlalchemy import or_
from sqlalchemy.orm import exc
from oslo.db import exception as db_exc
LOG = log.getLogger(__name__)
+# limit the number of port OR LIKE statements in one query
+MAX_PORTS_PER_QUERY = 500
+
def _make_segment_dict(record):
"""Make a segment dictionary out of a DB record."""
return qry.first()
-def get_port_and_sgs(port_id):
- """Get port from database with security group info."""
+def get_ports_and_sgs(port_ids):
+ """Get ports from database with security group info."""
+
+ # break large queries into smaller parts
+ if len(port_ids) > MAX_PORTS_PER_QUERY:
+ LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
+ "query %(maxp)s. Partitioning queries.",
+ {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
+ return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
+ get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
+
+ LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
- LOG.debug(_("get_port_and_sgs() called for port_id %s"), port_id)
+ if not port_ids:
+ # if port_ids is empty, avoid querying to DB to ask it for nothing
+ return []
+ ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
+ return [make_port_dict_with_security_groups(port, sec_groups)
+ for port, sec_groups in ports_to_sg_ids.iteritems()]
+
+
+def get_sg_ids_grouped_by_port(port_ids):
+ sg_ids_grouped_by_port = collections.defaultdict(list)
session = db_api.get_session()
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
with session.begin(subtransactions=True):
+ # partial UUIDs must be individually matched with startswith.
+ # full UUIDs may be matched directly in an IN statement
+ partial_uuids = set(port_id for port_id in port_ids
+ if not uuidutils.is_uuid_like(port_id))
+ full_uuids = set(port_ids) - partial_uuids
+ or_criteria = [models_v2.Port.id.startswith(port_id)
+ for port_id in partial_uuids]
+ if full_uuids:
+ or_criteria.append(models_v2.Port.id.in_(full_uuids))
+
query = session.query(models_v2.Port,
sg_db.SecurityGroupPortBinding.security_group_id)
query = query.outerjoin(sg_db.SecurityGroupPortBinding,
models_v2.Port.id == sg_binding_port)
- query = query.filter(models_v2.Port.id.startswith(port_id))
- port_and_sgs = query.all()
- if not port_and_sgs:
- return
- port = port_and_sgs[0][0]
- plugin = manager.NeutronManager.get_plugin()
- port_dict = plugin._make_port_dict(port)
- port_dict['security_groups'] = [
- sg_id for port_, sg_id in port_and_sgs if sg_id]
- port_dict['security_group_rules'] = []
- port_dict['security_group_source_groups'] = []
- port_dict['fixed_ips'] = [ip['ip_address']
- for ip in port['fixed_ips']]
- return port_dict
+ query = query.filter(or_(*or_criteria))
+
+ for port, sg_id in query:
+ if sg_id:
+ sg_ids_grouped_by_port[port].append(sg_id)
+ return sg_ids_grouped_by_port
+
+
+def make_port_dict_with_security_groups(port, sec_groups):
+ plugin = manager.NeutronManager.get_plugin()
+ port_dict = plugin._make_port_dict(port)
+ port_dict['security_groups'] = sec_groups
+ port_dict['security_group_rules'] = []
+ port_dict['security_group_source_groups'] = []
+ port_dict['fixed_ips'] = [ip['ip_address']
+ for ip in port['fixed_ips']]
+ return port_dict
def get_port_binding_host(port_id):
port_host = db.get_port_binding_host(port_id)
return (port_host == host)
- def get_port_from_device(self, device):
- port_id = self._device_to_port_id(device)
- port = db.get_port_and_sgs(port_id)
- if port:
- port['device'] = device
- return port
+ def get_ports_from_devices(self, devices):
+ port_ids_to_devices = dict((self._device_to_port_id(device), device)
+ for device in devices)
+ port_ids = port_ids_to_devices.keys()
+ ports = db.get_ports_and_sgs(port_ids)
+ for port in ports:
+ # map back to original requested id
+ port_id = next((port_id for port_id in port_ids
+ if port['id'].startswith(port_id)), None)
+ port['device'] = port_ids_to_devices.get(port_id)
+
+ return ports
def _device_to_port_id(self, device):
# REVISIT(rkukura): Consider calling into MechanismDrivers to
# License for the specific language governing permissions and limitations
# under the License.
+import contextlib
+import math
import mock
from neutron.api.v2 import attributes
+from neutron.common import constants as const
from neutron.extensions import securitygroup as ext_sg
from neutron import manager
+from neutron.tests.unit import test_api_v2
from neutron.tests.unit import test_extension_security_group as test_sg
from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc
plugin = manager.NeutronManager.get_plugin()
plugin.start_rpc_listeners()
- def test_security_group_get_port_from_device(self):
+ def _make_port_with_new_sec_group(self, net_id):
+ sg = self._make_security_group(self.fmt, 'name', 'desc')
+ port = self._make_port(
+ self.fmt, net_id, security_groups=[sg['security_group']['id']])
+ return port['port']
+
+ def test_security_group_get_ports_from_devices(self):
with self.network() as n:
with self.subnet(n):
- with self.security_group() as sg:
- security_group_id = sg['security_group']['id']
- res = self._create_port(self.fmt, n['network']['id'])
- port = self.deserialize(self.fmt, res)
- fixed_ips = port['port']['fixed_ips']
- data = {'port': {'fixed_ips': fixed_ips,
- 'name': port['port']['name'],
- ext_sg.SECURITYGROUPS:
- [security_group_id]}}
-
- req = self.new_update_request('ports', data,
- port['port']['id'])
- res = self.deserialize(self.fmt,
- req.get_response(self.api))
- port_id = res['port']['id']
- plugin = manager.NeutronManager.get_plugin()
- port_dict = plugin.get_port_from_device(port_id)
- self.assertEqual(port_id, port_dict['id'])
- self.assertEqual([security_group_id],
+ port1 = self._make_port_with_new_sec_group(n['network']['id'])
+ port2 = self._make_port_with_new_sec_group(n['network']['id'])
+ plugin = manager.NeutronManager.get_plugin()
+ # should match full ID and starting chars
+ ports = plugin.get_ports_from_devices(
+ [port1['id'], port2['id'][0:8]])
+ self.assertEqual(2, len(ports))
+ for port_dict in ports:
+ p = port1 if port1['id'] == port_dict['id'] else port2
+ self.assertEqual(p['id'], port_dict['id'])
+ self.assertEqual(p['security_groups'],
port_dict[ext_sg.SECURITYGROUPS])
self.assertEqual([], port_dict['security_group_rules'])
- self.assertEqual([fixed_ips[0]['ip_address']],
+ self.assertEqual([p['fixed_ips'][0]['ip_address']],
port_dict['fixed_ips'])
- self._delete('ports', port_id)
+ self._delete('ports', p['id'])
+
+ def test_security_group_get_ports_from_devices_with_bad_id(self):
+ plugin = manager.NeutronManager.get_plugin()
+ ports = plugin.get_ports_from_devices(['bad_device_id'])
+ self.assertFalse(ports)
- def test_security_group_get_port_from_device_with_no_port(self):
+ def test_security_group_no_db_calls_with_no_ports(self):
+ plugin = manager.NeutronManager.get_plugin()
+ with mock.patch(
+ 'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
+ ) as get_mock:
+ self.assertFalse(plugin.get_ports_from_devices([]))
+ self.assertFalse(get_mock.called)
+
+ def test_large_port_count_broken_into_parts(self):
+ plugin = manager.NeutronManager.get_plugin()
+ max_ports_per_query = 5
+ ports_to_query = 73
+ for max_ports_per_query in (1, 2, 5, 7, 9, 31):
+ with contextlib.nested(
+ mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY',
+ new=max_ports_per_query),
+ mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
+ return_value={}),
+ ) as (max_mock, get_mock):
+ plugin.get_ports_from_devices(
+ ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
+ for i in range(ports_to_query)])
+ all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
+ last_call_args = all_call_args.pop()
+ # all but last should be getting MAX_PORTS_PER_QUERY ports
+ self.assertTrue(
+ all(map(lambda x: len(x) == max_ports_per_query,
+ all_call_args))
+ )
+ remaining = ports_to_query % max_ports_per_query
+ if remaining:
+ self.assertEqual(remaining, len(last_call_args))
+ # should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls
+ self.assertEqual(
+ math.ceil(ports_to_query / float(max_ports_per_query)),
+ get_mock.call_count
+ )
+
+ def test_full_uuids_skip_port_id_lookup(self):
plugin = manager.NeutronManager.get_plugin()
- port_dict = plugin.get_port_from_device('bad_device_id')
- self.assertIsNone(port_dict)
+ # when full UUIDs are provided, the _or statement should only
+ # have one matching 'IN' critiera for all of the IDs
+ with contextlib.nested(
+ mock.patch('neutron.plugins.ml2.db.or_'),
+ mock.patch('neutron.plugins.ml2.db.db_api.get_session')
+ ) as (or_mock, sess_mock):
+ fmock = sess_mock.query.return_value.outerjoin.return_value.filter
+ # return no ports to exit the method early since we are mocking
+ # the query
+ fmock.return_value.all.return_value = []
+ plugin.get_ports_from_devices([test_api_v2._uuid(),
+ test_api_v2._uuid()])
+ # the or_ function should only have one argument
+ or_mock.assert_called_once_with(mock.ANY)
class TestMl2SecurityGroupsXML(TestMl2SecurityGroups):