op.create_table(
'qos_bandwidth_limit_rules',
- sa.Column('qos_rule_id', sa.String(length=36),
+ sa.Column('id', sa.String(length=36),
sa.ForeignKey('qos_rules.id', ondelete='CASCADE'),
nullable=False,
primary_key=True),
__tablename__ = 'qos_bandwidth_limit_rules'
max_kbps = sa.Column(sa.Integer)
max_burst_kbps = sa.Column(sa.Integer)
- qos_rule_id = sa.Column(sa.String(36),
- sa.ForeignKey('qos_rules.id',
- ondelete='CASCADE'),
- nullable=False,
- primary_key=True)
+ id = sa.Column(sa.String(36),
+ sa.ForeignKey('qos_rules.id',
+ ondelete='CASCADE'),
+ nullable=False,
+ primary_key=True)
_core_fields = list(fields.keys())
+ _common_fields = ['id']
+
+ @classmethod
+ def _is_common_field(cls, field):
+ return field in cls._common_fields
+
@classmethod
def _is_core_field(cls, field):
return field in cls._core_fields
+ @classmethod
+ def _is_addn_field(cls, field):
+ return not cls._is_core_field(field) or cls._is_common_field(field)
+
@staticmethod
def _filter_fields(fields, func):
return {
def _get_changed_addn_fields(self):
fields = self.obj_get_changes()
return self._filter_fields(
- fields, lambda key: not self._is_core_field(key))
+ fields, lambda key: self._is_addn_field(key))
+
+ def _copy_common_fields(self, from_, to_):
+ for field in self._common_fields:
+ to_[field] = from_[field]
# TODO(QoS): create and update are not transactional safe
def create(self):
# create type specific qos_..._rule
addn_fields = self._get_changed_addn_fields()
- addn_fields['qos_rule_id'] = base_db_obj.id
+ self._copy_common_fields(core_fields, addn_fields)
addn_db_obj = db_api.create_object(
self._context, self.db_model, addn_fields)
# merge two db objects into single neutron one
- self.from_db_object(self._context, self, base_db_obj, addn_db_obj)
+ self.from_db_object(base_db_obj, addn_db_obj)
def update(self):
updated_db_objs = []
# update base qos_rule, if needed
core_fields = self._get_changed_core_fields()
if core_fields:
- base_db_obj = db_api.create_object(
- self._context, self.base_db_model, core_fields)
+ base_db_obj = db_api.update_object(
+ self._context, self.base_db_model, self.id, core_fields)
updated_db_objs.append(base_db_obj)
addn_fields = self._get_changed_addn_fields()
if addn_fields:
addn_db_obj = db_api.update_object(
- self._context, self.base_db_model, self.id, addn_fields)
+ self._context, self.db_model, self.id, addn_fields)
updated_db_objs.append(addn_db_obj)
# update neutron object with values from both database objects
- self.from_db_object(self._context, self, *updated_db_objs)
+ self.from_db_object(*updated_db_objs)
# delete is the same, additional rule object cleanup is done thru cascading
--- /dev/null
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import mock
+
+from neutron.db import api as db_api
+from neutron.objects.qos import rule
+from neutron.tests.unit.objects import test_base
+
+
+class QosBandwidthLimitPolicyObjectTestCase(test_base.BaseObjectTestCase):
+
+ test_class = rule.QosBandwidthLimitRule
+
+ def _filter_db_object(self, func):
+ return {
+ field: self.db_obj[field]
+ for field in self.test_class.fields
+ if func(field)
+ }
+
+ def _get_core_db_obj(self):
+ return self._filter_db_object(
+ lambda field: self.test_class._is_core_field(field))
+
+ def _get_addn_db_obj(self):
+ return self._filter_db_object(
+ lambda field: self.test_class._is_addn_field(field))
+
+ def test_create(self):
+ with mock.patch.object(db_api, 'create_object',
+ return_value=self.db_obj) as create_mock:
+ test_class = self.test_class
+ obj = test_class(self.context, **self.db_obj)
+ self._check_equal(obj, self.db_obj)
+ obj.create()
+ self._check_equal(obj, self.db_obj)
+
+ core_db_obj = self._get_core_db_obj()
+ create_mock.assert_any_call(
+ self.context, self.test_class.base_db_model, core_db_obj)
+
+ addn_db_obj = self._get_addn_db_obj()
+ create_mock.assert_any_call(
+ self.context, self.test_class.db_model,
+ addn_db_obj)
+
+ def test_update_changes(self):
+ with mock.patch.object(db_api, 'update_object',
+ return_value=self.db_obj) as update_mock:
+ obj = self.test_class(self.context, **self.db_obj)
+ self._check_equal(obj, self.db_obj)
+ obj.update()
+ self._check_equal(obj, self.db_obj)
+
+ core_db_obj = self._get_core_db_obj()
+ update_mock.assert_any_call(
+ self.context, self.test_class.base_db_model, obj.id,
+ core_db_obj)
+
+ addn_db_obj = self._get_addn_db_obj()
+ update_mock.assert_any_call(
+ self.context, self.test_class.db_model, obj.id,
+ addn_db_obj)
return bool(random.getrandbits(1))
+def _random_integer():
+ return random.randint(0, 1000)
+
+
FIELD_TYPE_VALUE_GENERATOR_MAP = {
obj_fields.BooleanField: _random_boolean,
+ obj_fields.IntegerField: _random_integer,
obj_fields.StringField: _random_string,
obj_fields.UUIDField: _random_string,
}