obj_base.VersionedObjectDictCompat,
obj_base.ComparableVersionedObject):
+ synthetic_fields = []
+
def __init__(self, context=None, **kwargs):
super(NeutronObject, self).__init__(context, **kwargs)
self.obj_set_defaults()
def get_by_id(cls, context, id):
raise NotImplementedError()
+ @classmethod
+ def validate_filters(cls, **kwargs):
+ bad_filters = [key for key in kwargs
+ if key not in cls.fields or key in cls.synthetic_fields]
+ if bad_filters:
+ bad_filters = ', '.join(bad_filters)
+ msg = _("'%s' is not supported for filtering") % bad_filters
+ raise exceptions.InvalidInput(error_message=msg)
+
@classmethod
@abc.abstractmethod
def get_objects(cls, context, **kwargs):
# should be overridden for all persistent objects
db_model = None
- synthetic_fields = []
-
fields_no_update = []
def from_db_object(self, *objs):
@classmethod
def get_objects(cls, context, **kwargs):
+ cls.validate_filters(**kwargs)
db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs]
for obj in objs:
# sure the tenant has permission to access the policy later on.
admin_context = context.elevated()
with db_api.autonested_transaction(admin_context.session):
- db_objs = db_api.get_objects(admin_context, cls.db_model, **kwargs)
- objs = []
- for db_obj in db_objs:
- if not cls._is_policy_accessible(context, db_obj):
+ objs = super(QosPolicy, cls).get_objects(admin_context,
+ **kwargs)
+ result = []
+ for obj in objs:
+ if not cls._is_policy_accessible(context, obj):
continue
- obj = cls(context, **db_obj)
obj.reload_rules()
- objs.append(obj)
- return objs
+ result.append(obj)
+ return result
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
# we don't receive context because we don't need db access at all
@classmethod
def get_objects(cls, **kwargs):
+ cls.validate_filters(**kwargs)
core_plugin = manager.NeutronManager.get_plugin()
return [cls(type=type_)
for type_ in core_plugin.supported_qos_rule_types]
def get_policies(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
- #TODO(QoS): Support all the optional parameters
- return policy_object.QosPolicy.get_objects(context)
+ return policy_object.QosPolicy.get_objects(context, **filters)
#TODO(QoS): Consider adding a proxy catch-all for rules, so
# we capture the API function call, and just pass
filters=None, fields=None,
sorts=None, limit=None,
marker=None, page_reverse=False):
- #TODO(QoS): Support all the optional parameters
# make sure we have access to the policy when fetching rules
with db_api.autonested_transaction(context.session):
# first, validate that we have access to the policy
self._get_policy_obj(context, policy_id)
- return rule_object.QosBandwidthLimitRule.get_objects(context)
+ return rule_object.QosBandwidthLimitRule.get_objects(context,
+ **filters)
# TODO(QoS): enforce rule types when accessing rule objects
@db_base_plugin_common.filter_fields
def get_rule_types(self, context, filters=None, fields=None,
sorts=None, limit=None,
marker=None, page_reverse=False):
- return rule_type_object.QosRuleType.get_objects()
+ return rule_type_object.QosRuleType.get_objects(**filters)
@test.idempotent_id('108fbdf7-3463-4e47-9871-d07f3dcf5bbb')
def test_create_policy(self):
policy = self.create_qos_policy(name='test-policy',
- description='test policy desc',
+ description='test policy desc1',
shared=False)
# Test 'show policy'
retrieved_policy = self.admin_client.show_qos_policy(policy['id'])
retrieved_policy = retrieved_policy['policy']
self.assertEqual('test-policy', retrieved_policy['name'])
- self.assertEqual('test policy desc', retrieved_policy['description'])
+ self.assertEqual('test policy desc1', retrieved_policy['description'])
self.assertFalse(retrieved_policy['shared'])
# Test 'list policies'
policies_ids = [p['id'] for p in policies]
self.assertIn(policy['id'], policies_ids)
+ @test.attr(type='smoke')
+ @test.idempotent_id('f8d20e92-f06d-4805-b54f-230f77715815')
+ def test_list_policy_filter_by_name(self):
+ self.create_qos_policy(name='test', description='test policy',
+ shared=False)
+ self.create_qos_policy(name='test2', description='test policy',
+ shared=False)
+
+ policies = (self.admin_client.
+ list_qos_policies(name='test')['policies'])
+ self.assertEqual(1, len(policies))
+
+ retrieved_policy = policies[0]
+ self.assertEqual('test', retrieved_policy['name'])
+
@test.attr(type='smoke')
@test.idempotent_id('8e88a54b-f0b2-4b7d-b061-a15d93c2c7d6')
def test_policy_update(self):
description='',
shared=False)
self.admin_client.update_qos_policy(policy['id'],
- description='test policy desc',
+ description='test policy desc2',
shared=True)
retrieved_policy = self.admin_client.show_qos_policy(policy['id'])
retrieved_policy = retrieved_policy['policy']
- self.assertEqual('test policy desc', retrieved_policy['description'])
+ self.assertEqual('test policy desc2', retrieved_policy['description'])
self.assertTrue(retrieved_policy['shared'])
self.assertEqual([], retrieved_policy['rules'])
import json
import time
+import urllib
+
from six.moves.urllib import parse
from tempest_lib.common.utils import misc
body = json.loads(body)
return service_client.ResponseBody(resp, body)
- def list_qos_policies(self):
- uri = '%s/qos/policies' % self.uri_prefix
+ def list_qos_policies(self, **filters):
+ if filters:
+ uri = '%s/qos/policies?%s' % (self.uri_prefix,
+ urllib.urlencode(filters))
+ else:
+ uri = '%s/qos/policies' % self.uri_prefix
resp, body = self.get(uri)
self.expected_success(200, resp.status)
body = json.loads(body)
admin_context, self._test_class.db_model)
self._validate_objects(self.db_objs, objs)
+ def test_get_objects_valid_fields(self):
+ admin_context = self.context.elevated()
+
+ with mock.patch.object(
+ db_api, 'get_objects',
+ return_value=[self.db_obj]) as get_objects_mock:
+
+ with mock.patch.object(
+ self.context,
+ 'elevated',
+ return_value=admin_context) as context_mock:
+
+ objs = self._test_class.get_objects(
+ self.context,
+ **self.valid_field_filter)
+ context_mock.assert_called_once_with()
+ get_objects_mock.assert_any_call(
+ admin_context, self._test_class.db_model,
+ **self.valid_field_filter)
+ self._validate_objects([self.db_obj], objs)
+
def test_get_by_id(self):
admin_context = self.context.elevated()
with mock.patch.object(db_api, 'get_object',
# License for the specific language governing permissions and limitations
# under the License.
+import copy
import random
import string
fields_no_update = ['id']
+ synthetic_fields = ['field2']
+
def _random_string(n=10):
return ''.join(random.choice(string.ascii_lowercase) for _ in range(n))
self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0]
+ valid_field = [f for f in self._test_class.fields
+ if f not in self._test_class.synthetic_fields][0]
+ self.valid_field_filter = {valid_field: self.db_obj[valid_field]}
+
@classmethod
def get_random_fields(cls, obj_cls=None):
obj_cls = obj_cls or cls._test_class
get_objects_mock.assert_called_once_with(
self.context, self._test_class.db_model)
+ def test_get_objects_valid_fields(self):
+ with mock.patch.object(
+ db_api, 'get_objects',
+ return_value=[self.db_obj]) as get_objects_mock:
+
+ objs = self._test_class.get_objects(self.context,
+ **self.valid_field_filter)
+ self._validate_objects([self.db_obj], objs)
+
+ get_objects_mock.assert_called_with(
+ self.context, self._test_class.db_model,
+ **self.valid_field_filter)
+
+ def test_get_objects_mixed_fields(self):
+ synthetic_fields = self._test_class.synthetic_fields
+ if not synthetic_fields:
+ self.skipTest('No synthetic fields found in test class %r' %
+ self._test_class)
+
+ filters = copy.copy(self.valid_field_filter)
+ filters[synthetic_fields[0]] = 'xxx'
+
+ with mock.patch.object(db_api, 'get_objects',
+ return_value=self.db_objs):
+ self.assertRaises(base.exceptions.InvalidInput,
+ self._test_class.get_objects, self.context,
+ **filters)
+
+ def test_get_objects_synthetic_fields(self):
+ synthetic_fields = self._test_class.synthetic_fields
+ if not synthetic_fields:
+ self.skipTest('No synthetic fields found in test class %r' %
+ self._test_class)
+
+ with mock.patch.object(db_api, 'get_objects',
+ return_value=self.db_objs):
+ self.assertRaises(base.exceptions.InvalidInput,
+ self._test_class.get_objects, self.context,
+ **{synthetic_fields[0]: 'xxx'})
+
+ def test_get_objects_invalid_fields(self):
+ with mock.patch.object(db_api, 'get_objects',
+ return_value=self.db_objs):
+ self.assertRaises(base.exceptions.InvalidInput,
+ self._test_class.get_objects, self.context,
+ fake_field='xxx')
+
def _validate_objects(self, expected, observed):
self.assertFalse(
filter(lambda obj: not self._is_test_class(obj), observed))