]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Removed eval of unchecked strings.
authorMaru Newby <mnewby@internap.com>
Tue, 18 Sep 2012 01:46:38 +0000 (18:46 -0700)
committerMaru Newby <mnewby@internap.com>
Mon, 17 Sep 2012 16:11:43 +0000 (09:11 -0700)
 * eval() was previously used to marshall unchecked strings as
   filter parameters for QuantumDbPluginV2.get_ports() via
   the --fixed_ips flag.
 * This change removes the use of eval and cleans up the filtering
   implementation for get_ports().
 * The new filtering implementation does not support arbitrary
   OR'ing or AND'ing.  Instead, multiple values for a given filter
   key are logically OR'ed, and filters across keys are AND'ed.
 * Example usage - filter for .2 or .3 in the given subnet:

     quantum port-list -- --fixed_ips ip_address=10.0.0.3 \
         ip_address=10.0.0.2 subnet_id=nOtaRealId

 * Addresses bug 1052179

Change-Id: I451f33ae53e623f86015b3fc2e6a7ca2f51ee836

quantum/api/v2/attributes.py
quantum/api/v2/base.py
quantum/db/db_base_plugin_v2.py
quantum/tests/unit/test_api_v2.py
quantum/tests/unit/test_attributes.py
quantum/tests/unit/test_db_plugin.py

index 1425375faf4e24afbc4fb02e51b7c27da7262491..ff834169644259e1139d395259612e474f80ba4d 100644 (file)
@@ -126,6 +126,38 @@ def convert_to_boolean(data):
     msg = _("%s is not boolean") % data
     raise q_exc.InvalidInput(error_message=msg)
 
+
+def convert_kvp_str_to_list(data):
+    """Convert a value of the form 'key=value' to ['key', 'value'].
+
+    :raises: q_exc.InvalidInput if any of the strings are malformed
+                                (e.g. do not contain a key).
+    """
+    kvp = [x.strip() for x in data.split('=', 1)]
+    if len(kvp) == 2 and kvp[0]:
+        return kvp
+    msg = _("'%s' is not of the form <key>=[value]") % data
+    raise q_exc.InvalidInput(error_message=msg)
+
+
+def convert_kvp_list_to_dict(kvp_list):
+    """Convert a list of 'key=value' strings to a dict.
+
+    :raises: q_exc.InvalidInput if any of the strings are malformed
+                                (e.g. do not contain a key) or if any
+                                of the keys appear more than once.
+    """
+    if kvp_list == ['True']:
+        # No values were provided (i.e. '--flag-name')
+        return {}
+    kvp_map = {}
+    for kvp_str in kvp_list:
+        key, value = convert_kvp_str_to_list(kvp_str)
+        kvp_map.setdefault(key, set())
+        kvp_map[key].add(value)
+    return dict((x, list(y)) for x, y in kvp_map.iteritems())
+
+
 HEX_ELEM = '[0-9A-Fa-f]'
 UUID_PATTERN = '-'.join([HEX_ELEM + '{8}', HEX_ELEM + '{4}',
                          HEX_ELEM + '{4}', HEX_ELEM + '{4}',
@@ -218,6 +250,7 @@ RESOURCE_ATTRIBUTE_MAP = {
                         'is_visible': True},
         'fixed_ips': {'allow_post': True, 'allow_put': True,
                       'default': ATTR_NOT_SPECIFIED,
+                      'convert_list_to': convert_kvp_list_to_dict,
                       'enforce_policy': True,
                       'is_visible': True},
         'device_id': {'allow_post': True, 'allow_put': True,
index c23bcd11942b96bff55467e440cb52b8906b5ebf..58027da6c54c378c8640cee56c002420d9b9f207 100644 (file)
@@ -82,20 +82,23 @@ def _filters(request, attr_info):
         if key == 'fields':
             continue
         values = [v for v in request.GET.getall(key) if v]
-        if not attr_info.get(key) and values:
+        key_attr_info = attr_info.get(key, {})
+        if not key_attr_info and values:
             res[key] = values
             continue
-        result_values = []
-        convert_to = (attr_info.get(key) and attr_info[key].get('convert_to')
-                      or None)
-        for value in values:
+        convert_list_to = key_attr_info.get('convert_list_to')
+        if not convert_list_to:
+            convert_to = key_attr_info.get('convert_to')
             if convert_to:
-                try:
-                    result_values.append(convert_to(value))
-                except exceptions.InvalidInput as e:
-                    raise webob.exc.HTTPBadRequest(str(e))
-            else:
-                result_values.append(value)
+                convert_list_to = lambda values_: [convert_to(x)
+                                                   for x in values_]
+        if convert_list_to:
+            try:
+                result_values = convert_list_to(values)
+            except exceptions.InvalidInput as e:
+                raise webob.exc.HTTPBadRequest(str(e))
+        else:
+            result_values = values
         if result_values:
             res[key] = result_values
     return res
index ab9ba928e6d5ea12f35bfa3d9f9e67a16a845571..5940487652cfb4c935f80d1445911b600edc948b 100644 (file)
@@ -200,14 +200,18 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2):
                          if key in fields))
         return resource
 
-    def _get_collection(self, context, model, dict_func, filters=None,
-                        fields=None):
-        collection = self._model_query(context, model)
+    def _apply_filters_to_query(self, query, model, filters):
         if filters:
             for key, value in filters.iteritems():
                 column = getattr(model, key, None)
                 if column:
-                    collection = collection.filter(column.in_(value))
+                    query = query.filter(column.in_(value))
+        return query
+
+    def _get_collection(self, context, model, dict_func, filters=None,
+                        fields=None):
+        collection = self._model_query(context, model)
+        collection = self._apply_filters_to_query(collection, model, filters)
         return [dict_func(c, fields) for c in collection.all()]
 
     @staticmethod
@@ -1236,35 +1240,23 @@ class QuantumDbPluginV2(quantum_plugin_base_v2.QuantumPluginBaseV2):
         return self._make_port_dict(port, fields)
 
     def get_ports(self, context, filters=None, fields=None):
-        fixed_ips = filters.pop('fixed_ips', []) if filters else []
-        ports = self._get_collection(context, models_v2.Port,
-                                     self._make_port_dict,
-                                     filters=filters, fields=fields)
-
-        if ports and fixed_ips:
-            filtered_ports = []
-            for port in ports:
-                if port['fixed_ips']:
-                    ips = port['fixed_ips']
-                    for fixed in fixed_ips:
-                        found = False
-                        # Convert to dictionary (deserialize)
-                        fixed = eval(fixed)
-                        for ip in ips:
-                            if 'ip_address' in fixed and 'subnet_id' in fixed:
-                                if (ip['ip_address'] == fixed['ip_address'] and
-                                        ip['subnet_id'] == fixed['subnet_id']):
-                                    found = True
-                            elif 'ip_address' in fixed:
-                                if ip['ip_address'] == fixed['ip_address']:
-                                    found = True
-                            elif 'subnet_id' in fixed:
-                                if ip['subnet_id'] == fixed['subnet_id']:
-                                    found = True
-                            if found:
-                                filtered_ports.append(port)
-                                break
-                        if found:
-                            break
-            return filtered_ports
-        return ports
+        Port = models_v2.Port
+        IPAllocation = models_v2.IPAllocation
+
+        if not filters:
+            filters = {}
+
+        query = self._model_query(context, Port)
+
+        fixed_ips = filters.pop('fixed_ips', {})
+        ip_addresses = fixed_ips.get('ip_address')
+        subnet_ids = fixed_ips.get('subnet_id')
+        if ip_addresses or subnet_ids:
+            query = query.join(Port.fixed_ips)
+            if ip_addresses:
+                query = query.filter(IPAllocation.ip_address.in_(ip_addresses))
+            if subnet_ids:
+                query = query.filter(IPAllocation.subnet_id.in_(subnet_ids))
+
+        query = self._apply_filters_to_query(query, Port, filters)
+        return [self._make_port_dict(c, fields) for c in query.all()]
index af23f157e56cdbd40cf9e0e42c3fd0490efc8171..efa46e5e271ddbe384c27ad666544bef4f664a06 100644 (file)
@@ -293,6 +293,27 @@ class APIv2TestCase(APIv2TestBase):
                                                       filters=filters,
                                                       fields=fields)
 
+    def test_filters_with_convert_to(self):
+        instance = self.plugin.return_value
+        instance.get_ports.return_value = []
+
+        self.api.get(_get_path('ports'), {'admin_state_up': 'true'})
+        filters = {'admin_state_up': [True]}
+        instance.get_ports.assert_called_once_with(mock.ANY,
+                                                   filters=filters,
+                                                   fields=mock.ANY)
+
+    def test_filters_with_convert_list_to(self):
+        instance = self.plugin.return_value
+        instance.get_ports.return_value = []
+
+        self.api.get(_get_path('ports'),
+                     {'fixed_ips': ['ip_address=foo', 'subnet_id=bar']})
+        filters = {'fixed_ips': {'ip_address': ['foo'], 'subnet_id': ['bar']}}
+        instance.get_ports.assert_called_once_with(mock.ANY,
+                                                   filters=filters,
+                                                   fields=mock.ANY)
+
 
 # Note: since all resources use the same controller and validation
 # logic, we actually get really good coverage from testing just networks.
index bb788b455a4bcd4294be40ac00c5b318c9b15426..b4704a6b51c8c1ac6ce7b6063421495b32c3007b 100644 (file)
@@ -18,6 +18,7 @@
 import unittest2
 
 from quantum.api.v2 import attributes
+from quantum.common import exceptions as q_exc
 
 
 class TestAttributes(unittest2.TestCase):
@@ -83,3 +84,35 @@ class TestAttributes(unittest2.TestCase):
                                          attributes.MAC_PATTERN)
         error = '%s is not valid' % base_mac
         self.assertEquals(msg, error)
+
+
+class TestConvertKvp(unittest2.TestCase):
+
+    def test_convert_kvp_list_to_dict_succeeds_for_missing_values(self):
+        result = attributes.convert_kvp_list_to_dict(['True'])
+        self.assertEqual({}, result)
+
+    def test_convert_kvp_list_to_dict_succeeds_for_multiple_values(self):
+        result = attributes.convert_kvp_list_to_dict(
+            ['a=b', 'a=c', 'a=c', 'b=a'])
+        self.assertEqual({'a': ['c', 'b'], 'b': ['a']}, result)
+
+    def test_convert_kvp_list_to_dict_succeeds_for_values(self):
+        result = attributes.convert_kvp_list_to_dict(['a=b', 'c=d'])
+        self.assertEqual({'a': ['b'], 'c': ['d']}, result)
+
+    def test_convert_kvp_str_to_list_fails_for_missing_key(self):
+        with self.assertRaises(q_exc.InvalidInput):
+            attributes.convert_kvp_str_to_list('=a')
+
+    def test_convert_kvp_str_to_list_fails_for_missing_equals(self):
+        with self.assertRaises(q_exc.InvalidInput):
+            attributes.convert_kvp_str_to_list('a')
+
+    def test_convert_kvp_str_to_list_succeeds_for_one_equals(self):
+        result = attributes.convert_kvp_str_to_list('a=')
+        self.assertEqual(['a', ''], result)
+
+    def test_convert_kvp_str_to_list_succeeds_for_two_equals(self):
+        result = attributes.convert_kvp_str_to_list('a=a=a')
+        self.assertEqual(['a', 'a=a'], result)
index eb69d2531827c92e8efce055fdf59fcb93de4ed3..4fb3897ceb3b74fd155b0288460bd9d2657b9137 100644 (file)
@@ -717,6 +717,19 @@ class TestPortsV2(QuantumDbPluginV2TestCase):
             self.assertTrue(port1['port']['id'] in ids)
             self.assertTrue(port2['port']['id'] in ids)
 
+    def test_list_ports_filtered_by_fixed_ip(self):
+        with contextlib.nested(self.port(), self.port()) as (port1, port2):
+            fixed_ips = port1['port']['fixed_ips'][0]
+            query_params = """
+fixed_ips=ip_address%%3D%s&fixed_ips=ip_address%%3D%s&fixed_ips=subnet_id%%3D%s
+""".strip() % (fixed_ips['ip_address'],
+               '192.168.126.5',
+               fixed_ips['subnet_id'])
+            req = self.new_list_request('ports', 'json', query_params)
+            port_list = self.deserialize('json', req.get_response(self.api))
+            self.assertEqual(len(port_list['ports']), 1)
+            self.assertEqual(port_list['ports'][0]['id'], port1['port']['id'])
+
     def test_list_ports_public_network(self):
         with self.network(shared=True) as network:
             portres_1 = self._create_port('json',