From: Jakub Libosvar Date: Tue, 21 Jul 2015 08:04:00 +0000 (+0000) Subject: Support qos rules and fields parameters in GET requests X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=7ed1d4f61635e67d0a554ed34540a03222c3f9d3;p=openstack-build%2Fneutron-build.git Support qos rules and fields parameters in GET requests Previously we didn't load the rules into policy object. This patch adds loading the rules and defines bandwidth_limit_rules as a policy resource in a single transaction. As a part of moving towards usage of single transaction, create() and update() of rule were modified accordingly. Finally, we support types in GET requests in this patch. API tests will follow in different patch. Change-Id: I25c72aae74469b687766754bbeb749dfd1b8867c --- diff --git a/neutron/db/db_base_plugin_common.py b/neutron/db/db_base_plugin_common.py index 54257ed97..4ce5daab7 100644 --- a/neutron/db/db_base_plugin_common.py +++ b/neutron/db/db_base_plugin_common.py @@ -29,6 +29,26 @@ from neutron.db import models_v2 LOG = logging.getLogger(__name__) +def filter_fields(f): + @functools.wraps(f) + def inner_filter(*args, **kwargs): + result = f(*args, **kwargs) + fields = kwargs.get('fields') + if not fields: + pos = f.func_code.co_varnames.index('fields') + try: + fields = args[pos] + except IndexError: + return result + + do_filter = lambda d: {k: v for k, v in d.items() if k in fields} + if isinstance(result, list): + return [do_filter(obj) for obj in result] + else: + return do_filter(result) + return inner_filter + + class DbBasePluginCommon(common_db_mixin.CommonDbMixin): """Stores getters and helper methods for db_base_plugin_v2 diff --git a/neutron/extensions/qos.py b/neutron/extensions/qos.py index e845e5334..1c89acac1 100644 --- a/neutron/extensions/qos.py +++ b/neutron/extensions/qos.py @@ -61,7 +61,9 @@ RESOURCE_ATTRIBUTE_MAP = { 'convert_to': attr.convert_to_boolean}, 'tenant_id': {'allow_post': True, 'allow_put': False, 'required_by_policy': True, - 'is_visible': True} + 'is_visible': True}, + 'bandwidth_limit_rules': {'allow_post': False, 'allow_put': False, + 'is_visible': True}, }, 'rule_types': { 'type': {'allow_post': False, 'allow_put': False, diff --git a/neutron/objects/qos/policy.py b/neutron/objects/qos/policy.py index 0c1718ef4..8f2c605c8 100644 --- a/neutron/objects/qos/policy.py +++ b/neutron/objects/qos/policy.py @@ -75,12 +75,37 @@ class QosPolicy(base.NeutronObject): setattr(self, attrname, rules) self.obj_reset_changes([attrname]) + def _load_rules(self): + for attr in self.rule_fields: + self.obj_load_attr(attr) + + @classmethod + def get_by_id(cls, context, id): + with db_api.autonested_transaction(context.session): + policy_obj = super(QosPolicy, cls).get_by_id(context, id) + if policy_obj: + policy_obj._load_rules() + return policy_obj + + # TODO(QoS): Test that all objects are fetched within one transaction + @classmethod + def get_objects(cls, context, **kwargs): + with db_api.autonested_transaction(context.session): + db_objs = db_api.get_objects(context, cls.db_model, **kwargs) + objs = list() + for db_obj in db_objs: + obj = cls(context, **db_obj) + obj._load_rules() + objs.append(obj) + return objs + @classmethod def _get_object_policy(cls, context, model, **kwargs): - binding_db_obj = db_api.get_object(context, model, **kwargs) - # TODO(QoS): rethink handling missing binding case - if binding_db_obj: - return cls.get_by_id(context, binding_db_obj['policy_id']) + with db_api.autonested_transaction(context.session): + binding_db_obj = db_api.get_object(context, model, **kwargs) + # TODO(QoS): rethink handling missing binding case + if binding_db_obj: + return cls.get_by_id(context, binding_db_obj['policy_id']) @classmethod def get_network_policy(cls, context, network_id): @@ -92,6 +117,11 @@ class QosPolicy(base.NeutronObject): return cls._get_object_policy(context, cls.port_binding_model, port_id=port_id) + def create(self): + with db_api.autonested_transaction(self._context.session): + super(QosPolicy, self).create() + self._load_rules() + def attach_network(self, network_id): qos_db_api.create_policy_network_binding(self._context, policy_id=self.id, diff --git a/neutron/objects/qos/rule.py b/neutron/objects/qos/rule.py index 6269e8dbb..0b5713e73 100644 --- a/neutron/objects/qos/rule.py +++ b/neutron/objects/qos/rule.py @@ -96,7 +96,7 @@ class QosRule(base.NeutronObject): obj.obj_reset_changes() return obj - # TODO(QoS): create and update are not transactional safe + # TODO(QoS): Test that create is in single transaction def create(self): # TODO(QoS): enforce that type field value is bound to specific class @@ -104,18 +104,21 @@ class QosRule(base.NeutronObject): # create base qos_rule core_fields = self._get_changed_core_fields() - base_db_obj = db_api.create_object( - self._context, self.base_db_model, core_fields) - # create type specific qos_..._rule - addn_fields = self._get_changed_addn_fields() - self._copy_common_fields(core_fields, addn_fields) - addn_db_obj = db_api.create_object( - self._context, self.db_model, addn_fields) + with db_api.autonested_transaction(self._context.session): + base_db_obj = db_api.create_object( + self._context, self.base_db_model, core_fields) + + # create type specific qos_..._rule + addn_fields = self._get_changed_addn_fields() + self._copy_common_fields(core_fields, addn_fields) + addn_db_obj = db_api.create_object( + self._context, self.db_model, addn_fields) # merge two db objects into single neutron one self.from_db_object(base_db_obj, addn_db_obj) + # TODO(QoS): Test that update is in single transaction def update(self): updated_db_objs = [] @@ -123,16 +126,18 @@ class QosRule(base.NeutronObject): # update base qos_rule, if needed core_fields = self._get_changed_core_fields() - if core_fields: - base_db_obj = db_api.update_object( - self._context, self.base_db_model, self.id, core_fields) - updated_db_objs.append(base_db_obj) - - addn_fields = self._get_changed_addn_fields() - if addn_fields: - addn_db_obj = db_api.update_object( - self._context, self.db_model, self.id, addn_fields) - updated_db_objs.append(addn_db_obj) + + with db_api.autonested_transaction(self._context.session): + if core_fields: + base_db_obj = db_api.update_object( + self._context, self.base_db_model, self.id, core_fields) + updated_db_objs.append(base_db_obj) + + addn_fields = self._get_changed_addn_fields() + if addn_fields: + addn_db_obj = db_api.update_object( + self._context, self.db_model, self.id, addn_fields) + updated_db_objs.append(addn_db_obj) # update neutron object with values from both database objects self.from_db_object(*updated_db_objs) diff --git a/neutron/services/qos/qos_plugin.py b/neutron/services/qos/qos_plugin.py index 0b227c8a3..f1d9a1470 100644 --- a/neutron/services/qos/qos_plugin.py +++ b/neutron/services/qos/qos_plugin.py @@ -17,6 +17,7 @@ from neutron import manager from neutron.api.rpc.callbacks import registry as rpc_registry from neutron.api.rpc.callbacks import resources as rpc_resources +from neutron.db import db_base_plugin_common from neutron.extensions import qos from neutron.i18n import _LW from neutron.objects.qos import policy as policy_object @@ -134,10 +135,11 @@ class QoSPlugin(qos.QoSPluginBase): def _get_policy_obj(self, context, policy_id): return policy_object.QosPolicy.get_by_id(context, policy_id) + @db_base_plugin_common.filter_fields def get_policy(self, context, policy_id, fields=None): - #TODO(QoS): Support the fields parameter return self._get_policy_obj(context, policy_id).to_dict() + @db_base_plugin_common.filter_fields def get_policies(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): @@ -174,12 +176,13 @@ class QoSPlugin(qos.QoSPluginBase): rule.id = rule_id rule.delete() + @db_base_plugin_common.filter_fields def get_policy_bandwidth_limit_rule(self, context, rule_id, policy_id, fields=None): - #TODO(QoS): Support the fields parameter return rule_object.QosBandwidthLimitRule.get_by_id(context, rule_id).to_dict() + @db_base_plugin_common.filter_fields def get_policy_bandwidth_limit_rules(self, context, policy_id, filters=None, fields=None, sorts=None, limit=None, @@ -188,6 +191,7 @@ class QoSPlugin(qos.QoSPluginBase): return [rule_obj.to_dict() for rule_obj in rule_object.QosBandwidthLimitRule.get_objects(context)] + @db_base_plugin_common.filter_fields def get_rule_types(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): diff --git a/neutron/tests/unit/db/test_db_base_plugin_common.py b/neutron/tests/unit/db/test_db_base_plugin_common.py new file mode 100644 index 000000000..9074bf618 --- /dev/null +++ b/neutron/tests/unit/db/test_db_base_plugin_common.py @@ -0,0 +1,64 @@ +# Copyright (c) 2015 Red Hat, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from neutron.db import db_base_plugin_common +from neutron.tests import base + + +class FilterFieldsTestCase(base.BaseTestCase): + + @db_base_plugin_common.filter_fields + def method_dict(self, fields=None): + return {'one': 1, 'two': 2, 'three': 3} + + @db_base_plugin_common.filter_fields + def method_list(self, fields=None): + return [self.method_dict() for _ in range(3)] + + @db_base_plugin_common.filter_fields + def method_multiple_arguments(self, not_used, fields=None, + also_not_used=None): + return {'one': 1, 'two': 2, 'three': 3} + + def test_no_fields(self): + expected = {'one': 1, 'two': 2, 'three': 3} + observed = self.method_dict() + self.assertEqual(expected, observed) + + def test_dict(self): + expected = {'two': 2} + observed = self.method_dict(['two']) + self.assertEqual(expected, observed) + + def test_list(self): + expected = [{'two': 2}, {'two': 2}, {'two': 2}] + observed = self.method_list(['two']) + self.assertEqual(expected, observed) + + def test_multiple_arguments_positional(self): + expected = {'two': 2} + observed = self.method_multiple_arguments(list(), ['two']) + self.assertEqual(expected, observed) + + def test_multiple_arguments_positional_and_keywords(self): + expected = {'two': 2} + observed = self.method_multiple_arguments(fields=['two'], + not_used=None) + self.assertEqual(expected, observed) + + def test_multiple_arguments_keyword(self): + expected = {'two': 2} + observed = self.method_multiple_arguments(list(), fields=['two']) + self.assertEqual(expected, observed) diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index b73af22c6..afd6a7982 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -10,6 +10,8 @@ # License for the specific language governing permissions and limitations # under the License. +import mock + from neutron.db import api as db_api from neutron.db import models_v2 from neutron.objects.qos import policy @@ -22,6 +24,50 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): _test_class = policy.QosPolicy + def setUp(self): + super(QosPolicyObjectTestCase, self).setUp() + self.db_qos_rules = [self.get_random_fields(rule.QosRule) + for _ in range(3)] + + # Tie qos rules with policies + self.db_qos_rules[0]['qos_policy_id'] = self.db_objs[0]['id'] + self.db_qos_rules[1]['qos_policy_id'] = self.db_objs[0]['id'] + self.db_qos_rules[2]['qos_policy_id'] = self.db_objs[1]['id'] + + self.db_qos_bandwidth_rules = [ + self.get_random_fields(rule.QosBandwidthLimitRule) + for _ in range(3)] + + # Tie qos rules with qos bandwidth limit rules + for i, qos_rule in enumerate(self.db_qos_rules): + self.db_qos_bandwidth_rules[i]['id'] = qos_rule['id'] + + self.model_map = { + self._test_class.db_model: self.db_objs, + rule.QosRule.base_db_model: self.db_qos_rules, + rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules} + + def fake_get_objects(self, context, model, qos_policy_id=None): + objs = self.model_map[model] + if model is rule.QosRule.base_db_model and qos_policy_id: + return [obj for obj in objs + if obj['qos_policy_id'] == qos_policy_id] + return objs + + def fake_get_object(self, context, model, id): + objects = self.model_map[model] + return [obj for obj in objects if obj['id'] == id][0] + + def test_get_objects(self): + with mock.patch.object( + db_api, 'get_objects', + side_effect=self.fake_get_objects),\ + mock.patch.object( + db_api, 'get_object', + side_effect=self.fake_get_object): + objs = self._test_class.get_objects(self.context) + self._validate_objects(self.db_objs, objs) + class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, testlib_api.SqlTestCase): @@ -42,6 +88,19 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, policy_obj.create() return policy_obj + def _create_test_policy_with_rule(self): + policy_obj = self._create_test_policy() + + rule_fields = self.get_random_fields( + obj_cls=rule.QosBandwidthLimitRule) + rule_fields['qos_policy_id'] = policy_obj.id + rule_fields['tenant_id'] = policy_obj.tenant_id + + rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields) + rule_obj.create() + + return policy_obj, rule_obj + def _create_test_network(self): # TODO(ihrachys): replace with network.create() once we get an object # implementation for networks @@ -111,16 +170,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, self.assertIsNone(policy_obj) def test_synthetic_rule_fields(self): - obj = policy.QosPolicy(self.context, **self.db_obj) - obj.create() - - rule_fields = self.get_random_fields( - obj_cls=rule.QosBandwidthLimitRule) - rule_fields['qos_policy_id'] = obj.id - rule_fields['tenant_id'] = obj.tenant_id - - rule_obj = rule.QosBandwidthLimitRule(self.context, **rule_fields) - rule_obj.create() - - obj = policy.QosPolicy.get_by_id(self.context, obj.id) - self.assertEqual([rule_obj], obj.bandwidth_limit_rules) + policy_obj, rule_obj = self._create_test_policy_with_rule() + policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id) + self.assertEqual([rule_obj], policy_obj.bandwidth_limit_rules) + + def test_create_is_in_single_transaction(self): + obj = self._test_class(self.context, **self.db_obj) + with mock.patch('sqlalchemy.engine.' + 'Transaction.commit') as mock_commit,\ + mock.patch.object(obj._context.session, 'add'): + obj.create() + self.assertEqual(1, mock_commit.call_count) + + def test_get_by_id_fetches_rules_non_lazily(self): + policy_obj, rule_obj = self._create_test_policy_with_rule() + policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id) + + primitive = policy_obj.obj_to_primitive() + self.assertNotEqual([], (primitive['versioned_object.data'] + ['bandwidth_limit_rules'])) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 5e15dc797..0b1c4b239 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -23,10 +23,15 @@ from neutron.objects import base from neutron.tests import base as test_base +class FakeModel(object): + def __init__(self, *args, **kwargs): + pass + + @obj_base.VersionedObjectRegistry.register class FakeNeutronObject(base.NeutronObject): - db_model = 'fake_model' + db_model = FakeModel fields = { 'id': obj_fields.UUIDField(), @@ -106,13 +111,16 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): with mock.patch.object(db_api, 'get_objects', return_value=self.db_objs) as get_objects_mock: objs = self._test_class.get_objects(self.context) - self.assertFalse( - filter(lambda obj: not self._is_test_class(obj), objs)) - self.assertEqual( - sorted(self.db_objs), - sorted(get_obj_db_fields(obj) for obj in objs)) - get_objects_mock.assert_called_once_with( - self.context, self._test_class.db_model) + self._validate_objects(self.db_objs, objs) + get_objects_mock.assert_called_once_with( + self.context, self._test_class.db_model) + + def _validate_objects(self, expected, observed): + self.assertFalse( + filter(lambda obj: not self._is_test_class(obj), observed)) + self.assertEqual( + sorted(expected), + sorted(get_obj_db_fields(obj) for obj in observed)) def _check_equal(self, obj, db_obj): self.assertEqual(