]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
objects.base: reset changes after getting objects from database
authorIhar Hrachyshka <ihrachys@redhat.com>
Wed, 8 Jul 2015 16:06:12 +0000 (18:06 +0200)
committerIhar Hrachyshka <ihrachys@redhat.com>
Fri, 10 Jul 2015 18:00:48 +0000 (18:00 +0000)
Now all objects are comparable.

We need to reset changes, otherwise an object that is constructed and
.create()d is different from the one that is .get_by_id()d from database
(for primitive serialization contains list of changed fields for
versioned objects).

Added initial sql test case for objects (just create-fetch for policy
for now, but can be easily extended to other types).

Change-Id: I012b5fe4e95f166f66da91274734d7184c224dfd

neutron/objects/base.py
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/objects/qos/test_rule.py
neutron/tests/unit/objects/test_base.py

index 57f785ea41fc24b845a51d40e1b48b72df77476d..d3f75c20deabab4f372f2c67f1ca671d1ffc0954 100644 (file)
@@ -23,7 +23,8 @@ from neutron.db import api as db_api
 
 @six.add_metaclass(abc.ABCMeta)
 class NeutronObject(obj_base.VersionedObject,
-                    obj_base.VersionedObjectDictCompat):
+                    obj_base.VersionedObjectDictCompat,
+                    obj_base.ComparableVersionedObject):
 
     # should be overridden for all persistent objects
     db_model = None
@@ -39,12 +40,16 @@ class NeutronObject(obj_base.VersionedObject,
     @classmethod
     def get_by_id(cls, context, id):
         db_obj = db_api.get_object(context, cls.db_model, id)
-        return cls(context, **db_obj)
+        obj = cls(context, **db_obj)
+        obj.obj_reset_changes()
+        return obj
 
     @classmethod
     def get_objects(cls, context):
         db_objs = db_api.get_objects(context, cls.db_model)
         objs = [cls(context, **db_obj) for db_obj in db_objs]
+        for obj in objs:
+            obj.obj_reset_changes()
         return objs
 
     def create(self):
index 8997482dff1a38625bcd8948125e430b3e4514da..e88b7915a7d440cfb721dc6a1e4c120e2506f5e4 100644 (file)
 
 from neutron.objects.qos import policy
 from neutron.tests.unit.objects import test_base
+from neutron.tests.unit import testlib_api
 
 
-class QosPolicyObjectTestCase(test_base.BaseObjectTestCase):
+class QosPolicyBaseTestCase(object):
 
     _test_class = policy.QosPolicy
+
+
+class QosPolicyObjectTestCase(QosPolicyBaseTestCase,
+                              test_base.BaseObjectIfaceTestCase):
+    pass
+
+
+class QosPolicyDbObjectTestCase(QosPolicyBaseTestCase,
+                                test_base.BaseDbObjectTestCase,
+                                testlib_api.SqlTestCase):
+    pass
index e7656e871f4917098e3f1b754af6f2db536005b1..867a0b97744b42a69768b654366d57775eebb986 100644 (file)
@@ -17,7 +17,7 @@ from neutron.objects.qos import rule
 from neutron.tests.unit.objects import test_base
 
 
-class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectTestCase):
+class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
 
     _test_class = rule.QosBandwidthLimitRule
 
index f0378cff12fdd96be68bbd32e903001bb1b82fbe..6e6541c75ff1cf84eec43e9bf0a8fefc4c4017cc 100644 (file)
@@ -59,12 +59,12 @@ def get_obj_fields(obj):
     return {field: getattr(obj, field) for field in obj.fields}
 
 
-class BaseObjectTestCase(test_base.BaseTestCase):
+class _BaseObjectTestCase(object):
 
     _test_class = FakeNeutronObject
 
     def setUp(self):
-        super(BaseObjectTestCase, self).setUp()
+        super(_BaseObjectTestCase, self).setUp()
         self.context = context.get_admin_context()
         self.db_objs = list(self._get_random_fields() for _ in range(3))
         self.db_obj = self.db_objs[0]
@@ -81,6 +81,9 @@ class BaseObjectTestCase(test_base.BaseTestCase):
     def _is_test_class(cls, obj):
         return isinstance(obj, cls._test_class)
 
+
+class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
+
     def test_get_by_id(self):
         with mock.patch.object(db_api, 'get_object',
                                return_value=self.db_obj) as get_object_mock:
@@ -166,3 +169,13 @@ class BaseObjectTestCase(test_base.BaseTestCase):
         self._check_equal(obj, self.db_obj)
         delete_mock.assert_called_once_with(
             self.context, self._test_class.db_model, self.db_obj['id'])
+
+
+class BaseDbObjectTestCase(_BaseObjectTestCase):
+
+    def test_create(self):
+        obj = self._test_class(self.context, **self.db_obj)
+        obj.create()
+
+        new = self._test_class.get_by_id(self.context, id=obj.id)
+        self.assertEqual(obj, new)