# Common database operation implementations
-# TODO(QoS): consider reusing get_objects below
-# TODO(QoS): consider changing the name and making it public, officially
-def _find_object(context, model, **kwargs):
+def get_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(**kwargs)
.first())
-def get_object(context, model, id):
- # TODO(QoS): consider reusing get_objects below
- with context.session.begin(subtransactions=True):
- return (common_db_mixin.model_query(context, model)
- .filter_by(id=id)
- .first())
-
-
def get_objects(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
def update_object(context, model, id, values):
with context.session.begin(subtransactions=True):
- db_obj = get_object(context, model, id)
+ db_obj = get_object(context, model, id=id)
db_obj.update(values)
db_obj.save(session=context.session)
return db_obj.__dict__
def delete_object(context, model, id):
with context.session.begin(subtransactions=True):
- db_obj = get_object(context, model, id)
+ db_obj = get_object(context, model, id=id)
context.session.delete(db_obj)
@classmethod
def get_by_id(cls, context, id):
- db_obj = db_api.get_object(context, cls.db_model, id)
+ db_obj = db_api.get_object(context, cls.db_model, id=id)
if db_obj:
obj = cls(context, **db_obj)
obj.obj_reset_changes()
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
- # TODO(QoS): we should make sure we use public functions
- binding_db_obj = db_api._find_object(context, model, **kwargs)
+ binding_db_obj = db_api.get_object(context, model, **kwargs)
# TODO(QoS): rethink handling missing binding case
if binding_db_obj:
return cls.get_by_id(context, binding_db_obj['policy_id'])
if obj:
# the object above does not contain fields from base QosRule yet,
# so fetch it and mix its fields into the object
- base_db_obj = db_api.get_object(context, cls.base_db_model, id)
+ base_db_obj = db_api.get_object(context, cls.base_db_model, id=id)
for field in cls._core_fields:
setattr(obj, field, base_db_obj[field])
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
get_object_mock.assert_has_calls([
- mock.call(self.context, model, 'fake_id')
+ mock.call(self.context, model, id='fake_id')
for model in (self._test_class.db_model,
self._test_class.base_db_model)
], any_order=True)
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, get_obj_db_fields(obj))
get_object_mock.assert_called_once_with(
- self.context, self._test_class.db_model, 'fake_id')
+ self.context, self._test_class.db_model, id='fake_id')
def test_get_by_id_missing_object(self):
with mock.patch.object(db_api, 'get_object', return_value=None):