From a6fbe838954ebc9432da58322ddece3ae47178c0 Mon Sep 17 00:00:00 2001 From: Maru Newby Date: Mon, 17 Sep 2012 18:46:38 -0700 Subject: [PATCH] Removed eval of unchecked strings. * eval() was previously used to marshall unchecked strings as filter parameters for QuantumDbPluginV2.get_ports() via the --fixed_ips flag. * This change removes the use of eval and cleans up the filtering implementation for get_ports(). * The new filtering implementation does not support arbitrary OR'ing or AND'ing. Instead, multiple values for a given filter key are logically OR'ed, and filters across keys are AND'ed. * Example usage - filter for .2 or .3 in the given subnet: quantum port-list -- --fixed_ips ip_address=10.0.0.3 \ ip_address=10.0.0.2 subnet_id=nOtaRealId * Addresses bug 1052179 Change-Id: I451f33ae53e623f86015b3fc2e6a7ca2f51ee836 --- quantum/api/v2/attributes.py | 33 ++++++++++++++ quantum/api/v2/base.py | 25 ++++++----- quantum/db/db_base_plugin_v2.py | 64 ++++++++++++--------------- quantum/tests/unit/test_api_v2.py | 21 +++++++++ quantum/tests/unit/test_attributes.py | 33 ++++++++++++++ quantum/tests/unit/test_db_plugin.py | 13 ++++++ 6 files changed, 142 insertions(+), 47 deletions(-) diff --git a/quantum/api/v2/attributes.py b/quantum/api/v2/attributes.py index 1425375fa..ff8341696 100644 --- a/quantum/api/v2/attributes.py +++ b/quantum/api/v2/attributes.py @@ -126,6 +126,38 @@ def convert_to_boolean(data): msg = _("%s is not boolean") % data raise q_exc.InvalidInput(error_message=msg) + +def convert_kvp_str_to_list(data): + """Convert a value of the form 'key=value' to ['key', 'value']. + + :raises: q_exc.InvalidInput if any of the strings are malformed + (e.g. do not contain a key). + """ + kvp = [x.strip() for x in data.split('=', 1)] + if len(kvp) == 2 and kvp[0]: + return kvp + msg = _("'%s' is not of the form =[value]") % data + raise q_exc.InvalidInput(error_message=msg) + + +def convert_kvp_list_to_dict(kvp_list): + """Convert a list of 'key=value' strings to a dict. + + :raises: q_exc.InvalidInput if any of the strings are malformed + (e.g. do not contain a key) or if any + of the keys appear more than once. + """ + if kvp_list == ['True']: + # No values were provided (i.e. '--flag-name') + return {} + kvp_map = {} + for kvp_str in kvp_list: + key, value = convert_kvp_str_to_list(kvp_str) + kvp_map.setdefault(key, set()) + kvp_map[key].add(value) + return dict((x, list(y)) for x, y in kvp_map.iteritems()) + + HEX_ELEM = '[0-9A-Fa-f]' UUID_PATTERN = '-'.join([HEX_ELEM + '{8}', HEX_ELEM + '{4}', HEX_ELEM + '{4}', HEX_ELEM + '{4}', @@ -218,6 +250,7 @@ RESOURCE_ATTRIBUTE_MAP = { 'is_visible': True}, 'fixed_ips': {'allow_post': True, 'allow_put': True, 'default': ATTR_NOT_SPECIFIED, + 'convert_list_to': convert_kvp_list_to_dict, 'enforce_policy': True, 'is_visible': True}, 'device_id': {'allow_post': True, 'allow_put': True, diff --git a/quantum/api/v2/base.py b/quantum/api/v2/base.py index c23bcd119..58027da6c 100644 --- a/quantum/api/v2/base.py +++ b/quantum/api/v2/base.py @@ -82,20 +82,23 @@ def _filters(request, attr_info): if key == 'fields': continue values = [v for v in request.GET.getall(key) if v] - if not attr_info.get(key) and values: + key_attr_info = attr_info.get(key, {}) + if not key_attr_info and values: res[key] = values continue - result_values = [] - convert_to = (attr_info.get(key) and attr_info[key].get('convert_to') - or None) - for value in values: + convert_list_to = key_attr_info.get('convert_list_to') + if not convert_list_to: + convert_to = key_attr_info.get('convert_to') if convert_to: - try: - result_values.append(convert_to(value)) - except exceptions.InvalidInput as e: - raise webob.exc.HTTPBadRequest(str(e)) - else: - result_values.append(value) + convert_list_to = lambda values_: [convert_to(x) + for x in values_] + if convert_list_to: + try: + result_values = convert_list_to(values) + except exceptions.InvalidInput as e: + raise webob.exc.HTTPBadRequest(str(e)) + else: + result_values = values if result_values: res[key] = result_values return res diff --git a/quantum/db/db_base_plugin_v2.py b/quantum/db/db_base_plugin_v2.py index ab9ba928e..594048765 100644 --- a/quantum/db/db_base_plugin_v2.py +++ b/quantum/db/db_base_plugin_v2.py @@ -200,14 +200,18 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): if key in fields)) return resource - def _get_collection(self, context, model, dict_func, filters=None, - fields=None): - collection = self._model_query(context, model) + def _apply_filters_to_query(self, query, model, filters): if filters: for key, value in filters.iteritems(): column = getattr(model, key, None) if column: - collection = collection.filter(column.in_(value)) + query = query.filter(column.in_(value)) + return query + + def _get_collection(self, context, model, dict_func, filters=None, + fields=None): + collection = self._model_query(context, model) + collection = self._apply_filters_to_query(collection, model, filters) return [dict_func(c, fields) for c in collection.all()] @staticmethod @@ -1236,35 +1240,23 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2): return self._make_port_dict(port, fields) def get_ports(self, context, filters=None, fields=None): - fixed_ips = filters.pop('fixed_ips', []) if filters else [] - ports = self._get_collection(context, models_v2.Port, - self._make_port_dict, - filters=filters, fields=fields) - - if ports and fixed_ips: - filtered_ports = [] - for port in ports: - if port['fixed_ips']: - ips = port['fixed_ips'] - for fixed in fixed_ips: - found = False - # Convert to dictionary (deserialize) - fixed = eval(fixed) - for ip in ips: - if 'ip_address' in fixed and 'subnet_id' in fixed: - if (ip['ip_address'] == fixed['ip_address'] and - ip['subnet_id'] == fixed['subnet_id']): - found = True - elif 'ip_address' in fixed: - if ip['ip_address'] == fixed['ip_address']: - found = True - elif 'subnet_id' in fixed: - if ip['subnet_id'] == fixed['subnet_id']: - found = True - if found: - filtered_ports.append(port) - break - if found: - break - return filtered_ports - return ports + Port = models_v2.Port + IPAllocation = models_v2.IPAllocation + + if not filters: + filters = {} + + query = self._model_query(context, Port) + + fixed_ips = filters.pop('fixed_ips', {}) + ip_addresses = fixed_ips.get('ip_address') + subnet_ids = fixed_ips.get('subnet_id') + if ip_addresses or subnet_ids: + query = query.join(Port.fixed_ips) + if ip_addresses: + query = query.filter(IPAllocation.ip_address.in_(ip_addresses)) + if subnet_ids: + query = query.filter(IPAllocation.subnet_id.in_(subnet_ids)) + + query = self._apply_filters_to_query(query, Port, filters) + return [self._make_port_dict(c, fields) for c in query.all()] diff --git a/quantum/tests/unit/test_api_v2.py b/quantum/tests/unit/test_api_v2.py index af23f157e..efa46e5e2 100644 --- a/quantum/tests/unit/test_api_v2.py +++ b/quantum/tests/unit/test_api_v2.py @@ -293,6 +293,27 @@ class APIv2TestCase(APIv2TestBase): filters=filters, fields=fields) + def test_filters_with_convert_to(self): + instance = self.plugin.return_value + instance.get_ports.return_value = [] + + self.api.get(_get_path('ports'), {'admin_state_up': 'true'}) + filters = {'admin_state_up': [True]} + instance.get_ports.assert_called_once_with(mock.ANY, + filters=filters, + fields=mock.ANY) + + def test_filters_with_convert_list_to(self): + instance = self.plugin.return_value + instance.get_ports.return_value = [] + + self.api.get(_get_path('ports'), + {'fixed_ips': ['ip_address=foo', 'subnet_id=bar']}) + filters = {'fixed_ips': {'ip_address': ['foo'], 'subnet_id': ['bar']}} + instance.get_ports.assert_called_once_with(mock.ANY, + filters=filters, + fields=mock.ANY) + # Note: since all resources use the same controller and validation # logic, we actually get really good coverage from testing just networks. diff --git a/quantum/tests/unit/test_attributes.py b/quantum/tests/unit/test_attributes.py index bb788b455..b4704a6b5 100644 --- a/quantum/tests/unit/test_attributes.py +++ b/quantum/tests/unit/test_attributes.py @@ -18,6 +18,7 @@ import unittest2 from quantum.api.v2 import attributes +from quantum.common import exceptions as q_exc class TestAttributes(unittest2.TestCase): @@ -83,3 +84,35 @@ class TestAttributes(unittest2.TestCase): attributes.MAC_PATTERN) error = '%s is not valid' % base_mac self.assertEquals(msg, error) + + +class TestConvertKvp(unittest2.TestCase): + + def test_convert_kvp_list_to_dict_succeeds_for_missing_values(self): + result = attributes.convert_kvp_list_to_dict(['True']) + self.assertEqual({}, result) + + def test_convert_kvp_list_to_dict_succeeds_for_multiple_values(self): + result = attributes.convert_kvp_list_to_dict( + ['a=b', 'a=c', 'a=c', 'b=a']) + self.assertEqual({'a': ['c', 'b'], 'b': ['a']}, result) + + def test_convert_kvp_list_to_dict_succeeds_for_values(self): + result = attributes.convert_kvp_list_to_dict(['a=b', 'c=d']) + self.assertEqual({'a': ['b'], 'c': ['d']}, result) + + def test_convert_kvp_str_to_list_fails_for_missing_key(self): + with self.assertRaises(q_exc.InvalidInput): + attributes.convert_kvp_str_to_list('=a') + + def test_convert_kvp_str_to_list_fails_for_missing_equals(self): + with self.assertRaises(q_exc.InvalidInput): + attributes.convert_kvp_str_to_list('a') + + def test_convert_kvp_str_to_list_succeeds_for_one_equals(self): + result = attributes.convert_kvp_str_to_list('a=') + self.assertEqual(['a', ''], result) + + def test_convert_kvp_str_to_list_succeeds_for_two_equals(self): + result = attributes.convert_kvp_str_to_list('a=a=a') + self.assertEqual(['a', 'a=a'], result) diff --git a/quantum/tests/unit/test_db_plugin.py b/quantum/tests/unit/test_db_plugin.py index eb69d2531..4fb3897ce 100644 --- a/quantum/tests/unit/test_db_plugin.py +++ b/quantum/tests/unit/test_db_plugin.py @@ -717,6 +717,19 @@ class TestPortsV2(QuantumDbPluginV2TestCase): self.assertTrue(port1['port']['id'] in ids) self.assertTrue(port2['port']['id'] in ids) + def test_list_ports_filtered_by_fixed_ip(self): + with contextlib.nested(self.port(), self.port()) as (port1, port2): + fixed_ips = port1['port']['fixed_ips'][0] + query_params = """ +fixed_ips=ip_address%%3D%s&fixed_ips=ip_address%%3D%s&fixed_ips=subnet_id%%3D%s +""".strip() % (fixed_ips['ip_address'], + '192.168.126.5', + fixed_ips['subnet_id']) + req = self.new_list_request('ports', 'json', query_params) + port_list = self.deserialize('json', req.get_response(self.api)) + self.assertEqual(len(port_list['ports']), 1) + self.assertEqual(port_list['ports'][0]['id'], port1['port']['id']) + def test_list_ports_public_network(self): with self.network(shared=True) as network: portres_1 = self._create_port('json', -- 2.45.2