]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Fix get_objects to allow filtering
authorNir Magnezi <nmagnezi@redhat.com>
Sun, 2 Aug 2015 12:56:56 +0000 (08:56 -0400)
committerJohn Schwarz <jschwarz@redhat.com>
Sun, 9 Aug 2015 11:57:52 +0000 (14:57 +0300)
At the moment, an attempt to retrieve a list of objects (like qos
policy) and filter by name fails, because get_objects does not use
filters and therefore, upon query by object name, the server replies
with a list of all created objects (instead of a partial list).

Change-Id: I9df9981129b8f3b82e867c8423986f5e0150186b
Partially-Implements: blueprint quantum-qos-api

neutron/objects/base.py
neutron/objects/qos/policy.py
neutron/objects/qos/rule_type.py
neutron/services/qos/qos_plugin.py
neutron/tests/api/test_qos.py
neutron/tests/tempest/services/network/json/network_client.py
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/objects/test_base.py

index 230f53dcdeeab76cd05c8609663bb11d74f1c5a1..c4bb98f5672aa811b83aea28e35c9060088549a4 100644 (file)
@@ -41,6 +41,8 @@ class NeutronObject(obj_base.VersionedObject,
                     obj_base.VersionedObjectDictCompat,
                     obj_base.ComparableVersionedObject):
 
+    synthetic_fields = []
+
     def __init__(self, context=None, **kwargs):
         super(NeutronObject, self).__init__(context, **kwargs)
         self.obj_set_defaults()
@@ -58,6 +60,15 @@ class NeutronObject(obj_base.VersionedObject,
     def get_by_id(cls, context, id):
         raise NotImplementedError()
 
+    @classmethod
+    def validate_filters(cls, **kwargs):
+        bad_filters = [key for key in kwargs
+                       if key not in cls.fields or key in cls.synthetic_fields]
+        if bad_filters:
+            bad_filters = ', '.join(bad_filters)
+            msg = _("'%s' is not supported for filtering") % bad_filters
+            raise exceptions.InvalidInput(error_message=msg)
+
     @classmethod
     @abc.abstractmethod
     def get_objects(cls, context, **kwargs):
@@ -78,8 +89,6 @@ class NeutronDbObject(NeutronObject):
     # should be overridden for all persistent objects
     db_model = None
 
-    synthetic_fields = []
-
     fields_no_update = []
 
     def from_db_object(self, *objs):
@@ -100,6 +109,7 @@ class NeutronDbObject(NeutronObject):
 
     @classmethod
     def get_objects(cls, context, **kwargs):
+        cls.validate_filters(**kwargs)
         db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
         objs = [cls(context, **db_obj) for db_obj in db_objs]
         for obj in objs:
index 96d1536e8da56afb6bac7504e509ee3fedde999b..258512221fe39f9bc4b766a278f7d6daffc4b202 100644 (file)
@@ -92,15 +92,15 @@ class QosPolicy(base.NeutronDbObject):
         # sure the tenant has permission to access the policy later on.
         admin_context = context.elevated()
         with db_api.autonested_transaction(admin_context.session):
-            db_objs = db_api.get_objects(admin_context, cls.db_model, **kwargs)
-            objs = []
-            for db_obj in db_objs:
-                if not cls._is_policy_accessible(context, db_obj):
+            objs = super(QosPolicy, cls).get_objects(admin_context,
+                                                     **kwargs)
+            result = []
+            for obj in objs:
+                if not cls._is_policy_accessible(context, obj):
                     continue
-                obj = cls(context, **db_obj)
                 obj.reload_rules()
-                objs.append(obj)
-        return objs
+                result.append(obj)
+            return result
 
     @classmethod
     def _get_object_policy(cls, context, model, **kwargs):
index 1a009b559c86deb6e556fb11271b60460ec09c62..fb0754b9394f7e74b02062d1bb1ec60267663022 100644 (file)
@@ -36,6 +36,7 @@ class QosRuleType(base.NeutronObject):
     # we don't receive context because we don't need db access at all
     @classmethod
     def get_objects(cls, **kwargs):
+        cls.validate_filters(**kwargs)
         core_plugin = manager.NeutronManager.get_plugin()
         return [cls(type=type_)
                 for type_ in core_plugin.supported_qos_rule_types]
index 7111c4e94b3df131227dce27e053b83091df497b..331ec56fd92997a48538d8287be5075d245b73b3 100644 (file)
@@ -80,8 +80,7 @@ class QoSPlugin(qos.QoSPluginBase):
     def get_policies(self, context, filters=None, fields=None,
                      sorts=None, limit=None, marker=None,
                      page_reverse=False):
-        #TODO(QoS): Support all the optional parameters
-        return policy_object.QosPolicy.get_objects(context)
+        return policy_object.QosPolicy.get_objects(context, **filters)
 
     #TODO(QoS): Consider adding a proxy catch-all for rules, so
     #           we capture the API function call, and just pass
@@ -148,12 +147,12 @@ class QoSPlugin(qos.QoSPluginBase):
                                          filters=None, fields=None,
                                          sorts=None, limit=None,
                                          marker=None, page_reverse=False):
-        #TODO(QoS): Support all the optional parameters
         # make sure we have access to the policy when fetching rules
         with db_api.autonested_transaction(context.session):
             # first, validate that we have access to the policy
             self._get_policy_obj(context, policy_id)
-            return rule_object.QosBandwidthLimitRule.get_objects(context)
+            return rule_object.QosBandwidthLimitRule.get_objects(context,
+                                                                 **filters)
 
     # TODO(QoS): enforce rule types when accessing rule objects
     @db_base_plugin_common.filter_fields
@@ -161,4 +160,4 @@ class QoSPlugin(qos.QoSPluginBase):
     def get_rule_types(self, context, filters=None, fields=None,
                        sorts=None, limit=None,
                        marker=None, page_reverse=False):
-        return rule_type_object.QosRuleType.get_objects()
+        return rule_type_object.QosRuleType.get_objects(**filters)
index c609f9437e78c6077f1737eacff25b129d7c6ee4..b4cb4cc864d508e372172d0dcceb50983cde693d 100644 (file)
@@ -34,14 +34,14 @@ class QosTestJSON(base.BaseAdminNetworkTest):
     @test.idempotent_id('108fbdf7-3463-4e47-9871-d07f3dcf5bbb')
     def test_create_policy(self):
         policy = self.create_qos_policy(name='test-policy',
-                                        description='test policy desc',
+                                        description='test policy desc1',
                                         shared=False)
 
         # Test 'show policy'
         retrieved_policy = self.admin_client.show_qos_policy(policy['id'])
         retrieved_policy = retrieved_policy['policy']
         self.assertEqual('test-policy', retrieved_policy['name'])
-        self.assertEqual('test policy desc', retrieved_policy['description'])
+        self.assertEqual('test policy desc1', retrieved_policy['description'])
         self.assertFalse(retrieved_policy['shared'])
 
         # Test 'list policies'
@@ -49,6 +49,21 @@ class QosTestJSON(base.BaseAdminNetworkTest):
         policies_ids = [p['id'] for p in policies]
         self.assertIn(policy['id'], policies_ids)
 
+    @test.attr(type='smoke')
+    @test.idempotent_id('f8d20e92-f06d-4805-b54f-230f77715815')
+    def test_list_policy_filter_by_name(self):
+        self.create_qos_policy(name='test', description='test policy',
+                               shared=False)
+        self.create_qos_policy(name='test2', description='test policy',
+                               shared=False)
+
+        policies = (self.admin_client.
+                    list_qos_policies(name='test')['policies'])
+        self.assertEqual(1, len(policies))
+
+        retrieved_policy = policies[0]
+        self.assertEqual('test', retrieved_policy['name'])
+
     @test.attr(type='smoke')
     @test.idempotent_id('8e88a54b-f0b2-4b7d-b061-a15d93c2c7d6')
     def test_policy_update(self):
@@ -56,12 +71,12 @@ class QosTestJSON(base.BaseAdminNetworkTest):
                                         description='',
                                         shared=False)
         self.admin_client.update_qos_policy(policy['id'],
-                                            description='test policy desc',
+                                            description='test policy desc2',
                                             shared=True)
 
         retrieved_policy = self.admin_client.show_qos_policy(policy['id'])
         retrieved_policy = retrieved_policy['policy']
-        self.assertEqual('test policy desc', retrieved_policy['description'])
+        self.assertEqual('test policy desc2', retrieved_policy['description'])
         self.assertTrue(retrieved_policy['shared'])
         self.assertEqual([], retrieved_policy['rules'])
 
index f811abecbb75c30ea4842995cfffb6e8cd457e40..9c5ef4aa1a22000bc674e05e0f6a22b984d75a12 100644 (file)
@@ -12,6 +12,8 @@
 
 import json
 import time
+import urllib
+
 
 from six.moves.urllib import parse
 from tempest_lib.common.utils import misc
@@ -625,8 +627,12 @@ class NetworkClientJSON(service_client.ServiceClient):
         body = json.loads(body)
         return service_client.ResponseBody(resp, body)
 
-    def list_qos_policies(self):
-        uri = '%s/qos/policies' % self.uri_prefix
+    def list_qos_policies(self, **filters):
+        if filters:
+            uri = '%s/qos/policies?%s' % (self.uri_prefix,
+                                          urllib.urlencode(filters))
+        else:
+            uri = '%s/qos/policies' % self.uri_prefix
         resp, body = self.get(uri)
         self.expected_success(200, resp.status)
         body = json.loads(body)
index 97af37bbb2f9035e6b5b8e407b1fb924ef74c1c7..6b29b06bb59af52f1335437da959d4f382acd251 100644 (file)
@@ -64,6 +64,27 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
                 admin_context, self._test_class.db_model)
         self._validate_objects(self.db_objs, objs)
 
+    def test_get_objects_valid_fields(self):
+        admin_context = self.context.elevated()
+
+        with mock.patch.object(
+            db_api, 'get_objects',
+            return_value=[self.db_obj]) as get_objects_mock:
+
+            with mock.patch.object(
+                self.context,
+                'elevated',
+                return_value=admin_context) as context_mock:
+
+                objs = self._test_class.get_objects(
+                    self.context,
+                    **self.valid_field_filter)
+                context_mock.assert_called_once_with()
+            get_objects_mock.assert_any_call(
+                admin_context, self._test_class.db_model,
+                **self.valid_field_filter)
+        self._validate_objects([self.db_obj], objs)
+
     def test_get_by_id(self):
         admin_context = self.context.elevated()
         with mock.patch.object(db_api, 'get_object',
index 14e8b1d1733821aeaeb3d718d383412d3ccf49d8..381ff8b29fc3f480eeab1b307d53e20e03d8bd1a 100644 (file)
@@ -10,6 +10,7 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import copy
 import random
 import string
 
@@ -48,6 +49,8 @@ class FakeNeutronObject(base.NeutronDbObject):
 
     fields_no_update = ['id']
 
+    synthetic_fields = ['field2']
+
 
 def _random_string(n=10):
     return ''.join(random.choice(string.ascii_lowercase) for _ in range(n))
@@ -85,6 +88,10 @@ class _BaseObjectTestCase(object):
         self.db_objs = list(self.get_random_fields() for _ in range(3))
         self.db_obj = self.db_objs[0]
 
+        valid_field = [f for f in self._test_class.fields
+                       if f not in self._test_class.synthetic_fields][0]
+        self.valid_field_filter = {valid_field: self.db_obj[valid_field]}
+
     @classmethod
     def get_random_fields(cls, obj_cls=None):
         obj_cls = obj_cls or cls._test_class
@@ -127,6 +134,53 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
         get_objects_mock.assert_called_once_with(
             self.context, self._test_class.db_model)
 
+    def test_get_objects_valid_fields(self):
+        with mock.patch.object(
+            db_api, 'get_objects',
+            return_value=[self.db_obj]) as get_objects_mock:
+
+            objs = self._test_class.get_objects(self.context,
+                                                **self.valid_field_filter)
+            self._validate_objects([self.db_obj], objs)
+
+        get_objects_mock.assert_called_with(
+            self.context, self._test_class.db_model,
+            **self.valid_field_filter)
+
+    def test_get_objects_mixed_fields(self):
+        synthetic_fields = self._test_class.synthetic_fields
+        if not synthetic_fields:
+            self.skipTest('No synthetic fields found in test class %r' %
+                          self._test_class)
+
+        filters = copy.copy(self.valid_field_filter)
+        filters[synthetic_fields[0]] = 'xxx'
+
+        with mock.patch.object(db_api, 'get_objects',
+                               return_value=self.db_objs):
+            self.assertRaises(base.exceptions.InvalidInput,
+                              self._test_class.get_objects, self.context,
+                              **filters)
+
+    def test_get_objects_synthetic_fields(self):
+        synthetic_fields = self._test_class.synthetic_fields
+        if not synthetic_fields:
+            self.skipTest('No synthetic fields found in test class %r' %
+                          self._test_class)
+
+        with mock.patch.object(db_api, 'get_objects',
+                               return_value=self.db_objs):
+            self.assertRaises(base.exceptions.InvalidInput,
+                              self._test_class.get_objects, self.context,
+                              **{synthetic_fields[0]: 'xxx'})
+
+    def test_get_objects_invalid_fields(self):
+        with mock.patch.object(db_api, 'get_objects',
+                               return_value=self.db_objs):
+            self.assertRaises(base.exceptions.InvalidInput,
+                              self._test_class.get_objects, self.context,
+                              fake_field='xxx')
+
     def _validate_objects(self, expected, observed):
         self.assertFalse(
             filter(lambda obj: not self._is_test_class(obj), observed))