]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Unite qos_rules and qos_*_rules tables
authorIhar Hrachyshka <ihrachys@redhat.com>
Sat, 25 Jul 2015 14:21:37 +0000 (16:21 +0200)
committerIhar Hrachyshka <ihrachys@redhat.com>
Sat, 1 Aug 2015 20:13:14 +0000 (22:13 +0200)
The only values in qos_rules table are: type, id and qos_policy_id. Both
id fields point to qos_*_rules and qos_policies objects.

Type is redundant since qos_rule and qos_*_rule objects maintain 1-to-1
relationship.

Keeping a separate table just to link qos_*_rule and qos_policy objects
has no meaning. At the same time, it complicates the code for rule
objects significantly.

So instead of copying with all those issues, we just squash the tables
into single one. It allows us to reuse all base methods from
NeutronObject for rules.

LOC stats for the patch clearly shows the point:

65 insertions(+), 267 deletions(-)

And no actual functionality is lost.

While at it, the following changes were applied:

- some base tests are reimplemented to test objects in a more explicit
  way;
- fields_no_update class attribute is now actually enforced in base
  object class.

Partially-Implements: blueprint quantum-qos-api
Change-Id: Iadabd14c3490c842608e53ceccf38c79dcdf8d85

12 files changed:
doc/source/devref/quality_of_service.rst
neutron/common/exceptions.py
neutron/db/api.py
neutron/db/migration/alembic_migrations/versions/liberty/expand/48153cb5f051_qos_db_changes.py
neutron/db/qos/models.py
neutron/objects/base.py
neutron/objects/qos/policy.py
neutron/objects/qos/rule.py
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/objects/qos/test_rule.py
neutron/tests/unit/objects/test_base.py
neutron/tests/unit/services/qos/test_qos_plugin.py

index 1c5570205c3c2349d79bc55f4276f9390d70be85..2742f1da6a287ba69d851b06a033ed9252c4763c 100644 (file)
@@ -65,15 +65,8 @@ From database point of view, following objects are defined in schema:
 * QosPolicy: directly maps to the conceptual policy resource.
 * QosNetworkPolicyBinding, QosPortPolicyBinding: defines attachment between a
   Neutron resource and a QoS policy.
-* QosRule: defines common rule fields for all supported rule types.
-* QosBandwidthLimitRule: defines rule fields that are specific to
-  bandwidth_limit type (the only type supported by the service as of time of
-  writing).
+* QosBandwidthLimitRule: defines the only rule type available at the moment.
 
-There is a one-to-one relationship between QosRule and type specific
-Qos<type>Rule database objects. We represent the single object with two tables
-to avoid duplication of common fields. (That introduces some complexity in
-neutron objects for rule resources, but see below).
 
 All database models are defined under:
 
@@ -138,10 +131,10 @@ Note that synthetic fields are lazily loaded, meaning there is no hit into
 the database if the field is not inspected by consumers.
 
 For Qos<type>Rule objects, an extendable approach was taken to allow easy
-addition of objects for new rule types. To accomodate this, all the methods
-that access the database were implemented in a base class called QosRule that
-is then inherited into type-specific rule implementations that, ideally, only
-define additional fields and some other minor things.
+addition of objects for new rule types. To accomodate this, fields common to
+all types are put into a base class called QosRule that is then inherited into
+type-specific rule implementations that, ideally, only define additional fields
+and some other minor things.
 
 Note that the QosRule base class is not registered with oslo.versionedobjects
 registry, because it's not expected that 'generic' rules should be
index b4d3f5a4b2500ef69fd11275330a44f611f67eb1..7dc39bf4800ccba5713d5b0c7698df746c2e92aa 100644 (file)
@@ -73,6 +73,10 @@ class AdminRequired(NotAuthorized):
     message = _("User does not have admin privileges: %(reason)s")
 
 
+class ObjectNotFound(NotFound):
+    message = _("Object %(id)s not found.")
+
+
 class NetworkNotFound(NotFound):
     message = _("Network %(net_id)s could not be found")
 
index 2c438055ccc5e75cfa8aacd65d836eff907b3f0e..b4384eec0c01cd3ab213c41767641f82eeffdc11 100644 (file)
@@ -24,6 +24,7 @@ from oslo_utils import uuidutils
 from sqlalchemy import exc
 from sqlalchemy import orm
 
+from neutron.common import exceptions as n_exc
 from neutron.db import common_db_mixin
 
 
@@ -117,9 +118,16 @@ def create_object(context, model, values):
     return db_obj.__dict__
 
 
+def _safe_get_object(context, model, id):
+    db_obj = get_object(context, model, id=id)
+    if db_obj is None:
+        raise n_exc.ObjectNotFound(id=id)
+    return db_obj
+
+
 def update_object(context, model, id, values):
     with context.session.begin(subtransactions=True):
-        db_obj = get_object(context, model, id=id)
+        db_obj = _safe_get_object(context, model, id)
         db_obj.update(values)
         db_obj.save(session=context.session)
     return db_obj.__dict__
@@ -127,5 +135,5 @@ def update_object(context, model, id, values):
 
 def delete_object(context, model, id):
     with context.session.begin(subtransactions=True):
-        db_obj = get_object(context, model, id=id)
+        db_obj = _safe_get_object(context, model, id)
         context.session.delete(db_obj)
index 03711ca03d46dca8121a36022d37062b501b33a6..d20048b0e394a705a6ed7dd856a15cdfe11b40aa 100755 (executable)
@@ -60,18 +60,10 @@ def upgrade():
                   nullable=False, unique=True))
 
     op.create_table(
-        'qos_rules',
+        'qos_bandwidth_limit_rules',
         sa.Column('id', sa.String(length=36), primary_key=True),
         sa.Column('qos_policy_id', sa.String(length=36),
                   sa.ForeignKey('qos_policies.id', ondelete='CASCADE'),
                   nullable=False),
-        sa.Column('type', sa.String(length=255)))
-
-    op.create_table(
-        'qos_bandwidth_limit_rules',
-        sa.Column('id', sa.String(length=36),
-                  sa.ForeignKey('qos_rules.id', ondelete='CASCADE'),
-                  nullable=False,
-                  primary_key=True),
         sa.Column('max_kbps', sa.Integer()),
         sa.Column('max_burst_kbps', sa.Integer()))
index f40ee0f49a30ed4b5d861f7b7ee1413249e3f9ed..89594618ff1949c025f6ebcb6d9dbf0fa8fcc94d 100755 (executable)
@@ -69,21 +69,16 @@ class QosPortPolicyBinding(model_base.BASEV2):
                                cascade='delete', lazy='joined'))
 
 
-class QosRule(model_base.BASEV2, models_v2.HasId):
-    __tablename__ = 'qos_rules'
-    type = sa.Column(sa.String(255))
-    qos_policy_id = sa.Column(sa.String(36),
-                              sa.ForeignKey('qos_policies.id',
-                                            ondelete='CASCADE'),
-                              nullable=False)
+class QosRuleColumns(models_v2.HasId):
+    qos_policy_id = sa.Column(sa.String(36), nullable=False)
 
+    __table_args__ = (
+        sa.ForeignKeyConstraint(['qos_policy_id'], ['qos_policies.id']),
+        model_base.BASEV2.__table_args__
+    )
 
-class QosBandwidthLimitRule(model_base.BASEV2):
+
+class QosBandwidthLimitRule(QosRuleColumns, model_base.BASEV2):
     __tablename__ = 'qos_bandwidth_limit_rules'
     max_kbps = sa.Column(sa.Integer)
     max_burst_kbps = sa.Column(sa.Integer)
-    id = sa.Column(sa.String(36),
-                   sa.ForeignKey('qos_rules.id',
-                                 ondelete='CASCADE'),
-                   nullable=False,
-                   primary_key=True)
index 264bbf9af9d03e61af2de951a19e599e202a8d76..5339fce2741654bd452026d5beedb588837c2924 100644 (file)
@@ -15,9 +15,22 @@ import abc
 from oslo_versionedobjects import base as obj_base
 import six
 
+from neutron.common import exceptions
 from neutron.db import api as db_api
 
 
+class NeutronObjectUpdateForbidden(exceptions.NeutronException):
+    message = _("Unable to update the following object fields: %(fields)s")
+
+
+def get_updatable_fields(cls, fields):
+    fields = fields.copy()
+    for field in cls.fields_no_update:
+        if field in fields:
+            del fields[field]
+    return fields
+
+
 @six.add_metaclass(abc.ABCMeta)
 class NeutronObject(obj_base.VersionedObject,
                     obj_base.VersionedObjectDictCompat,
@@ -54,11 +67,10 @@ class NeutronDbObject(NeutronObject):
     # should be overridden for all persistent objects
     db_model = None
 
-    # fields that are not allowed to update
-    fields_no_update = []
-
     synthetic_fields = []
 
+    fields_no_update = []
+
     def from_db_object(self, *objs):
         for field in self.fields:
             for db_obj in objs:
@@ -90,6 +102,18 @@ class NeutronDbObject(NeutronObject):
                 del fields[field]
         return fields
 
+    def _validate_changed_fields(self, fields):
+        fields = fields.copy()
+        # We won't allow id update anyway, so let's pop it out not to trigger
+        # update on id field touched by the consumer
+        fields.pop('id', None)
+
+        forbidden_updates = set(self.fields_no_update) & set(fields.keys())
+        if forbidden_updates:
+            raise NeutronObjectUpdateForbidden(fields=forbidden_updates)
+
+        return fields
+
     def create(self):
         fields = self._get_changed_persistent_fields()
         db_obj = db_api.create_object(self._context, self.db_model, fields)
@@ -97,6 +121,8 @@ class NeutronDbObject(NeutronObject):
 
     def update(self):
         updates = self._get_changed_persistent_fields()
+        updates = self._validate_changed_fields(updates)
+
         if updates:
             db_obj = db_api.update_object(self._context, self.db_model,
                                           self.id, updates)
index cc7cdc981a18141a4301a56ca9451c942ef03b3c..b86636e76bf4abf09f24e53d4d55e7abb0cb32ca 100644 (file)
@@ -78,7 +78,7 @@ class QosPolicy(base.NeutronDbObject):
                 action='obj_load_attr', reason='unable to load %s' % attrname)
 
         rule_cls = getattr(rule_obj_impl, self.rule_fields[attrname])
-        rules = rule_cls.get_rules_by_policy(self._context, self.id)
+        rules = rule_cls.get_objects(self._context, qos_policy_id=self.id)
         setattr(self, attrname, rules)
         self.obj_reset_changes([attrname])
 
@@ -142,6 +142,7 @@ class QosPolicy(base.NeutronDbObject):
         return cls._get_object_policy(context, cls.port_binding_model,
                                       port_id=port_id)
 
+    # TODO(QoS): Consider extending base to trigger registered methods for us
     def create(self):
         with db_api.autonested_transaction(self._context.session):
             super(QosPolicy, self).create()
index d62ad9419571e00b42cf651c836c2f332e11b3e5..d9e44d1f1ec4f4bcb59af35554c73939d271f47e 100644 (file)
@@ -19,135 +19,19 @@ from oslo_versionedobjects import base as obj_base
 from oslo_versionedobjects import fields as obj_fields
 import six
 
-from neutron.db import api as db_api
 from neutron.db.qos import models as qos_db_model
 from neutron.objects import base
-from neutron.services.qos import qos_consts
 
 
 @six.add_metaclass(abc.ABCMeta)
 class QosRule(base.NeutronDbObject):
 
-    base_db_model = qos_db_model.QosRule
-
     fields = {
         'id': obj_fields.UUIDField(),
-        #TODO(QoS): We ought to kill the `type' attribute
-        'type': obj_fields.StringField(),
         'qos_policy_id': obj_fields.UUIDField()
     }
 
-    fields_no_update = ['id', 'tenant_id', 'qos_policy_id']
-
-    # each rule subclass should redefine it
-    rule_type = None
-
-    _core_fields = list(fields.keys())
-
-    _common_fields = ['id']
-
-    @classmethod
-    def _is_common_field(cls, field):
-        return field in cls._common_fields
-
-    @classmethod
-    def _is_core_field(cls, field):
-        return field in cls._core_fields
-
-    @classmethod
-    def _is_addn_field(cls, field):
-        return not cls._is_core_field(field) or cls._is_common_field(field)
-
-    @staticmethod
-    def _filter_fields(fields, func):
-        return {
-            key: val for key, val in fields.items()
-            if func(key)
-        }
-
-    def _get_changed_core_fields(self):
-        fields = self.obj_get_changes()
-        return self._filter_fields(fields, self._is_core_field)
-
-    def _get_changed_addn_fields(self):
-        fields = self.obj_get_changes()
-        return self._filter_fields(fields, self._is_addn_field)
-
-    def _copy_common_fields(self, from_, to_):
-        for field in self._common_fields:
-            to_[field] = from_[field]
-
-    @classmethod
-    def get_objects(cls, context, **kwargs):
-        # TODO(QoS): support searching for subtype fields
-        db_objs = db_api.get_objects(context, cls.base_db_model, **kwargs)
-        return [cls.get_by_id(context, db_obj['id']) for db_obj in db_objs]
-
-    @classmethod
-    def get_by_id(cls, context, id):
-        obj = super(QosRule, cls).get_by_id(context, id)
-
-        if obj:
-            # the object above does not contain fields from base QosRule yet,
-            # so fetch it and mix its fields into the object
-            base_db_obj = db_api.get_object(context, cls.base_db_model, id=id)
-            for field in cls._core_fields:
-                setattr(obj, field, base_db_obj[field])
-
-            obj.obj_reset_changes()
-            return obj
-
-    # TODO(QoS): Test that create is in single transaction
-    def create(self):
-
-        # TODO(QoS): enforce that type field value is bound to specific class
-        self.type = self.rule_type
-
-        # create base qos_rule
-        core_fields = self._get_changed_core_fields()
-
-        with db_api.autonested_transaction(self._context.session):
-            base_db_obj = db_api.create_object(
-                self._context, self.base_db_model, core_fields)
-
-            # create type specific qos_..._rule
-            addn_fields = self._get_changed_addn_fields()
-            self._copy_common_fields(core_fields, addn_fields)
-            addn_db_obj = db_api.create_object(
-                self._context, self.db_model, addn_fields)
-
-        # merge two db objects into single neutron one
-        self.from_db_object(base_db_obj, addn_db_obj)
-
-    # TODO(QoS): Test that update is in single transaction
-    def update(self):
-        updated_db_objs = []
-
-        # TODO(QoS): enforce that type field cannot be changed
-
-        # update base qos_rule, if needed
-        core_fields = self._get_changed_core_fields()
-
-        with db_api.autonested_transaction(self._context.session):
-            if core_fields:
-                base_db_obj = db_api.update_object(
-                    self._context, self.base_db_model, self.id, core_fields)
-                updated_db_objs.append(base_db_obj)
-
-            addn_fields = self._get_changed_addn_fields()
-            if addn_fields:
-                addn_db_obj = db_api.update_object(
-                    self._context, self.db_model, self.id, addn_fields)
-                updated_db_objs.append(addn_db_obj)
-
-        # update neutron object with values from both database objects
-        self.from_db_object(*updated_db_objs)
-
-    # delete is the same, additional rule object cleanup is done thru cascading
-
-    @classmethod
-    def get_rules_by_policy(cls, context, policy_id):
-        return cls.get_objects(context, qos_policy_id=policy_id)
+    fields_no_update = ['id', 'qos_policy_id']
 
 
 @obj_base.VersionedObjectRegistry.register
@@ -155,8 +39,6 @@ class QosBandwidthLimitRule(QosRule):
 
     db_model = qos_db_model.QosBandwidthLimitRule
 
-    rule_type = qos_consts.RULE_TYPE_BANDWIDTH_LIMIT
-
     fields = {
         'max_kbps': obj_fields.IntegerField(nullable=True),
         'max_burst_kbps': obj_fields.IntegerField(nullable=True)
index 6c587db10167254cb8093b2d491bf599409e6eaa..528e2d29e5a88cf60456636dc152838514ca3e88 100644 (file)
@@ -27,33 +27,17 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
     def setUp(self):
         super(QosPolicyObjectTestCase, self).setUp()
-        self.db_qos_rules = [self.get_random_fields(rule.QosRule)
-                             for _ in range(3)]
-
-        # Tie qos rules with policies
-        self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id']
-        self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id']
-        self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id']
-
+        # qos_policy_ids will be incorrect, but we don't care in this test
         self.db_qos_bandwidth_rules = [
             self.get_random_fields(rule.QosBandwidthLimitRule)
             for _ in range(3)]
 
-        # Tie qos rules with qos bandwidth limit rules
-        for i, qos_rule in enumerate(self.db_qos_rules):
-            self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id']
-
         self.model_map = {
             self._test_class.db_model: self.db_objs,
-            rule.QosRule.base_db_model: self.db_qos_rules,
             rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules}
 
-    def fake_get_objects(self, context, model, qos_policy_id=None):
-        objs = self.model_map[model]
-        if model is rule.QosRule.base_db_model and qos_policy_id:
-            return [obj for obj in objs
-                    if obj['qos_policy_id'] == qos_policy_id]
-        return objs
+    def fake_get_objects(self, context, model, **kwargs):
+        return self.model_map[model]
 
     def fake_get_object(self, context, model, id):
         objects = self.model_map[model]
@@ -76,8 +60,8 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
                     objs = self._test_class.get_objects(self.context)
                     context_mock.assert_called_once_with()
-                    get_objects_mock.assert_any_call(
-                        admin_context, self._test_class.db_model)
+            get_objects_mock.assert_any_call(
+                admin_context, self._test_class.db_model)
         self._validate_objects(self.db_objs, objs)
 
     def test_get_by_id(self):
index 6a3736e1756280de16e1015751834951ca48b85a..f42476998c379a0ea0e5d4ba2eb0bc6cb1ef56ef 100644 (file)
@@ -10,9 +10,6 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-import mock
-
-from neutron.db import api as db_api
 from neutron.objects.qos import policy
 from neutron.objects.qos import rule
 from neutron.tests.unit.objects import test_base
@@ -23,96 +20,6 @@ class QosBandwidthLimitRuleObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
     _test_class = rule.QosBandwidthLimitRule
 
-    @classmethod
-    def get_random_fields(cls):
-        # object middleware should not allow random types, so override it with
-        # proper type
-        fields = (super(QosBandwidthLimitRuleObjectTestCase, cls)
-                  .get_random_fields())
-        fields['type'] = cls._test_class.rule_type
-        return fields
-
-    def _filter_db_object(self, func):
-        return {
-            field: self.db_obj[field]
-            for field in self._test_class.fields
-            if func(field)
-        }
-
-    def _get_core_db_obj(self):
-        return self._filter_db_object(
-            lambda field: self._test_class._is_core_field(field))
-
-    def _get_addn_db_obj(self):
-        return self._filter_db_object(
-            lambda field: self._test_class._is_addn_field(field))
-
-    def test_get_by_id(self):
-        with mock.patch.object(db_api, 'get_object',
-                               return_value=self.db_obj) as get_object_mock:
-            obj = self._test_class.get_by_id(self.context, id='fake_id')
-            self.assertTrue(self._is_test_class(obj))
-            self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
-            get_object_mock.assert_has_calls([
-                mock.call(self.context, model, id='fake_id')
-                for model in (self._test_class.db_model,
-                              self._test_class.base_db_model)
-            ], any_order=True)
-
-    def test_get_objects(self):
-        with mock.patch.object(db_api, 'get_objects',
-                               return_value=self.db_objs):
-
-            @classmethod
-            def _get_by_id(cls, context, id):
-                for db_obj in self.db_objs:
-                    if db_obj['id'] == id:
-                        return self._test_class(context, **db_obj)
-
-            with mock.patch.object(rule.QosRule, 'get_by_id', new=_get_by_id):
-                objs = self._test_class.get_objects(self.context)
-                self.assertFalse(
-                    filter(lambda obj: not self._is_test_class(obj), objs))
-                self.assertEqual(
-                    sorted(self.db_objs),
-                    sorted(test_base.get_obj_db_fields(obj) for obj in objs))
-
-    def test_create(self):
-        with mock.patch.object(db_api, 'create_object',
-                               return_value=self.db_obj) as create_mock:
-            test_class = self._test_class
-            obj = test_class(self.context, **self.db_obj)
-            self._check_equal(obj, self.db_obj)
-            obj.create()
-            self._check_equal(obj, self.db_obj)
-
-            core_db_obj = self._get_core_db_obj()
-            addn_db_obj = self._get_addn_db_obj()
-            create_mock.assert_has_calls(
-                [mock.call(self.context, self._test_class.base_db_model,
-                           core_db_obj),
-                 mock.call(self.context, self._test_class.db_model,
-                           addn_db_obj)]
-            )
-
-    def test_update_changes(self):
-        with mock.patch.object(db_api, 'update_object',
-                               return_value=self.db_obj) as update_mock:
-            obj = self._test_class(self.context, **self.db_obj)
-            self._check_equal(obj, self.db_obj)
-            obj.update()
-            self._check_equal(obj, self.db_obj)
-
-            core_db_obj = self._get_core_db_obj()
-            update_mock.assert_any_call(
-                self.context, self._test_class.base_db_model, obj.id,
-                core_db_obj)
-
-            addn_db_obj = self._get_addn_db_obj()
-            update_mock.assert_any_call(
-                self.context, self._test_class.db_model, obj.id,
-                addn_db_obj)
-
 
 class QosBandwidthLimitRuleDbObjectTestCase(test_base.BaseDbObjectTestCase,
                                             testlib_api.SqlTestCase):
index 932e22ab0eb36f62ca458dd860cc9e3caa77dda1..812939956c831bef8a49097e688a09ba57535988 100644 (file)
@@ -17,6 +17,7 @@ import mock
 from oslo_versionedobjects import base as obj_base
 from oslo_versionedobjects import fields as obj_fields
 
+from neutron.common import exceptions as n_exc
 from neutron import context
 from neutron.db import api as db_api
 from neutron.objects import base
@@ -39,6 +40,8 @@ class FakeNeutronObject(base.NeutronDbObject):
         'field2': obj_fields.StringField()
     }
 
+    fields_no_update = ['id']
+
 
 def _random_string(n=10):
     return ''.join(random.choice(string.ascii_lowercase) for _ in range(n))
@@ -86,6 +89,9 @@ class _BaseObjectTestCase(object):
                 fields[field] = generator()
         return fields
 
+    def get_updatable_fields(self, fields):
+        return base.get_updatable_fields(self._test_class, fields)
+
     @classmethod
     def _is_test_class(cls, obj):
         return isinstance(obj, cls._test_class)
@@ -145,37 +151,48 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
             obj.create()
             self._check_equal(obj, self.db_obj)
 
-    def test_update_no_changes(self):
-        with mock.patch.object(db_api, 'update_object',
-                               return_value=self.db_obj) as update_mock:
-            obj = self._test_class(self.context, **self.db_obj)
-            self._check_equal(obj, self.db_obj)
+    @mock.patch.object(db_api, 'update_object')
+    def test_update_no_changes(self, update_mock):
+        with mock.patch.object(base.NeutronDbObject,
+                               '_get_changed_persistent_fields',
+                               return_value={}):
+            obj = self._test_class(self.context)
             obj.update()
-            self.assertTrue(update_mock.called)
-
-            # consequent call to update does not try to update database
-            update_mock.reset_mock()
-            obj.update()
-            self._check_equal(obj, self.db_obj)
             self.assertFalse(update_mock.called)
 
-    def test_update_changes(self):
-        with mock.patch.object(db_api, 'update_object',
-                               return_value=self.db_obj) as update_mock:
+    @mock.patch.object(db_api, 'update_object')
+    def test_update_changes(self, update_mock):
+        fields_to_update = self.get_updatable_fields(self.db_obj)
+        with mock.patch.object(base.NeutronDbObject,
+                               '_get_changed_persistent_fields',
+                               return_value=fields_to_update):
             obj = self._test_class(self.context, **self.db_obj)
-            self._check_equal(obj, self.db_obj)
             obj.update()
-            self._check_equal(obj, self.db_obj)
             update_mock.assert_called_once_with(
                 self.context, self._test_class.db_model,
-                self.db_obj['id'], self.db_obj)
+                self.db_obj['id'], fields_to_update)
+
+    @mock.patch.object(base.NeutronDbObject,
+                       '_get_changed_persistent_fields',
+                       return_value={'a': 'a', 'b': 'b', 'c': 'c'})
+    def test_update_changes_forbidden(self, *mocks):
+        with mock.patch.object(
+            self._test_class,
+            'fields_no_update',
+            new_callable=mock.PropertyMock(return_value=['a', 'c']),
+            create=True):
+            obj = self._test_class(self.context, **self.db_obj)
+            self.assertRaises(base.NeutronObjectUpdateForbidden, obj.update)
 
     def test_update_updates_from_db_object(self):
         with mock.patch.object(db_api, 'update_object',
                                return_value=self.db_obj):
             obj = self._test_class(self.context, **self.db_objs[1])
-            self._check_equal(obj, self.db_objs[1])
-            obj.update()
+            fields_to_update = self.get_updatable_fields(self.db_objs[1])
+            with mock.patch.object(base.NeutronDbObject,
+                                   '_get_changed_persistent_fields',
+                                   return_value=fields_to_update):
+                obj.update()
             self._check_equal(obj, self.db_obj)
 
     @mock.patch.object(db_api, 'delete_object')
@@ -198,9 +215,9 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
         self.assertEqual(obj, new)
 
         obj = new
-        for key, val in self.db_objs[1].items():
-            if key not in self._test_class.fields_no_update:
-                setattr(obj, key, val)
+
+        for key, val in self.get_updatable_fields(self.db_objs[1]).items():
+            setattr(obj, key, val)
         obj.update()
 
         new = self._test_class.get_by_id(self.context, id=obj.id)
@@ -211,3 +228,16 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
 
         new = self._test_class.get_by_id(self.context, id=obj.id)
         self.assertIsNone(new)
+
+    def test_update_non_existent_object_raises_not_found(self):
+        obj = self._test_class(self.context, **self.db_obj)
+        obj.obj_reset_changes()
+
+        for key, val in self.get_updatable_fields(self.db_obj).items():
+            setattr(obj, key, val)
+
+        self.assertRaises(n_exc.ObjectNotFound, obj.update)
+
+    def test_delete_non_existent_object_raises_not_found(self):
+        obj = self._test_class(self.context, **self.db_obj)
+        self.assertRaises(n_exc.ObjectNotFound, obj.delete)
index 8254da6356f3bd4a71e844bda81b1ea9be2fb14c..df26a4eaa4b55c5107469331caa60478075bb794 100644 (file)
@@ -18,6 +18,7 @@ from neutron.api.rpc.callbacks import resources
 from neutron.common import exceptions as n_exc
 from neutron import context
 from neutron import manager
+from neutron.objects import base as base_object
 from neutron.objects.qos import policy as policy_object
 from neutron.objects.qos import rule as rule_object
 from neutron.plugins.common import constants
@@ -80,8 +81,10 @@ class TestQosPlugin(base.BaseTestCase):
         self.assertFalse(self.registry_m.called)
 
     def test_update_policy(self):
+        fields = base_object.get_updatable_fields(
+            policy_object.QosPolicy, self.policy_data['policy'])
         self.qos_plugin.update_policy(
-            self.ctxt, self.policy.id, self.policy_data)
+            self.ctxt, self.policy.id, {'policy': fields})
         self._validate_registry_params(events.UPDATED)
 
     def test_delete_policy(self):