]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Fix accessing shared policies, add assoc tests
authorJohn Schwarz <jschwarz@redhat.com>
Sun, 26 Jul 2015 13:00:12 +0000 (16:00 +0300)
committerIhar Hrachyshka <ihrachys@redhat.com>
Sat, 1 Aug 2015 17:03:44 +0000 (19:03 +0200)
This patch is two-fold:
1. Previously, policies that were created using the 'shared=True' flag
   were not accessible to other tenants, since the context used to
   search the policies was not elevated. This patch elevates the context
   prior to retrieving the policy, and if a match was found, makes sure
   that the user has permissions to access it (either the policy is
   shared or it's from the same tenant id).
2. Tests for both associations and disassociations of policies to both
   networks and ports are added in this patch, to make sure coverage is
   good and that the problem is fixed.

Change-Id: Idec13ff4ec575b6d0c0a455c1b3bd9d9700ff7fb

neutron/objects/qos/policy.py
neutron/services/qos/qos_extension.py
neutron/tests/api/base.py
neutron/tests/api/test_qos.py
neutron/tests/tempest/services/network/json/network_client.py
neutron/tests/unit/objects/qos/test_policy.py

index fb2fca2226bddec412745bb67238989e0cbaf5fd..cc7cdc981a18141a4301a56ca9451c942ef03b3c 100644 (file)
@@ -86,21 +86,40 @@ class QosPolicy(base.NeutronDbObject):
         for attr in self.rule_fields:
             self.obj_load_attr(attr)
 
+    @staticmethod
+    def _is_policy_accessible(context, db_obj):
+        #TODO(QoS): Look at I3426b13eede8bfa29729cf3efea3419fb91175c4 for
+        #           other possible solutions to this.
+        return (context.is_admin or
+                db_obj.shared or
+                db_obj.tenant_id == context.tenant_id)
+
     @classmethod
     def get_by_id(cls, context, id):
-        with db_api.autonested_transaction(context.session):
-            policy_obj = super(QosPolicy, cls).get_by_id(context, id)
-            if policy_obj:
-                policy_obj._load_rules()
-        return policy_obj
+        # We want to get the policy regardless of its tenant id. We'll make
+        # sure the tenant has permission to access the policy later on.
+        admin_context = context.elevated()
+        with db_api.autonested_transaction(admin_context.session):
+            policy_obj = super(QosPolicy, cls).get_by_id(admin_context, id)
+            if (not policy_obj or
+                not cls._is_policy_accessible(context, policy_obj)):
+                return
+
+            policy_obj._load_rules()
+            return policy_obj
 
     # TODO(QoS): Test that all objects are fetched within one transaction
     @classmethod
     def get_objects(cls, context, **kwargs):
-        with db_api.autonested_transaction(context.session):
-            db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
-            objs = list()
+        # We want to get the policy regardless of its tenant id. We'll make
+        # sure the tenant has permission to access the policy later on.
+        admin_context = context.elevated()
+        with db_api.autonested_transaction(admin_context.session):
+            db_objs = db_api.get_objects(admin_context, cls.db_model, **kwargs)
+            objs = []
             for db_obj in db_objs:
+                if not cls._is_policy_accessible(context, db_obj):
+                    continue
                 obj = cls(context, **db_obj)
                 obj._load_rules()
                 objs.append(obj)
index 2cae032cac09b1480bd829ebeb38b834dbf90ca4..518b2adc5ccbbf512ce5900d67988c6af7c2784c 100644 (file)
@@ -49,6 +49,9 @@ class QosResourceExtensionHandler(object):
         qos_policy_id = port_changes.get(qos.QOS_POLICY_ID)
         if qos_policy_id is not None:
             policy = self._get_policy_obj(context, qos_policy_id)
+            #TODO(QoS): If the policy doesn't exist (or if it is not shared and
+            #           the tenant id doesn't match the context's), this will
+            #           raise an exception (policy is None).
             policy.attach_port(port['id'])
             port[qos.QOS_POLICY_ID] = qos_policy_id
 
@@ -61,6 +64,9 @@ class QosResourceExtensionHandler(object):
         qos_policy_id = network_changes.get(qos.QOS_POLICY_ID)
         if qos_policy_id:
             policy = self._get_policy_obj(context, qos_policy_id)
+            #TODO(QoS): If the policy doesn't exist (or if it is not shared and
+            #           the tenant id doesn't match the context's), this will
+            #           raise an exception (policy is None).
             policy.attach_network(network['id'])
             network[qos.QOS_POLICY_ID] = qos_policy_id
 
index 0e8b6fffda873e31404576f736b6cb954752d75a..57847862922f709e70818191297ca425de14eb1a 100644 (file)
@@ -231,9 +231,9 @@ class BaseNetworkTest(neutron.tests.tempest.test.BaseTestCase):
         return network
 
     @classmethod
-    def create_shared_network(cls, network_name=None):
+    def create_shared_network(cls, network_name=None, **post_body):
         network_name = network_name or data_utils.rand_name('sharednetwork-')
-        post_body = {'name': network_name, 'shared': True}
+        post_body.update({'name': network_name, 'shared': True})
         body = cls.admin_client.create_network(**post_body)
         network = body['network']
         cls.shared_networks.append(network)
index 5332b45d19ae9b5070646faa1b109ed06ad61f70..e4b05321d82f3e17765ea215fbaa23178da72e21 100644 (file)
@@ -42,7 +42,7 @@ class QosTestJSON(base.BaseAdminNetworkTest):
         retrieved_policy = retrieved_policy['policy']
         self.assertEqual('test-policy', retrieved_policy['name'])
         self.assertEqual('test policy desc', retrieved_policy['description'])
-        self.assertEqual(False, retrieved_policy['shared'])
+        self.assertFalse(retrieved_policy['shared'])
 
         # Test 'list policies'
         policies = self.admin_client.list_qos_policies()['policies']
@@ -62,7 +62,7 @@ class QosTestJSON(base.BaseAdminNetworkTest):
         retrieved_policy = self.admin_client.show_qos_policy(policy['id'])
         retrieved_policy = retrieved_policy['policy']
         self.assertEqual('test policy desc', retrieved_policy['description'])
-        self.assertEqual(True, retrieved_policy['shared'])
+        self.assertTrue(retrieved_policy['shared'])
         self.assertEqual([], retrieved_policy['bandwidth_limit_rules'])
 
     @test.attr(type='smoke')
@@ -79,9 +79,156 @@ class QosTestJSON(base.BaseAdminNetworkTest):
         self.assertRaises(exceptions.NotFound,
                           self.admin_client.show_qos_policy, policy['id'])
 
+    @test.attr(type='smoke')
+    @test.idempotent_id('cf776f77-8d3d-49f2-8572-12d6a1557224')
+    def test_list_rule_types(self):
+        # List supported rule types
+        expected_rule_types = qos_consts.VALID_RULE_TYPES
+        expected_rule_details = ['type']
+
+        rule_types = self.admin_client.list_qos_rule_types()
+        actual_list_rule_types = rule_types['rule_types']
+        actual_rule_types = [rule['type'] for rule in actual_list_rule_types]
+
+        # Verify that only required fields present in rule details
+        for rule in actual_list_rule_types:
+            self.assertEqual(tuple(rule.keys()), tuple(expected_rule_details))
+
+        # Verify if expected rules are present in the actual rules list
+        for rule in expected_rule_types:
+            self.assertIn(rule, actual_rule_types)
+
+    def _disassociate_network(self, client, network_id):
+        client.update_network(network_id, qos_policy_id=None)
+        updated_network = self.admin_client.show_network(network_id)
+        self.assertIsNone(updated_network['network']['qos_policy_id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('65b9ef75-1911-406a-bbdb-ca1d68d528b0')
+    def test_policy_association_with_admin_network(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=False)
+        network = self.create_shared_network('test network',
+                                             qos_policy_id=policy['id'])
+
+        retrieved_network = self.admin_client.show_network(network['id'])
+        self.assertEqual(
+            policy['id'], retrieved_network['network']['qos_policy_id'])
+
+        self._disassociate_network(self.admin_client, network['id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('1738de5d-0476-4163-9022-5e1b548c208e')
+    def test_policy_association_with_tenant_network(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=True)
+        network = self.create_network('test network',
+                                      qos_policy_id=policy['id'])
+
+        retrieved_network = self.admin_client.show_network(network['id'])
+        self.assertEqual(
+            policy['id'], retrieved_network['network']['qos_policy_id'])
+
+        self._disassociate_network(self.client, network['id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('1aa55a79-324f-47d9-a076-894a8fc2448b')
+    def test_policy_association_with_network_non_shared_policy(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=False)
+        #TODO(QoS): This currently raises an exception on the server side. See
+        #           services/qos/qos_extension.py for comments on this subject.
+        network = self.create_network('test network',
+                                      qos_policy_id=policy['id'])
+
+        retrieved_network = self.admin_client.show_network(network['id'])
+        self.assertIsNone(retrieved_network['network']['qos_policy_id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('09a9392c-1359-4cbb-989f-fb768e5834a8')
+    def test_policy_update_association_with_admin_network(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=False)
+        network = self.create_shared_network('test network')
+        retrieved_network = self.admin_client.show_network(network['id'])
+        self.assertIsNone(retrieved_network['network']['qos_policy_id'])
+
+        self.admin_client.update_network(network['id'],
+                                         qos_policy_id=policy['id'])
+        retrieved_network = self.admin_client.show_network(network['id'])
+        self.assertEqual(
+            policy['id'], retrieved_network['network']['qos_policy_id'])
+
+        self._disassociate_network(self.admin_client, network['id'])
+
+    def _disassociate_port(self, port_id):
+        self.client.update_port(port_id, qos_policy_id=None)
+        updated_port = self.admin_client.show_port(port_id)
+        self.assertIsNone(updated_port['port']['qos_policy_id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('98fcd95e-84cf-4746-860e-44692e674f2e')
+    def test_policy_association_with_port_shared_policy(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=True)
+        network = self.create_shared_network('test network')
+        port = self.create_port(network, qos_policy_id=policy['id'])
+
+        retrieved_port = self.admin_client.show_port(port['id'])
+        self.assertEqual(
+            policy['id'], retrieved_port['port']['qos_policy_id'])
+
+        self._disassociate_port(port['id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('f53d961c-9fe5-4422-8b66-7add972c6031')
+    def test_policy_association_with_port_non_shared_policy(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=False)
+        network = self.create_shared_network('test network')
+        #TODO(QoS): This currently raises an exception on the server side. See
+        #           services/qos/qos_extension.py for comments on this subject.
+        port = self.create_port(network, qos_policy_id=policy['id'])
+
+        retrieved_port = self.admin_client.show_port(port['id'])
+        self.assertIsNone(retrieved_port['port']['qos_policy_id'])
+
+    @test.attr(type='smoke')
+    @test.idempotent_id('f8163237-fba9-4db5-9526-bad6d2343c76')
+    def test_policy_update_association_with_port_shared_policy(self):
+        policy = self.create_qos_policy(name='test-policy',
+                                        description='test policy',
+                                        shared=True)
+        network = self.create_shared_network('test network')
+        port = self.create_port(network)
+        retrieved_port = self.admin_client.show_port(port['id'])
+        self.assertIsNone(retrieved_port['port']['qos_policy_id'])
+
+        self.client.update_port(port['id'], qos_policy_id=policy['id'])
+        retrieved_port = self.admin_client.show_port(port['id'])
+        self.assertEqual(
+            policy['id'], retrieved_port['port']['qos_policy_id'])
+
+        self._disassociate_port(port['id'])
+
+
+class QosBandwidthLimitRuleTestJSON(base.BaseAdminNetworkTest):
+    @classmethod
+    def resource_setup(cls):
+        super(QosBandwidthLimitRuleTestJSON, cls).resource_setup()
+        if not test.is_extension_enabled('qos', 'network'):
+            msg = "qos extension not enabled."
+            raise cls.skipException(msg)
+
     @test.attr(type='smoke')
     @test.idempotent_id('8a59b00b-3e9c-4787-92f8-93a5cdf5e378')
-    def test_bandwidth_limit_rule_create(self):
+    def test_rule_create(self):
         policy = self.create_qos_policy(name='test-policy',
                                         description='test policy',
                                         shared=False)
@@ -109,8 +256,9 @@ class QosTestJSON(base.BaseAdminNetworkTest):
         self.assertEqual(1, len(policy_rules))
         self.assertEqual(rule['id'], policy_rules[0]['id'])
 
+    @test.attr(type='smoke')
     @test.idempotent_id('149a6988-2568-47d2-931e-2dbc858943b3')
-    def test_bandwidth_limit_rule_update(self):
+    def test_rule_update(self):
         policy = self.create_qos_policy(name='test-policy',
                                         description='test policy',
                                         shared=False)
@@ -132,7 +280,7 @@ class QosTestJSON(base.BaseAdminNetworkTest):
     #TODO(QoS): Uncomment once the rule-delete logic is fixed.
 #    @test.attr(type='smoke')
 #    @test.idempotent_id('67ee6efd-7b33-4a68-927d-275b4f8ba958')
-#    def test_bandwidth_limit_rule_delete(self):
+#    def test_rule_delete(self):
 #        policy = self.create_qos_policy(name='test-policy',
 #                                        description='test policy',
 #                                        shared=False)
@@ -149,26 +297,5 @@ class QosTestJSON(base.BaseAdminNetworkTest):
 #                          self.admin_client.show_bandwidth_limit_rule,
 #                          policy['id'], rule['id'])
 
-    @test.attr(type='smoke')
-    @test.idempotent_id('cf776f77-8d3d-49f2-8572-12d6a1557224')
-    def test_list_rule_types(self):
-        # List supported rule types
-        expected_rule_types = qos_consts.VALID_RULE_TYPES
-        expected_rule_details = ['type']
-
-        rule_types = self.admin_client.list_qos_rule_types()
-        actual_list_rule_types = rule_types['rule_types']
-        actual_rule_types = [rule['type'] for rule in actual_list_rule_types]
-
-        # Verify that only required fields present in rule details
-        for rule in actual_list_rule_types:
-            self.assertEqual(tuple(rule.keys()), tuple(expected_rule_details))
-
-        # Verify if expected rules are present in the actual rules list
-        for rule in expected_rule_types:
-            self.assertIn(rule, actual_rule_types)
-
     #TODO(QoS): create several bandwidth-limit rules (not sure it makes sense,
     #           but to test more than one rule)
-    #TODO(QoS): associate/disassociate policy with network
-    #TODO(QoS): associate/disassociate policy with port
index bc8eaa2c04bfb78885a7c7443dcbcde8660fcca8..c01c83c706aa6df08d75e88fd384ab35be0b5185 100644 (file)
@@ -653,12 +653,6 @@ class NetworkClientJSON(service_client.ServiceClient):
         self.expected_success(200, resp.status)
         return service_client.ResponseBody(resp, body)
 
-    def get_qos_policy(self, policy_id):
-        uri = '%s/qos/policies/%s' % (self.uri_prefix, policy_id)
-        resp, body = self.get(uri)
-        self.expected_success(200, resp.status)
-        return service_client.ResponseBody(resp, body)
-
     def create_bandwidth_limit_rule(self, policy_id, max_kbps, max_burst_kbps):
         uri = '%s/qos/policies/%s/bandwidth_limit_rules' % (
             self.uri_prefix, policy_id)
index c3c747b90b91488c167cd39375f3fb050ad1a351..6c587db10167254cb8093b2d491bf599409e6eaa 100644 (file)
@@ -60,15 +60,40 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
         return [obj for obj in objects if obj['id'] == id][0]
 
     def test_get_objects(self):
+        admin_context = self.context.elevated()
         with mock.patch.object(
-                    db_api, 'get_objects',
-                    side_effect=self.fake_get_objects),\
-                mock.patch.object(
-                    db_api, 'get_object',
-                    side_effect=self.fake_get_object):
-            objs = self._test_class.get_objects(self.context)
+            db_api, 'get_objects',
+            side_effect=self.fake_get_objects) as get_objects_mock:
+
+            with mock.patch.object(
+                db_api, 'get_object',
+                side_effect=self.fake_get_object):
+
+                with mock.patch.object(
+                    self.context,
+                    'elevated',
+                    return_value=admin_context) as context_mock:
+
+                    objs = self._test_class.get_objects(self.context)
+                    context_mock.assert_called_once_with()
+                    get_objects_mock.assert_any_call(
+                        admin_context, self._test_class.db_model)
         self._validate_objects(self.db_objs, objs)
 
+    def test_get_by_id(self):
+        admin_context = self.context.elevated()
+        with mock.patch.object(db_api, 'get_object',
+                               return_value=self.db_obj) as get_object_mock:
+            with mock.patch.object(self.context,
+                                   'elevated',
+                                   return_value=admin_context) as context_mock:
+                obj = self._test_class.get_by_id(self.context, id='fake_id')
+                self.assertTrue(self._is_test_class(obj))
+                self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
+                context_mock.assert_called_once_with()
+                get_object_mock.assert_called_once_with(
+                    admin_context, self._test_class.db_model, id='fake_id')
+
 
 class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
                                 testlib_api.SqlTestCase):