This is a significant piece of work.
It enables neutron objects to define fields that are lazily loaded on
field access. To achieve that,
- field should be mentioned in cls.synthetic_fields
- obj_load_attr should be extended to lazily fetch and cache the field
Based on this work, we define per type rule fields that are lists of
appropriate neutron objects. (At the moment, we have only single type
supported, but I tried hard to make it easily extendable, with little or
no coding needed when a new rule type object definition is added to
rule.py: for example, we inspect object definitions based on
VALID_RULE_TYPES, and define appropriate fields for the policy object).
To implement lazy loading for those fields, I redefined get_by_id for
rules that now meld fields from both base and subtype db models into the
corresponding neutron object.
Added a simple test that checks bandwidth_rules attribute behaves for
policies.
Some objects unit test framework rework was needed to accomodate
synthetic fields that are not propagated to db layer.
Change-Id: Ia16393453b1ed48651fbd778bbe0ac6427560117
class NetworkSubnetPoolAffinityError(BadRequest):
message = _("Subnets hosted on the same network must be allocated from "
"the same subnet pool")
+
+
+class ObjectActionError(NeutronException):
+ message = _('Object action %(action)s failed because: %(reason)s')
def __str__(self):
return str(self.function(*self.args, **self.kwargs))
+
+
+def camelize(s):
+ return ''.join(s.replace('_', ' ').title().split())
# Common database operation implementations
-# TODO(QoS): consider handling multiple objects found, or no objects at all
+# TODO(QoS): consider reusing get_objects below
# TODO(QoS): consider changing the name and making it public, officially
def _find_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
def get_object(context, model, id):
+ # TODO(QoS): consider reusing get_objects below
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(id=id)
.first())
-def get_objects(context, model):
+def get_objects(context, model, **kwargs):
with context.session.begin(subtransactions=True):
- return common_db_mixin.model_query(context, model).all()
+ return (common_db_mixin.model_query(context, model)
+ .filter_by(**kwargs)
+ .all())
def create_object(context, model, values):
# fields that are not allowed to update
fields_no_update = []
+ synthetic_fields = []
+
def from_db_object(self, *objs):
for field in self.fields:
for db_obj in objs:
return obj
@classmethod
- def get_objects(cls, context):
- db_objs = db_api.get_objects(context, cls.db_model)
+ def get_objects(cls, context, **kwargs):
+ db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs]
for obj in objs:
obj.obj_reset_changes()
return objs
- def create(self):
+ def _get_changed_persistent_fields(self):
fields = self.obj_get_changes()
+ for field in self.synthetic_fields:
+ if field in fields:
+ del fields[field]
+ return fields
+
+ def create(self):
+ fields = self._get_changed_persistent_fields()
db_obj = db_api.create_object(self._context, self.db_model, fields)
self.from_db_object(db_obj)
def update(self):
- # TODO(QoS): enforce fields_no_update
- updates = self.obj_get_changes()
+ updates = self._get_changed_persistent_fields()
if updates:
db_obj = db_api.update_object(self._context, self.db_model,
self.id, updates)
# License for the specific language governing permissions and limitations
# under the License.
+import abc
+
from oslo_versionedobjects import base as obj_base
from oslo_versionedobjects import fields as obj_fields
+import six
+from neutron.common import exceptions
+from neutron.common import utils
from neutron.db import api as db_api
from neutron.db.qos import api as qos_db_api
from neutron.db.qos import models as qos_db_model
+from neutron.extensions import qos as qos_extension
from neutron.objects import base
+from neutron.objects.qos import rule as rule_obj_impl
+
+
+class QosRulesExtenderMeta(abc.ABCMeta):
+
+ def __new__(cls, *args, **kwargs):
+ cls_ = super(QosRulesExtenderMeta, cls).__new__(cls, *args, **kwargs)
+ cls_.rule_fields = {}
+ for rule in qos_extension.VALID_RULE_TYPES:
+ rule_cls_name = 'Qos%sRule' % utils.camelize(rule)
+ field = '%s_rules' % rule
+ cls_.fields[field] = obj_fields.ListOfObjectsField(rule_cls_name)
+ cls_.rule_fields[field] = rule_cls_name
-# TODO(QoS): add rule lists to object fields
-# TODO(QoS): implement something for binding networks and ports with policies
+ cls_.synthetic_fields = list(cls_.rule_fields.keys())
+
+ return cls_
@obj_base.VersionedObjectRegistry.register
+@six.add_metaclass(QosRulesExtenderMeta)
class QosPolicy(base.NeutronObject):
db_model = qos_db_model.QosPolicy
fields_no_update = ['id', 'tenant_id']
+ def obj_load_attr(self, attrname):
+ if attrname not in self.rule_fields:
+ raise exceptions.ObjectActionError(
+ action='obj_load_attr', reason='unable to load %s' % attrname)
+
+ rule_cls = getattr(rule_obj_impl, self.rule_fields[attrname])
+ rules = rule_cls.get_rules_by_policy(self._context, self.id)
+ setattr(self, attrname, rules)
+ self.obj_reset_changes([attrname])
+
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
# TODO(QoS): we should make sure we use public functions
from neutron.db import api as db_api
from neutron.db.qos import models as qos_db_model
+from neutron.extensions import qos as qos_extension
from neutron.objects import base
fields_no_update = ['id', 'tenant_id', 'qos_policy_id']
+ # each rule subclass should redefine it
+ rule_type = None
+
_core_fields = list(fields.keys())
_common_fields = ['id']
if func(key)
}
- # TODO(QoS): reimplement get_by_id to merge both core and addn fields
-
def _get_changed_core_fields(self):
fields = self.obj_get_changes()
return self._filter_fields(fields, self._is_core_field)
for field in self._common_fields:
to_[field] = from_[field]
+ @classmethod
+ def get_objects(cls, context, **kwargs):
+ # TODO(QoS): support searching for subtype fields
+ db_objs = db_api.get_objects(context, cls.base_db_model, **kwargs)
+ return [cls.get_by_id(context, db_obj['id']) for db_obj in db_objs]
+
+ @classmethod
+ def get_by_id(cls, context, id):
+ obj = super(QosRule, cls).get_by_id(context, id)
+
+ if obj:
+ # the object above does not contain fields from base QosRule yet,
+ # so fetch it and mix its fields into the object
+ base_db_obj = db_api.get_object(context, cls.base_db_model, id)
+ for field in cls._core_fields:
+ setattr(obj, field, base_db_obj[field])
+
+ obj.obj_reset_changes()
+ return obj
+
# TODO(QoS): create and update are not transactional safe
def create(self):
+ # TODO(QoS): enforce that type field value is bound to specific class
+ self.type = self.rule_type
+
# create base qos_rule
core_fields = self._get_changed_core_fields()
base_db_obj = db_api.create_object(
def update(self):
updated_db_objs = []
+ # TODO(QoS): enforce that type field cannot be changed
+
# update base qos_rule, if needed
core_fields = self._get_changed_core_fields()
if core_fields:
# delete is the same, additional rule object cleanup is done thru cascading
+ @classmethod
+ def get_rules_by_policy(cls, context, policy_id):
+ return cls.get_objects(context, qos_policy_id=policy_id)
+
@obj_base.VersionedObjectRegistry.register
class QosBandwidthLimitRule(QosRule):
db_model = qos_db_model.QosBandwidthLimitRule
+ rule_type = qos_extension.RULE_TYPE_BANDWIDTH_LIMIT
+
fields = {
- 'max_kbps': obj_fields.IntegerField(),
- 'max_burst_kbps': obj_fields.IntegerField()
+ 'max_kbps': obj_fields.IntegerField(nullable=True),
+ 'max_burst_kbps': obj_fields.IntegerField(nullable=True)
}
LOG.logger.setLevel(logging.logging.DEBUG)
LOG.debug("Hello %s", delayed)
self.assertTrue(my_func.called)
+
+
+class TestCamelize(base.BaseTestCase):
+ def test_camelize(self):
+ data = {'bandwidth_limit': 'BandwidthLimit',
+ 'test': 'Test',
+ 'some__more__dashes': 'SomeMoreDashes',
+ 'a_penguin_walks_into_a_bar': 'APenguinWalksIntoABar'}
+
+ for s, expected in data.items():
+ self.assertEqual(expected, utils.camelize(s))
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.objects.qos import policy
+from neutron.objects.qos import rule
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
policy_obj = policy.QosPolicy.get_network_policy(self.context,
self._network['id'])
self.assertIsNone(policy_obj)
+
+ def test_synthetic_rule_fields(self):
+ obj = policy.QosPolicy(self.context, **self.db_obj)
+ obj.create()
+
+ rule_fields = self.get_random_fields(
+ obj_cls=rule.QosBandwidthLimitRule)
+ rule_fields['qos_policy_id'] = obj.id
+ rule_fields['tenant_id'] = obj.tenant_id
+
+ rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
+ rule_obj.create()
+
+ obj = policy.QosPolicy.get_by_id(self.context, obj.id)
+ self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
_test_class = rule.QosBandwidthLimitRule
+ @classmethod
+ def get_random_fields(cls):
+ # object middleware should not allow random types, so override it with
+ # proper type
+ fields = (super(QosBandwidthLimitPolicyObjectTestCase, cls)
+ .get_random_fields())
+ fields['type'] = cls._test_class.rule_type
+ return fields
+
def _filter_db_object(self, func):
return {
field: self.db_obj[field]
return self._filter_db_object(
lambda field: self._test_class._is_addn_field(field))
+ def test_get_by_id(self):
+ with mock.patch.object(db_api, 'get_object',
+ return_value=self.db_obj) as get_object_mock:
+ obj = self._test_class.get_by_id(self.context, id='fake_id')
+ self.assertTrue(self._is_test_class(obj))
+ self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
+ get_object_mock.assert_has_calls([
+ mock.call(self.context, model, 'fake_id')
+ for model in (self._test_class.db_model,
+ self._test_class.base_db_model)
+ ], any_order=True)
+
+ def test_get_objects(self):
+ with mock.patch.object(db_api, 'get_objects',
+ return_value=self.db_objs):
+
+ @classmethod
+ def _get_by_id(cls, context, id):
+ for db_obj in self.db_objs:
+ if db_obj['id'] == id:
+ return self._test_class(context, **db_obj)
+
+ with mock.patch.object(rule.QosRule, 'get_by_id', new=_get_by_id):
+ objs = self._test_class.get_objects(self.context)
+ self.assertFalse(
+ filter(lambda obj: not self._is_test_class(obj), objs))
+ self.assertEqual(
+ sorted(self.db_objs),
+ sorted(test_base.get_obj_db_fields(obj) for obj in objs))
+
def test_create(self):
with mock.patch.object(db_api, 'create_object',
return_value=self.db_obj) as create_mock:
self._check_equal(obj, self.db_obj)
core_db_obj = self._get_core_db_obj()
- create_mock.assert_any_call(
- self.context, self._test_class.base_db_model, core_db_obj)
-
addn_db_obj = self._get_addn_db_obj()
- create_mock.assert_any_call(
- self.context, self._test_class.db_model,
- addn_db_obj)
+ create_mock.assert_has_calls(
+ [mock.call(self.context, self._test_class.base_db_model,
+ core_db_obj),
+ mock.call(self.context, self._test_class.db_model,
+ addn_db_obj)]
+ )
def test_update_changes(self):
with mock.patch.object(db_api, 'update_object',
obj_fields.IntegerField: _random_integer,
obj_fields.StringField: _random_string,
obj_fields.UUIDField: _random_string,
+ obj_fields.ListOfObjectsField: lambda: []
}
-def get_obj_fields(obj):
- return {field: getattr(obj, field) for field in obj.fields}
+def get_obj_db_fields(obj):
+ return {field: getattr(obj, field) for field in obj.fields
+ if field not in obj.synthetic_fields}
class _BaseObjectTestCase(object):
def setUp(self):
super(_BaseObjectTestCase, self).setUp()
self.context = context.get_admin_context()
- self.db_objs = list(self._get_random_fields() for _ in range(3))
+ self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0]
@classmethod
- def _get_random_fields(cls):
+ def get_random_fields(cls, obj_cls=None):
+ obj_cls = obj_cls or cls._test_class
fields = {}
- for field in cls._test_class.fields:
- field_obj = cls._test_class.fields[field]
- fields[field] = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]()
+ for field, field_obj in obj_cls.fields.items():
+ if field not in obj_cls.synthetic_fields:
+ generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
+ fields[field] = generator()
return fields
@classmethod
return_value=self.db_obj) as get_object_mock:
obj = self._test_class.get_by_id(self.context, id='fake_id')
self.assertTrue(self._is_test_class(obj))
- self.assertEqual(self.db_obj, get_obj_fields(obj))
+ self.assertEqual(self.db_obj, get_obj_db_fields(obj))
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model, 'fake_id')
filter(lambda obj: not self._is_test_class(obj), objs))
self.assertEqual(
sorted(self.db_objs),
- sorted(get_obj_fields(obj) for obj in objs))
+ sorted(get_obj_db_fields(obj) for obj in objs))
get_objects_mock.assert_called_once_with(
self.context, self._test_class.db_model)
def _check_equal(self, obj, db_obj):
self.assertEqual(
sorted(db_obj),
- sorted(get_obj_fields(obj)))
+ sorted(get_obj_db_fields(obj)))
def test_create(self):
with mock.patch.object(db_api, 'create_object',