LOG = logging.getLogger(__name__)
+def filter_fields(f):
+ @functools.wraps(f)
+ def inner_filter(*args, **kwargs):
+ result = f(*args, **kwargs)
+ fields = kwargs.get('fields')
+ if not fields:
+ pos = f.func_code.co_varnames.index('fields')
+ try:
+ fields = args[pos]
+ except IndexError:
+ return result
+
+ do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
+ if isinstance(result, list):
+ return [do_filter(obj) for obj in result]
+ else:
+ return do_filter(result)
+ return inner_filter
+
+
class DbBasePluginCommon(common_db_mixin.CommonDbMixin):
"""Stores getters and helper methods for db_base_plugin_v2
'convert_to': attr.convert_to_boolean},
'tenant_id': {'allow_post': True, 'allow_put': False,
'required_by_policy': True,
- 'is_visible': True}
+ 'is_visible': True},
+ 'bandwidth_limit_rules': {'allow_post': False, 'allow_put': False,
+ 'is_visible': True},
},
'rule_types': {
'type': {'allow_post': False, 'allow_put': False,
setattr(self, attrname, rules)
self.obj_reset_changes([attrname])
+ def _load_rules(self):
+ for attr in self.rule_fields:
+ self.obj_load_attr(attr)
+
+ @classmethod
+ def get_by_id(cls, context, id):
+ with db_api.autonested_transaction(context.session):
+ policy_obj = super(QosPolicy, cls).get_by_id(context, id)
+ if policy_obj:
+ policy_obj._load_rules()
+ return policy_obj
+
+ # TODO(QoS): Test that all objects are fetched within one transaction
+ @classmethod
+ def get_objects(cls, context, **kwargs):
+ with db_api.autonested_transaction(context.session):
+ db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
+ objs = list()
+ for db_obj in db_objs:
+ obj = cls(context, **db_obj)
+ obj._load_rules()
+ objs.append(obj)
+ return objs
+
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
- binding_db_obj = db_api.get_object(context, model, **kwargs)
- # TODO(QoS): rethink handling missing binding case
- if binding_db_obj:
- return cls.get_by_id(context, binding_db_obj['policy_id'])
+ with db_api.autonested_transaction(context.session):
+ binding_db_obj = db_api.get_object(context, model, **kwargs)
+ # TODO(QoS): rethink handling missing binding case
+ if binding_db_obj:
+ return cls.get_by_id(context, binding_db_obj['policy_id'])
@classmethod
def get_network_policy(cls, context, network_id):
return cls._get_object_policy(context, cls.port_binding_model,
port_id=port_id)
+ def create(self):
+ with db_api.autonested_transaction(self._context.session):
+ super(QosPolicy, self).create()
+ self._load_rules()
+
def attach_network(self, network_id):
qos_db_api.create_policy_network_binding(self._context,
policy_id=self.id,
obj.obj_reset_changes()
return obj
- # TODO(QoS): create and update are not transactional safe
+ # TODO(QoS): Test that create is in single transaction
def create(self):
# TODO(QoS): enforce that type field value is bound to specific class
# create base qos_rule
core_fields = self._get_changed_core_fields()
- base_db_obj = db_api.create_object(
- self._context, self.base_db_model, core_fields)
- # create type specific qos_..._rule
- addn_fields = self._get_changed_addn_fields()
- self._copy_common_fields(core_fields, addn_fields)
- addn_db_obj = db_api.create_object(
- self._context, self.db_model, addn_fields)
+ with db_api.autonested_transaction(self._context.session):
+ base_db_obj = db_api.create_object(
+ self._context, self.base_db_model, core_fields)
+
+ # create type specific qos_..._rule
+ addn_fields = self._get_changed_addn_fields()
+ self._copy_common_fields(core_fields, addn_fields)
+ addn_db_obj = db_api.create_object(
+ self._context, self.db_model, addn_fields)
# merge two db objects into single neutron one
self.from_db_object(base_db_obj, addn_db_obj)
+ # TODO(QoS): Test that update is in single transaction
def update(self):
updated_db_objs = []
# update base qos_rule, if needed
core_fields = self._get_changed_core_fields()
- if core_fields:
- base_db_obj = db_api.update_object(
- self._context, self.base_db_model, self.id, core_fields)
- updated_db_objs.append(base_db_obj)
-
- addn_fields = self._get_changed_addn_fields()
- if addn_fields:
- addn_db_obj = db_api.update_object(
- self._context, self.db_model, self.id, addn_fields)
- updated_db_objs.append(addn_db_obj)
+
+ with db_api.autonested_transaction(self._context.session):
+ if core_fields:
+ base_db_obj = db_api.update_object(
+ self._context, self.base_db_model, self.id, core_fields)
+ updated_db_objs.append(base_db_obj)
+
+ addn_fields = self._get_changed_addn_fields()
+ if addn_fields:
+ addn_db_obj = db_api.update_object(
+ self._context, self.db_model, self.id, addn_fields)
+ updated_db_objs.append(addn_db_obj)
# update neutron object with values from both database objects
self.from_db_object(*updated_db_objs)
from neutron.api.rpc.callbacks import registry as rpc_registry
from neutron.api.rpc.callbacks import resources as rpc_resources
+from neutron.db import db_base_plugin_common
from neutron.extensions import qos
from neutron.i18n import _LW
from neutron.objects.qos import policy as policy_object
def _get_policy_obj(self, context, policy_id):
return policy_object.QosPolicy.get_by_id(context, policy_id)
+ @db_base_plugin_common.filter_fields
def get_policy(self, context, policy_id, fields=None):
- #TODO(QoS): Support the fields parameter
return self._get_policy_obj(context, policy_id).to_dict()
+ @db_base_plugin_common.filter_fields
def get_policies(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
rule.id = rule_id
rule.delete()
+ @db_base_plugin_common.filter_fields
def get_policy_bandwidth_limit_rule(self, context, rule_id,
policy_id, fields=None):
- #TODO(QoS): Support the fields parameter
return rule_object.QosBandwidthLimitRule.get_by_id(context,
rule_id).to_dict()
+ @db_base_plugin_common.filter_fields
def get_policy_bandwidth_limit_rules(self, context, policy_id,
filters=None, fields=None,
sorts=None, limit=None,
return [rule_obj.to_dict() for rule_obj in
rule_object.QosBandwidthLimitRule.get_objects(context)]
+ @db_base_plugin_common.filter_fields
def get_rule_types(self, context, filters=None, fields=None,
sorts=None, limit=None,
marker=None, page_reverse=False):
--- /dev/null
+# Copyright (c) 2015 Red Hat, Inc.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from neutron.db import db_base_plugin_common
+from neutron.tests import base
+
+
+class FilterFieldsTestCase(base.BaseTestCase):
+
+ @db_base_plugin_common.filter_fields
+ def method_dict(self, fields=None):
+ return {'one': 1, 'two': 2, 'three': 3}
+
+ @db_base_plugin_common.filter_fields
+ def method_list(self, fields=None):
+ return [self.method_dict() for _ in range(3)]
+
+ @db_base_plugin_common.filter_fields
+ def method_multiple_arguments(self, not_used, fields=None,
+ also_not_used=None):
+ return {'one': 1, 'two': 2, 'three': 3}
+
+ def test_no_fields(self):
+ expected = {'one': 1, 'two': 2, 'three': 3}
+ observed = self.method_dict()
+ self.assertEqual(expected, observed)
+
+ def test_dict(self):
+ expected = {'two': 2}
+ observed = self.method_dict(['two'])
+ self.assertEqual(expected, observed)
+
+ def test_list(self):
+ expected = [{'two': 2}, {'two': 2}, {'two': 2}]
+ observed = self.method_list(['two'])
+ self.assertEqual(expected, observed)
+
+ def test_multiple_arguments_positional(self):
+ expected = {'two': 2}
+ observed = self.method_multiple_arguments(list(), ['two'])
+ self.assertEqual(expected, observed)
+
+ def test_multiple_arguments_positional_and_keywords(self):
+ expected = {'two': 2}
+ observed = self.method_multiple_arguments(fields=['two'],
+ not_used=None)
+ self.assertEqual(expected, observed)
+
+ def test_multiple_arguments_keyword(self):
+ expected = {'two': 2}
+ observed = self.method_multiple_arguments(list(), fields=['two'])
+ self.assertEqual(expected, observed)
# License for the specific language governing permissions and limitations
# under the License.
+import mock
+
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.objects.qos import policy
_test_class = policy.QosPolicy
+ def setUp(self):
+ super(QosPolicyObjectTestCase, self).setUp()
+ self.db_qos_rules = [self.get_random_fields(rule.QosRule)
+ for _ in range(3)]
+
+ # Tie qos rules with policies
+ self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id']
+ self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id']
+ self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id']
+
+ self.db_qos_bandwidth_rules = [
+ self.get_random_fields(rule.QosBandwidthLimitRule)
+ for _ in range(3)]
+
+ # Tie qos rules with qos bandwidth limit rules
+ for i, qos_rule in enumerate(self.db_qos_rules):
+ self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id']
+
+ self.model_map = {
+ self._test_class.db_model: self.db_objs,
+ rule.QosRule.base_db_model: self.db_qos_rules,
+ rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules}
+
+ def fake_get_objects(self, context, model, qos_policy_id=None):
+ objs = self.model_map[model]
+ if model is rule.QosRule.base_db_model and qos_policy_id:
+ return [obj for obj in objs
+ if obj['qos_policy_id'] == qos_policy_id]
+ return objs
+
+ def fake_get_object(self, context, model, id):
+ objects = self.model_map[model]
+ return [obj for obj in objects if obj['id'] == id][0]
+
+ def test_get_objects(self):
+ with mock.patch.object(
+ db_api, 'get_objects',
+ side_effect=self.fake_get_objects),\
+ mock.patch.object(
+ db_api, 'get_object',
+ side_effect=self.fake_get_object):
+ objs = self._test_class.get_objects(self.context)
+ self._validate_objects(self.db_objs, objs)
+
class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
testlib_api.SqlTestCase):
policy_obj.create()
return policy_obj
+ def _create_test_policy_with_rule(self):
+ policy_obj = self._create_test_policy()
+
+ rule_fields = self.get_random_fields(
+ obj_cls=rule.QosBandwidthLimitRule)
+ rule_fields['qos_policy_id'] = policy_obj.id
+ rule_fields['tenant_id'] = policy_obj.tenant_id
+
+ rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
+ rule_obj.create()
+
+ return policy_obj, rule_obj
+
def _create_test_network(self):
# TODO(ihrachys): replace with network.create() once we get an object
# implementation for networks
self.assertIsNone(policy_obj)
def test_synthetic_rule_fields(self):
- obj = policy.QosPolicy(self.context, **self.db_obj)
- obj.create()
-
- rule_fields = self.get_random_fields(
- obj_cls=rule.QosBandwidthLimitRule)
- rule_fields['qos_policy_id'] = obj.id
- rule_fields['tenant_id'] = obj.tenant_id
-
- rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields)
- rule_obj.create()
-
- obj = policy.QosPolicy.get_by_id(self.context, obj.id)
- self.assertEqual([rule_obj], obj.bandwidth_limit_rules)
+ policy_obj, rule_obj = self._create_test_policy_with_rule()
+ policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
+ self.assertEqual([rule_obj], policy_obj.bandwidth_limit_rules)
+
+ def test_create_is_in_single_transaction(self):
+ obj = self._test_class(self.context, **self.db_obj)
+ with mock.patch('sqlalchemy.engine.'
+ 'Transaction.commit') as mock_commit,\
+ mock.patch.object(obj._context.session, 'add'):
+ obj.create()
+ self.assertEqual(1, mock_commit.call_count)
+
+ def test_get_by_id_fetches_rules_non_lazily(self):
+ policy_obj, rule_obj = self._create_test_policy_with_rule()
+ policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
+
+ primitive = policy_obj.obj_to_primitive()
+ self.assertNotEqual([], (primitive['versioned_object.data']
+ ['bandwidth_limit_rules']))
from neutron.tests import base as test_base
+class FakeModel(object):
+ def __init__(self, *args, **kwargs):
+ pass
+
+
@obj_base.VersionedObjectRegistry.register
class FakeNeutronObject(base.NeutronObject):
- db_model = 'fake_model'
+ db_model = FakeModel
fields = {
'id': obj_fields.UUIDField(),
with mock.patch.object(db_api, 'get_objects',
return_value=self.db_objs) as get_objects_mock:
objs = self._test_class.get_objects(self.context)
- self.assertFalse(
- filter(lambda obj: not self._is_test_class(obj), objs))
- self.assertEqual(
- sorted(self.db_objs),
- sorted(get_obj_db_fields(obj) for obj in objs))
- get_objects_mock.assert_called_once_with(
- self.context, self._test_class.db_model)
+ self._validate_objects(self.db_objs, objs)
+ get_objects_mock.assert_called_once_with(
+ self.context, self._test_class.db_model)
+
+ def _validate_objects(self, expected, observed):
+ self.assertFalse(
+ filter(lambda obj: not self._is_test_class(obj), observed))
+ self.assertEqual(
+ sorted(expected),
+ sorted(get_obj_db_fields(obj) for obj in observed))
def _check_equal(self, obj, db_obj):
self.assertEqual(