]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Support qos rules and fields parameters in GET requests
authorJakub Libosvar <libosvar@redhat.com>
Tue, 21 Jul 2015 08:04:00 +0000 (08:04 +0000)
committerJakub Libosvar <libosvar@redhat.com>
Fri, 24 Jul 2015 11:47:10 +0000 (11:47 +0000)
Previously we didn't load the rules into policy object. This patch adds
loading the rules and defines bandwidth_limit_rules as a policy
resource in a single transaction. As a part of moving towards usage of
single transaction, create() and update() of rule were modified
accordingly.
Finally, we support types in GET requests in this patch.

API tests will follow in different patch.

Change-Id: I25c72aae74469b687766754bbeb749dfd1b8867c

neutron/db/db_base_plugin_common.py
neutron/extensions/qos.py
neutron/objects/qos/policy.py
neutron/objects/qos/rule.py
neutron/services/qos/qos_plugin.py
neutron/tests/unit/db/test_db_base_plugin_common.py [new file with mode: 0644]
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/objects/test_base.py

index 54257ed971cc1030d2c88c22e1169edcd50e4748..4ce5daab7b684413cd242dcd289b9046b848a0aa 100644 (file)
@@ -29,6 +29,26 @@ from neutron.db import models_v2
 LOG = logging.getLogger(__name__)
 
 
+def filter_fields(f):
+    @functools.wraps(f)
+    def inner_filter(*args, **kwargs):
+        result = f(*args, **kwargs)
+        fields = kwargs.get('fields')
+        if not fields:
+            pos = f.func_code.co_varnames.index('fields')
+            try:
+                fields = args[pos]
+            except IndexError:
+                return result
+
+        do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
+        if isinstance(result, list):
+            return [do_filter(obj) for obj in result]
+        else:
+            return do_filter(result)
+    return inner_filter
+
+
 class DbBasePluginCommon(common_db_mixin.CommonDbMixin):
     """Stores getters and helper methods for db_base_plugin_v2
 
index e845e533435508a2591784e4541268abdf32b309..1c89acac1158217d8190ff053d14123f5dfae030 100644 (file)
@@ -61,7 +61,9 @@ RESOURCE_ATTRIBUTE_MAP = {
                    'convert_to': attr.convert_to_boolean},
         'tenant_id': {'allow_post': True, 'allow_put': False,
                       'required_by_policy': True,
-                      'is_visible': True}
+                      'is_visible': True},
+        'bandwidth_limit_rules': {'allow_post': False, 'allow_put': False,
+                                  'is_visible': True},
     },
     'rule_types': {
         'type': {'allow_post': False, 'allow_put': False,
index 0c1718ef4865fb50e570e27b941eaeb302c3412f..8f2c605c8e021a362dccf43286e50514617ca832 100644 (file)
@@ -75,12 +75,37 @@ class QosPolicy(base.NeutronObject):
         setattr(self, attrname, rules)
         self.obj_reset_changes([attrname])
 
+    def _load_rules(self):
+        for attr in self.rule_fields:
+            self.obj_load_attr(attr)
+
+    @classmethod
+    def get_by_id(cls, context, id):
+        with db_api.autonested_transaction(context.session):
+            policy_obj = super(QosPolicy, cls).get_by_id(context, id)
+            if policy_obj:
+                policy_obj._load_rules()
+        return policy_obj
+
+    # TODO(QoS): Test that all objects are fetched within one transaction
+    @classmethod
+    def get_objects(cls, context, **kwargs):
+        with db_api.autonested_transaction(context.session):
+            db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
+            objs = list()
+            for db_obj in db_objs:
+                obj = cls(context, **db_obj)
+                obj._load_rules()
+                objs.append(obj)
+        return objs
+
     @classmethod
     def _get_object_policy(cls, context, model, **kwargs):
-        binding_db_obj = db_api.get_object(context, model, **kwargs)
-        # TODO(QoS): rethink handling missing binding case
-        if binding_db_obj:
-            return cls.get_by_id(context, binding_db_obj['policy_id'])
+        with db_api.autonested_transaction(context.session):
+            binding_db_obj = db_api.get_object(context, model, **kwargs)
+            # TODO(QoS): rethink handling missing binding case
+            if binding_db_obj:
+                return cls.get_by_id(context, binding_db_obj['policy_id'])
 
     @classmethod
     def get_network_policy(cls, context, network_id):
@@ -92,6 +117,11 @@ class QosPolicy(base.NeutronObject):
         return cls._get_object_policy(context, cls.port_binding_model,
                                       port_id=port_id)
 
+    def create(self):
+        with db_api.autonested_transaction(self._context.session):
+            super(QosPolicy, self).create()
+            self._load_rules()
+
     def attach_network(self, network_id):
         qos_db_api.create_policy_network_binding(self._context,
                                                  policy_id=self.id,
index 6269e8dbb225b759a2ab22fad81db9be0706d6b5..0b5713e73b4fb0e1b8c2b22a88653c2ce3a56933 100644 (file)
@@ -96,7 +96,7 @@ class QosRule(base.NeutronObject):
             obj.obj_reset_changes()
             return obj
 
-    # TODO(QoS): create and update are not transactional safe
+    # TODO(QoS): Test that create is in single transaction
     def create(self):
 
         # TODO(QoS): enforce that type field value is bound to specific class
@@ -104,18 +104,21 @@ class QosRule(base.NeutronObject):
 
         # create base qos_rule
         core_fields = self._get_changed_core_fields()
-        base_db_obj = db_api.create_object(
-            self._context, self.base_db_model, core_fields)
 
-        # create type specific qos_..._rule
-        addn_fields = self._get_changed_addn_fields()
-        self._copy_common_fields(core_fields, addn_fields)
-        addn_db_obj = db_api.create_object(
-            self._context, self.db_model, addn_fields)
+        with db_api.autonested_transaction(self._context.session):
+            base_db_obj = db_api.create_object(
+                self._context, self.base_db_model, core_fields)
+
+            # create type specific qos_..._rule
+            addn_fields = self._get_changed_addn_fields()
+            self._copy_common_fields(core_fields, addn_fields)
+            addn_db_obj = db_api.create_object(
+                self._context, self.db_model, addn_fields)
 
         # merge two db objects into single neutron one
         self.from_db_object(base_db_obj, addn_db_obj)
 
+    # TODO(QoS): Test that update is in single transaction
     def update(self):
         updated_db_objs = []
 
@@ -123,16 +126,18 @@ class QosRule(base.NeutronObject):
 
         # update base qos_rule, if needed
         core_fields = self._get_changed_core_fields()
-        if core_fields:
-            base_db_obj = db_api.update_object(
-                self._context, self.base_db_model, self.id, core_fields)
-            updated_db_objs.append(base_db_obj)
-
-        addn_fields = self._get_changed_addn_fields()
-        if addn_fields:
-            addn_db_obj = db_api.update_object(
-                self._context, self.db_model, self.id, addn_fields)
-            updated_db_objs.append(addn_db_obj)
+
+        with db_api.autonested_transaction(self._context.session):
+            if core_fields:
+                base_db_obj = db_api.update_object(
+                    self._context, self.base_db_model, self.id, core_fields)
+                updated_db_objs.append(base_db_obj)
+
+            addn_fields = self._get_changed_addn_fields()
+            if addn_fields:
+                addn_db_obj = db_api.update_object(
+                    self._context, self.db_model, self.id, addn_fields)
+                updated_db_objs.append(addn_db_obj)
 
         # update neutron object with values from both database objects
         self.from_db_object(*updated_db_objs)
index 0b227c8a382632aecc530b8d0b637f816483476d..f1d9a147021ead5807bb543f45cbee7162ad3d2f 100644 (file)
@@ -17,6 +17,7 @@ from neutron import manager
 
 from neutron.api.rpc.callbacks import registry as rpc_registry
 from neutron.api.rpc.callbacks import resources as rpc_resources
+from neutron.db import db_base_plugin_common
 from neutron.extensions import qos
 from neutron.i18n import _LW
 from neutron.objects.qos import policy as policy_object
@@ -134,10 +135,11 @@ class QoSPlugin(qos.QoSPluginBase):
     def _get_policy_obj(self, context, policy_id):
         return policy_object.QosPolicy.get_by_id(context, policy_id)
 
+    @db_base_plugin_common.filter_fields
     def get_policy(self, context, policy_id, fields=None):
-        #TODO(QoS): Support the fields parameter
         return self._get_policy_obj(context, policy_id).to_dict()
 
+    @db_base_plugin_common.filter_fields
     def get_policies(self, context, filters=None, fields=None,
                      sorts=None, limit=None, marker=None,
                      page_reverse=False):
@@ -174,12 +176,13 @@ class QoSPlugin(qos.QoSPluginBase):
         rule.id = rule_id
         rule.delete()
 
+    @db_base_plugin_common.filter_fields
     def get_policy_bandwidth_limit_rule(self, context, rule_id,
                                         policy_id, fields=None):
-        #TODO(QoS): Support the fields parameter
         return rule_object.QosBandwidthLimitRule.get_by_id(context,
                                                            rule_id).to_dict()
 
+    @db_base_plugin_common.filter_fields
     def get_policy_bandwidth_limit_rules(self, context, policy_id,
                                          filters=None, fields=None,
                                          sorts=None, limit=None,
@@ -188,6 +191,7 @@ class QoSPlugin(qos.QoSPluginBase):
         return [rule_obj.to_dict() for rule_obj in
                 rule_object.QosBandwidthLimitRule.get_objects(context)]
 
+    @db_base_plugin_common.filter_fields
     def get_rule_types(self, context, filters=None, fields=None,
                        sorts=None, limit=None,
                        marker=None, page_reverse=False):
diff --git a/neutron/tests/unit/db/test_db_base_plugin_common.py b/neutron/tests/unit/db/test_db_base_plugin_common.py
new file mode 100644 (file)
index 0000000..9074bf6
--- /dev/null
@@ -0,0 +1,64 @@
+# Copyright (c) 2015 Red Hat, Inc.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+from neutron.db import db_base_plugin_common
+from neutron.tests import base
+
+
+class FilterFieldsTestCase(base.BaseTestCase):
+
+    @db_base_plugin_common.filter_fields
+    def method_dict(self, fields=None):
+        return {'one': 1, 'two': 2, 'three': 3}
+
+    @db_base_plugin_common.filter_fields
+    def method_list(self, fields=None):
+        return [self.method_dict() for _ in range(3)]
+
+    @db_base_plugin_common.filter_fields
+    def method_multiple_arguments(self, not_used, fields=None,
+                                  also_not_used=None):
+        return {'one': 1, 'two': 2, 'three': 3}
+
+    def test_no_fields(self):
+        expected = {'one': 1, 'two': 2, 'three': 3}
+        observed = self.method_dict()
+        self.assertEqual(expected, observed)
+
+    def test_dict(self):
+        expected = {'two': 2}
+        observed = self.method_dict(['two'])
+        self.assertEqual(expected, observed)
+
+    def test_list(self):
+        expected = [{'two': 2}, {'two': 2}, {'two': 2}]
+        observed = self.method_list(['two'])
+        self.assertEqual(expected, observed)
+
+    def test_multiple_arguments_positional(self):
+        expected = {'two': 2}
+        observed = self.method_multiple_arguments(list(), ['two'])
+        self.assertEqual(expected, observed)
+
+    def test_multiple_arguments_positional_and_keywords(self):
+        expected = {'two': 2}
+        observed = self.method_multiple_arguments(fields=['two'],
+                                                  not_used=None)
+        self.assertEqual(expected, observed)
+
+    def test_multiple_arguments_keyword(self):
+        expected = {'two': 2}
+        observed = self.method_multiple_arguments(list(), fields=['two'])
+        self.assertEqual(expected, observed)
index b73af22c6ccc7a116f86b90b9c04ccecd6a6ad3b..afd6a79829b22356452d2c31e45fa9ba46055913 100644 (file)
@@ -10,6 +10,8 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import mock
+
 from neutron.db import api as db_api
 from neutron.db import models_v2
 from neutron.objects.qos import policy
@@ -22,6 +24,50 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
     _test_class = policy.QosPolicy
 
+    def setUp(self):
+        super(QosPolicyObjectTestCase, self).setUp()
+        self.db_qos_rules = [self.get_random_fields(rule.QosRule)
+                             for _ in range(3)]
+
+        # Tie qos rules with policies
+        self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id']
+        self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id']
+        self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id']
+
+        self.db_qos_bandwidth_rules = [
+            self.get_random_fields(rule.QosBandwidthLimitRule)
+            for _ in range(3)]
+
+        # Tie qos rules with qos bandwidth limit rules
+        for i, qos_rule in enumerate(self.db_qos_rules):
+            self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id']
+
+        self.model_map = {
+            self._test_class.db_model: self.db_objs,
+            rule.QosRule.base_db_model: self.db_qos_rules,
+            rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules}
+
+    def fake_get_objects(self, context, model, qos_policy_id=None):
+        objs = self.model_map[model]
+        if model is rule.QosRule.base_db_model and qos_policy_id:
+            return [obj for obj in objs
+                    if obj['qos_policy_id'] == qos_policy_id]
+        return objs
+
+    def fake_get_object(self, context, model, id):
+        objects = self.model_map[model]
+        return [obj for obj in objects if obj['id'] == id][0]
+
+    def test_get_objects(self):
+        with mock.patch.object(
+                    db_api, 'get_objects',
+                    side_effect=self.fake_get_objects),\
+                mock.patch.object(
+                    db_api, 'get_object',
+                    side_effect=self.fake_get_object):
+            objs = self._test_class.get_objects(self.context)
+        self._validate_objects(self.db_objs, objs)
+
 
 class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
                                 testlib_api.SqlTestCase):
@@ -42,6 +88,19 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
         policy_obj.create()
         return policy_obj
 
+    def _create_test_policy_with_rule(self):
+        policy_obj = self._create_test_policy()
+
+        rule_fields = self.get_random_fields(
+            obj_cls=rule.QosBandwidthLimitRule)
+        rule_fields['qos_policy_id'] = policy_obj.id
+        rule_fields['tenant_id'] = policy_obj.tenant_id
+
+        rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
+        rule_obj.create()
+
+        return policy_obj, rule_obj
+
     def _create_test_network(self):
         # TODO(ihrachys): replace with network.create() once we get an object
         # implementation for networks
@@ -111,16 +170,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
         self.assertIsNone(policy_obj)
 
     def test_synthetic_rule_fields(self):
-        obj = policy.QosPolicy(self.context, **self.db_obj)
-        obj.create()
-
-        rule_fields = self.get_random_fields(
-            obj_cls=rule.QosBandwidthLimitRule)
-        rule_fields['qos_policy_id'] = obj.id
-        rule_fields['tenant_id'] = obj.tenant_id
-
-        rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
-        rule_obj.create()
-
-        obj = policy.QosPolicy.get_by_id(self.context, obj.id)
-        self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
+        policy_obj, rule_obj = self._create_test_policy_with_rule()
+        policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
+        self.assertEqual([rule_obj], policy_obj.bandwidth_limit_rules)
+
+    def test_create_is_in_single_transaction(self):
+        obj = self._test_class(self.context, **self.db_obj)
+        with mock.patch('sqlalchemy.engine.'
+                        'Transaction.commit') as mock_commit,\
+                mock.patch.object(obj._context.session, 'add'):
+            obj.create()
+        self.assertEqual(1, mock_commit.call_count)
+
+    def test_get_by_id_fetches_rules_non_lazily(self):
+        policy_obj, rule_obj = self._create_test_policy_with_rule()
+        policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
+
+        primitive = policy_obj.obj_to_primitive()
+        self.assertNotEqual([], (primitive['versioned_object.data']
+                                          ['bandwidth_limit_rules']))
index 5e15dc79717b6a3c178509a622fc84b2ac8824fe..0b1c4b2390a14901b5ab44fc6b7d45e9c1643b1a 100644 (file)
@@ -23,10 +23,15 @@ from neutron.objects import base
 from neutron.tests import base as test_base
 
 
+class FakeModel(object):
+    def __init__(self, *args, **kwargs):
+        pass
+
+
 @obj_base.VersionedObjectRegistry.register
 class FakeNeutronObject(base.NeutronObject):
 
-    db_model = 'fake_model'
+    db_model = FakeModel
 
     fields = {
         'id': obj_fields.UUIDField(),
@@ -106,13 +111,16 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
         with mock.patch.object(db_api, 'get_objects',
                                return_value=self.db_objs) as get_objects_mock:
             objs = self._test_class.get_objects(self.context)
-            self.assertFalse(
-                filter(lambda obj: not self._is_test_class(obj), objs))
-            self.assertEqual(
-                sorted(self.db_objs),
-                sorted(get_obj_db_fields(obj) for obj in objs))
-            get_objects_mock.assert_called_once_with(
-                self.context, self._test_class.db_model)
+            self._validate_objects(self.db_objs, objs)
+        get_objects_mock.assert_called_once_with(
+            self.context, self._test_class.db_model)
+
+    def _validate_objects(self, expected, observed):
+        self.assertFalse(
+            filter(lambda obj: not self._is_test_class(obj), observed))
+        self.assertEqual(
+            sorted(expected),
+            sorted(get_obj_db_fields(obj) for obj in observed))
 
     def _check_equal(self, obj, db_obj):
         self.assertEqual(