From: Ihar Hrachyshka Date: Wed, 8 Jul 2015 16:06:12 +0000 (+0200) Subject: objects.base: reset changes after getting objects from database X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=3edec57c2250daafdcdac88581efa1acc5acf237;p=openstack-build%2Fneutron-build.git objects.base: reset changes after getting objects from database Now all objects are comparable. We need to reset changes, otherwise an object that is constructed and .create()d is different from the one that is .get_by_id()d from database (for primitive serialization contains list of changed fields for versioned objects). Added initial sql test case for objects (just create-fetch for policy for now, but can be easily extended to other types). Change-Id: I012b5fe4e95f166f66da91274734d7184c224dfd --- diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 57f785ea4..d3f75c20d 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -23,7 +23,8 @@ from neutron.db import api as db_api @six.add_metaclass(abc.ABCMeta) class NeutronObject(obj_base.VersionedObject, - obj_base.VersionedObjectDictCompat): + obj_base.VersionedObjectDictCompat, + obj_base.ComparableVersionedObject): # should be overridden for all persistent objects db_model = None @@ -39,12 +40,16 @@ class NeutronObject(obj_base.VersionedObject, @classmethod def get_by_id(cls, context, id): db_obj = db_api.get_object(context, cls.db_model, id) - return cls(context, **db_obj) + obj = cls(context, **db_obj) + obj.obj_reset_changes() + return obj @classmethod def get_objects(cls, context): db_objs = db_api.get_objects(context, cls.db_model) objs = [cls(context, **db_obj) for db_obj in db_objs] + for obj in objs: + obj.obj_reset_changes() return objs def create(self): diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index 8997482df..e88b7915a 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -12,8 +12,20 @@ from neutron.objects.qos import policy from neutron.tests.unit.objects import test_base +from neutron.tests.unit import testlib_api -class QosPolicyObjectTestCase(test_base.BaseObjectTestCase): +class QosPolicyBaseTestCase(object): _test_class = policy.QosPolicy + + +class QosPolicyObjectTestCase(QosPolicyBaseTestCase, + test_base.BaseObjectIfaceTestCase): + pass + + +class QosPolicyDbObjectTestCase(QosPolicyBaseTestCase, + test_base.BaseDbObjectTestCase, + testlib_api.SqlTestCase): + pass diff --git a/neutron/tests/unit/objects/qos/test_rule.py b/neutron/tests/unit/objects/qos/test_rule.py index e7656e871..867a0b977 100644 --- a/neutron/tests/unit/objects/qos/test_rule.py +++ b/neutron/tests/unit/objects/qos/test_rule.py @@ -17,7 +17,7 @@ from neutron.objects.qos import rule from neutron.tests.unit.objects import test_base -class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectTestCase): +class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): _test_class = rule.QosBandwidthLimitRule diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index f0378cff1..6e6541c75 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -59,12 +59,12 @@ def get_obj_fields(obj): return {field: getattr(obj, field) for field in obj.fields} -class BaseObjectTestCase(test_base.BaseTestCase): +class _BaseObjectTestCase(object): _test_class = FakeNeutronObject def setUp(self): - super(BaseObjectTestCase, self).setUp() + super(_BaseObjectTestCase, self).setUp() self.context = context.get_admin_context() self.db_objs = list(self._get_random_fields() for _ in range(3)) self.db_obj = self.db_objs[0] @@ -81,6 +81,9 @@ class BaseObjectTestCase(test_base.BaseTestCase): def _is_test_class(cls, obj): return isinstance(obj, cls._test_class) + +class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): + def test_get_by_id(self): with mock.patch.object(db_api, 'get_object', return_value=self.db_obj) as get_object_mock: @@ -166,3 +169,13 @@ class BaseObjectTestCase(test_base.BaseTestCase): self._check_equal(obj, self.db_obj) delete_mock.assert_called_once_with( self.context, self._test_class.db_model, self.db_obj['id']) + + +class BaseDbObjectTestCase(_BaseObjectTestCase): + + def test_create(self): + obj = self._test_class(self.context, **self.db_obj) + obj.create() + + new = self._test_class.get_by_id(self.context, id=obj.id) + self.assertEqual(obj, new)