Now all objects are comparable.
We need to reset changes, otherwise an object that is constructed and
.create()d is different from the one that is .get_by_id()d from database
(for primitive serialization contains list of changed fields for
versioned objects).
Added initial sql test case for objects (just create-fetch for policy
for now, but can be easily extended to other types).
Change-Id: I012b5fe4e95f166f66da91274734d7184c224dfd
@six.add_metaclass(abc.ABCMeta)
class NeutronObject(obj_base.VersionedObject,
- obj_base.VersionedObjectDictCompat):
+ obj_base.VersionedObjectDictCompat,
+ obj_base.ComparableVersionedObject):
# should be overridden for all persistent objects
db_model = None
@classmethod
def get_by_id(cls, context, id):
db_obj = db_api.get_object(context, cls.db_model, id)
- return cls(context, **db_obj)
+ obj = cls(context, **db_obj)
+ obj.obj_reset_changes()
+ return obj
@classmethod
def get_objects(cls, context):
db_objs = db_api.get_objects(context, cls.db_model)
objs = [cls(context, **db_obj) for db_obj in db_objs]
+ for obj in objs:
+ obj.obj_reset_changes()
return objs
def create(self):
from neutron.objects.qos import policy
from neutron.tests.unit.objects import test_base
+from neutron.tests.unit import testlib_api
-class QosPolicyObjectTestCase(test_base.BaseObjectTestCase):
+class QosPolicyBaseTestCase(object):
_test_class = policy.QosPolicy
+
+
+class QosPolicyObjectTestCase(QosPolicyBaseTestCase,
+ test_base.BaseObjectIfaceTestCase):
+ pass
+
+
+class QosPolicyDbObjectTestCase(QosPolicyBaseTestCase,
+ test_base.BaseDbObjectTestCase,
+ testlib_api.SqlTestCase):
+ pass
from neutron.tests.unit.objects import test_base
-class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectTestCase):
+class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = rule.QosBandwidthLimitRule
return {field: getattr(obj, field) for field in obj.fields}
-class BaseObjectTestCase(test_base.BaseTestCase):
+class _BaseObjectTestCase(object):
_test_class = FakeNeutronObject
def setUp(self):
- super(BaseObjectTestCase, self).setUp()
+ super(_BaseObjectTestCase, self).setUp()
self.context = context.get_admin_context()
self.db_objs = list(self._get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0]
def _is_test_class(cls, obj):
return isinstance(obj, cls._test_class)
+
+class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
+
def test_get_by_id(self):
with mock.patch.object(db_api, 'get_object',
return_value=self.db_obj) as get_object_mock:
self._check_equal(obj, self.db_obj)
delete_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.db_obj['id'])
+
+
+class BaseDbObjectTestCase(_BaseObjectTestCase):
+
+ def test_create(self):
+ obj = self._test_class(self.context, **self.db_obj)
+ obj.create()
+
+ new = self._test_class.get_by_id(self.context, id=obj.id)
+ self.assertEqual(obj, new)