]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
objects.qos.policy: support per type rule lists as synthetic fields
authorIhar Hrachyshka <ihrachys@redhat.com>
Fri, 10 Jul 2015 16:00:34 +0000 (18:00 +0200)
committerIhar Hrachyshka <ihrachys@redhat.com>
Thu, 16 Jul 2015 13:51:47 +0000 (15:51 +0200)
This is a significant piece of work.

It enables neutron objects to define fields that are lazily loaded on
field access. To achieve that,

- field should be mentioned in cls.synthetic_fields
- obj_load_attr should be extended to lazily fetch and cache the field

Based on this work, we define per type rule fields that are lists of
appropriate neutron objects. (At the moment, we have only single type
supported, but I tried hard to make it easily extendable, with little or
no coding needed when a new rule type object definition is added to
rule.py: for example, we inspect object definitions based on
VALID_RULE_TYPES, and define appropriate fields for the policy object).

To implement lazy loading for those fields, I redefined get_by_id for
rules that now meld fields from both base and subtype db models into the
corresponding neutron object.

Added a simple test that checks bandwidth_rules attribute behaves for
policies.

Some objects unit test framework rework was needed to accomodate
synthetic fields that are not propagated to db layer.

Change-Id: Ia16393453b1ed48651fbd778bbe0ac6427560117

neutron/common/exceptions.py
neutron/common/utils.py
neutron/db/api.py
neutron/objects/base.py
neutron/objects/qos/policy.py
neutron/objects/qos/rule.py
neutron/tests/unit/common/test_utils.py
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/objects/qos/test_rule.py
neutron/tests/unit/objects/test_base.py

index c6ec6ccca5407117f89eff47a7a0d4ee074766c3..163dd9818278d2390cfe09c6afb4d45520f943c0 100644 (file)
@@ -470,3 +470,7 @@ class DeviceNotFoundError(NeutronException):
 class NetworkSubnetPoolAffinityError(BadRequest):
     message = _("Subnets hosted on the same network must be allocated from "
                 "the same subnet pool")
+
+
+class ObjectActionError(NeutronException):
+    message = _('Object action %(action)s failed because: %(reason)s')
index bd2dccdb0d2082f59ae06313b9601d7a6c1515b0..ec16b775752f3ab283476c28e2d79d78eee9d1b2 100644 (file)
@@ -423,3 +423,7 @@ class DelayedStringRenderer(object):
 
     def __str__(self):
         return str(self.function(*self.args, **self.kwargs))
+
+
+def camelize(s):
+    return ''.join(s.replace('_', ' ').title().split())
index 2bada2f6e98493a903b4fd81d608b319bf0aae0f..c1619c51b46a015fd9e484876081fbb00c3514d4 100644 (file)
@@ -91,7 +91,7 @@ class convert_db_exception_to_retry(object):
 
 
 # Common database operation implementations
-# TODO(QoS): consider handling multiple objects found, or no objects at all
+# TODO(QoS): consider reusing get_objects below
 # TODO(QoS): consider changing the name and making it public, officially
 def _find_object(context, model, **kwargs):
     with context.session.begin(subtransactions=True):
@@ -101,15 +101,18 @@ def _find_object(context, model, **kwargs):
 
 
 def get_object(context, model, id):
+    # TODO(QoS): consider reusing get_objects below
     with context.session.begin(subtransactions=True):
         return (common_db_mixin.model_query(context, model)
                 .filter_by(id=id)
                 .first())
 
 
-def get_objects(context, model):
+def get_objects(context, model, **kwargs):
     with context.session.begin(subtransactions=True):
-        return common_db_mixin.model_query(context, model).all()
+        return (common_db_mixin.model_query(context, model)
+                .filter_by(**kwargs)
+                .all())
 
 
 def create_object(context, model, values):
index f2b18511db4ffe8543471d8dc17136d4ac377eb7..e41ac9ec4d93700284caf3fce346dffd94a36a75 100644 (file)
@@ -32,6 +32,8 @@ class NeutronObject(obj_base.VersionedObject,
     # fields that are not allowed to update
     fields_no_update = []
 
+    synthetic_fields = []
+
     def from_db_object(self, *objs):
         for field in self.fields:
             for db_obj in objs:
@@ -53,21 +55,27 @@ class NeutronObject(obj_base.VersionedObject,
             return obj
 
     @classmethod
-    def get_objects(cls, context):
-        db_objs = db_api.get_objects(context, cls.db_model)
+    def get_objects(cls, context, **kwargs):
+        db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
         objs = [cls(context, **db_obj) for db_obj in db_objs]
         for obj in objs:
             obj.obj_reset_changes()
         return objs
 
-    def create(self):
+    def _get_changed_persistent_fields(self):
         fields = self.obj_get_changes()
+        for field in self.synthetic_fields:
+            if field in fields:
+                del fields[field]
+        return fields
+
+    def create(self):
+        fields = self._get_changed_persistent_fields()
         db_obj = db_api.create_object(self._context, self.db_model, fields)
         self.from_db_object(db_obj)
 
     def update(self):
-        # TODO(QoS): enforce fields_no_update
-        updates = self.obj_get_changes()
+        updates = self._get_changed_persistent_fields()
         if updates:
             db_obj = db_api.update_object(self._context, self.db_model,
                                           self.id, updates)
index 83c481a02b1fd87d0cbe872544fb1c9f35366543..09ba2b59bb970a1fa6e7e44077f56ff43194e91e 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import abc
+
 from oslo_versionedobjects import base as obj_base
 from oslo_versionedobjects import fields as obj_fields
+import six
 
+from neutron.common import exceptions
+from neutron.common import utils
 from neutron.db import api as db_api
 from neutron.db.qos import api as qos_db_api
 from neutron.db.qos import models as qos_db_model
+from neutron.extensions import qos as qos_extension
 from neutron.objects import base
+from neutron.objects.qos import rule as rule_obj_impl
+
+
+class QosRulesExtenderMeta(abc.ABCMeta):
+
+    def __new__(cls, *args, **kwargs):
+        cls_ = super(QosRulesExtenderMeta, cls).__new__(cls, *args, **kwargs)
 
+        cls_.rule_fields = {}
+        for rule in qos_extension.VALID_RULE_TYPES:
+            rule_cls_name = 'Qos%sRule' % utils.camelize(rule)
+            field = '%s_rules' % rule
+            cls_.fields[field] = obj_fields.ListOfObjectsField(rule_cls_name)
+            cls_.rule_fields[field] = rule_cls_name
 
-# TODO(QoS): add rule lists to object fields
-# TODO(QoS): implement something for binding networks and ports with policies
+        cls_.synthetic_fields = list(cls_.rule_fields.keys())
+
+        return cls_
 
 
 @obj_base.VersionedObjectRegistry.register
+@six.add_metaclass(QosRulesExtenderMeta)
 class QosPolicy(base.NeutronObject):
 
     db_model = qos_db_model.QosPolicy
@@ -44,6 +65,16 @@ class QosPolicy(base.NeutronObject):
 
     fields_no_update = ['id', 'tenant_id']
 
+    def obj_load_attr(self, attrname):
+        if attrname not in self.rule_fields:
+            raise exceptions.ObjectActionError(
+                action='obj_load_attr', reason='unable to load %s' % attrname)
+
+        rule_cls = getattr(rule_obj_impl, self.rule_fields[attrname])
+        rules = rule_cls.get_rules_by_policy(self._context, self.id)
+        setattr(self, attrname, rules)
+        self.obj_reset_changes([attrname])
+
     @classmethod
     def _get_object_policy(cls, context, model, **kwargs):
         # TODO(QoS): we should make sure we use public functions
index 3de9476d622c9e8a7473a3bbe82b665555613c0a..b9aead64b71499e4f5304b29b2b1f6289dde86af 100644 (file)
@@ -21,6 +21,7 @@ import six
 
 from neutron.db import api as db_api
 from neutron.db.qos import models as qos_db_model
+from neutron.extensions import qos as qos_extension
 from neutron.objects import base
 
 
@@ -37,6 +38,9 @@ class QosRule(base.NeutronObject):
 
     fields_no_update = ['id', 'tenant_id', 'qos_policy_id']
 
+    # each rule subclass should redefine it
+    rule_type = None
+
     _core_fields = list(fields.keys())
 
     _common_fields = ['id']
@@ -60,8 +64,6 @@ class QosRule(base.NeutronObject):
             if func(key)
         }
 
-    # TODO(QoS): reimplement get_by_id to merge both core and addn fields
-
     def _get_changed_core_fields(self):
         fields = self.obj_get_changes()
         return self._filter_fields(fields, self._is_core_field)
@@ -75,9 +77,32 @@ class QosRule(base.NeutronObject):
         for field in self._common_fields:
             to_[field] = from_[field]
 
+    @classmethod
+    def get_objects(cls, context, **kwargs):
+        # TODO(QoS): support searching for subtype fields
+        db_objs = db_api.get_objects(context, cls.base_db_model, **kwargs)
+        return [cls.get_by_id(context, db_obj['id']) for db_obj in db_objs]
+
+    @classmethod
+    def get_by_id(cls, context, id):
+        obj = super(QosRule, cls).get_by_id(context, id)
+
+        if obj:
+            # the object above does not contain fields from base QosRule yet,
+            # so fetch it and mix its fields into the object
+            base_db_obj = db_api.get_object(context, cls.base_db_model, id)
+            for field in cls._core_fields:
+                setattr(obj, field, base_db_obj[field])
+
+            obj.obj_reset_changes()
+            return obj
+
     # TODO(QoS): create and update are not transactional safe
     def create(self):
 
+        # TODO(QoS): enforce that type field value is bound to specific class
+        self.type = self.rule_type
+
         # create base qos_rule
         core_fields = self._get_changed_core_fields()
         base_db_obj = db_api.create_object(
@@ -95,6 +120,8 @@ class QosRule(base.NeutronObject):
     def update(self):
         updated_db_objs = []
 
+        # TODO(QoS): enforce that type field cannot be changed
+
         # update base qos_rule, if needed
         core_fields = self._get_changed_core_fields()
         if core_fields:
@@ -113,13 +140,19 @@ class QosRule(base.NeutronObject):
 
     # delete is the same, additional rule object cleanup is done thru cascading
 
+    @classmethod
+    def get_rules_by_policy(cls, context, policy_id):
+        return cls.get_objects(context, qos_policy_id=policy_id)
+
 
 @obj_base.VersionedObjectRegistry.register
 class QosBandwidthLimitRule(QosRule):
 
     db_model = qos_db_model.QosBandwidthLimitRule
 
+    rule_type = qos_extension.RULE_TYPE_BANDWIDTH_LIMIT
+
     fields = {
-        'max_kbps': obj_fields.IntegerField(),
-        'max_burst_kbps': obj_fields.IntegerField()
+        'max_kbps': obj_fields.IntegerField(nullable=True),
+        'max_burst_kbps': obj_fields.IntegerField(nullable=True)
     }
index 82c84904c0005e78717b0db28c2a43c3f89f0b9c..1f5cfb2e46a9772f2b370f14bc3c5e320e058632 100644 (file)
@@ -663,3 +663,14 @@ class TestDelayedStringRenderer(base.BaseTestCase):
         LOG.logger.setLevel(logging.logging.DEBUG)
         LOG.debug("Hello %s", delayed)
         self.assertTrue(my_func.called)
+
+
+class TestCamelize(base.BaseTestCase):
+    def test_camelize(self):
+        data = {'bandwidth_limit': 'BandwidthLimit',
+                'test': 'Test',
+                'some__more__dashes': 'SomeMoreDashes',
+                'a_penguin_walks_into_a_bar': 'APenguinWalksIntoABar'}
+
+        for s, expected in data.items():
+            self.assertEqual(expected, utils.camelize(s))
index 9c208b994959ed850425cb3525683833a9411b17..d3b720cdd7ad2f48a6fb5d4b569df635886bd3b2 100644 (file)
@@ -13,6 +13,7 @@
 from neutron.db import api as db_api
 from neutron.db import models_v2
 from neutron.objects.qos import policy
+from neutron.objects.qos import rule
 from neutron.tests.unit.objects import test_base
 from neutron.tests.unit import testlib_api
 
@@ -112,3 +113,18 @@ class QosPolicyDbObjectTestCase(QosPolicyBaseTestCase,
         policy_obj = policy.QosPolicy.get_network_policy(self.context,
                                                          self._network['id'])
         self.assertIsNone(policy_obj)
+
+    def test_synthetic_rule_fields(self):
+        obj = policy.QosPolicy(self.context, **self.db_obj)
+        obj.create()
+
+        rule_fields = self.get_random_fields(
+            obj_cls=rule.QosBandwidthLimitRule)
+        rule_fields['qos_policy_id'] = obj.id
+        rule_fields['tenant_id'] = obj.tenant_id
+
+        rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
+        rule_obj.create()
+
+        obj = policy.QosPolicy.get_by_id(self.context, obj.id)
+        self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
index 867a0b97744b42a69768b654366d57775eebb986..52364fba63729c4f6af1e1c0e849f4ba6a9bf40e 100644 (file)
@@ -21,6 +21,15 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
     _test_class = rule.QosBandwidthLimitRule
 
+    @classmethod
+    def get_random_fields(cls):
+        # object middleware should not allow random types, so override it with
+        # proper type
+        fields = (super(QosBandwidthLimitPolicyObjectTestCase, cls)
+                  .get_random_fields())
+        fields['type'] = cls._test_class.rule_type
+        return fields
+
     def _filter_db_object(self, func):
         return {
             field: self.db_obj[field]
@@ -36,6 +45,36 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
         return self._filter_db_object(
             lambda field: self._test_class._is_addn_field(field))
 
+    def test_get_by_id(self):
+        with mock.patch.object(db_api, 'get_object',
+                               return_value=self.db_obj) as get_object_mock:
+            obj = self._test_class.get_by_id(self.context, id='fake_id')
+            self.assertTrue(self._is_test_class(obj))
+            self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
+            get_object_mock.assert_has_calls([
+                mock.call(self.context, model, 'fake_id')
+                for model in (self._test_class.db_model,
+                              self._test_class.base_db_model)
+            ], any_order=True)
+
+    def test_get_objects(self):
+        with mock.patch.object(db_api, 'get_objects',
+                               return_value=self.db_objs):
+
+            @classmethod
+            def _get_by_id(cls, context, id):
+                for db_obj in self.db_objs:
+                    if db_obj['id'] == id:
+                        return self._test_class(context, **db_obj)
+
+            with mock.patch.object(rule.QosRule, 'get_by_id', new=_get_by_id):
+                objs = self._test_class.get_objects(self.context)
+                self.assertFalse(
+                    filter(lambda obj: not self._is_test_class(obj), objs))
+                self.assertEqual(
+                    sorted(self.db_objs),
+                    sorted(test_base.get_obj_db_fields(obj) for obj in objs))
+
     def test_create(self):
         with mock.patch.object(db_api, 'create_object',
                                return_value=self.db_obj) as create_mock:
@@ -46,13 +85,13 @@ class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
             self._check_equal(obj, self.db_obj)
 
             core_db_obj = self._get_core_db_obj()
-            create_mock.assert_any_call(
-                self.context, self._test_class.base_db_model, core_db_obj)
-
             addn_db_obj = self._get_addn_db_obj()
-            create_mock.assert_any_call(
-                self.context, self._test_class.db_model,
-                addn_db_obj)
+            create_mock.assert_has_calls(
+                [mock.call(self.context, self._test_class.base_db_model,
+                           core_db_obj),
+                 mock.call(self.context, self._test_class.db_model,
+                           addn_db_obj)]
+            )
 
     def test_update_changes(self):
         with mock.patch.object(db_api, 'update_object',
index a56d6cb3fd74fbd18b76fa957347147b99275f96..45725c52975377756203f19130b0b571b2d878d3 100644 (file)
@@ -52,11 +52,13 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = {
     obj_fields.IntegerField: _random_integer,
     obj_fields.StringField: _random_string,
     obj_fields.UUIDField: _random_string,
+    obj_fields.ListOfObjectsField: lambda: []
 }
 
 
-def get_obj_fields(obj):
-    return {field: getattr(obj, field) for field in obj.fields}
+def get_obj_db_fields(obj):
+    return {field: getattr(obj, field) for field in obj.fields
+            if field not in obj.synthetic_fields}
 
 
 class _BaseObjectTestCase(object):
@@ -66,15 +68,17 @@ class _BaseObjectTestCase(object):
     def setUp(self):
         super(_BaseObjectTestCase, self).setUp()
         self.context = context.get_admin_context()
-        self.db_objs = list(self._get_random_fields() for _ in range(3))
+        self.db_objs = list(self.get_random_fields() for _ in range(3))
         self.db_obj = self.db_objs[0]
 
     @classmethod
-    def _get_random_fields(cls):
+    def get_random_fields(cls, obj_cls=None):
+        obj_cls = obj_cls or cls._test_class
         fields = {}
-        for field in cls._test_class.fields:
-            field_obj = cls._test_class.fields[field]
-            fields[field] = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]()
+        for field, field_obj in obj_cls.fields.items():
+            if field not in obj_cls.synthetic_fields:
+                generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
+                fields[field] = generator()
         return fields
 
     @classmethod
@@ -89,7 +93,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
                                return_value=self.db_obj) as get_object_mock:
             obj = self._test_class.get_by_id(self.context, id='fake_id')
             self.assertTrue(self._is_test_class(obj))
-            self.assertEqual(self.db_obj, get_obj_fields(obj))
+            self.assertEqual(self.db_obj, get_obj_db_fields(obj))
             get_object_mock.assert_called_once_with(
                 self.context, self._test_class.db_model, 'fake_id')
 
@@ -106,14 +110,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
                 filter(lambda obj: not self._is_test_class(obj), objs))
             self.assertEqual(
                 sorted(self.db_objs),
-                sorted(get_obj_fields(obj) for obj in objs))
+                sorted(get_obj_db_fields(obj) for obj in objs))
             get_objects_mock.assert_called_once_with(
                 self.context, self._test_class.db_model)
 
     def _check_equal(self, obj, db_obj):
         self.assertEqual(
             sorted(db_obj),
-            sorted(get_obj_fields(obj)))
+            sorted(get_obj_db_fields(obj)))
 
     def test_create(self):
         with mock.patch.object(db_api, 'create_object',