]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
ml2 extension driver: more tests, fix data argument inconsistency
authorIsaku Yamahata <isaku.yamahata@intel.com>
Fri, 17 Oct 2014 08:35:55 +0000 (17:35 +0900)
committerIsaku Yamahata <isaku.yamahata@intel.com>
Tue, 3 Mar 2015 01:02:20 +0000 (17:02 -0800)
This patch adds more tests for ML2 extension driver.
It also fixes a minor bug which was revealed.

The data argument for process/update method of extension driver
was inconsist. some are given data like
{'resource': {'arg': 'value'...}}. But some are given one like
{'arg': 'value'}.
This inconsistency needs to be fixed so that argument is
{'arg': 'value'}. Given the argument is known to be network,
there is no point to carry outer dictionary.

Partially Implements: blueprint ml2-ovs-portsecurity
Change-Id: I4614c3ba5eff0ace46cc928517e31c14b7b2e448

neutron/plugins/ml2/plugin.py
neutron/tests/unit/ml2/drivers/ext_test.py [new file with mode: 0644]
neutron/tests/unit/ml2/test_extension_driver_api.py
setup.cfg

index e208281e5f23edd7dc9a092f3ad4991143caf92f..ffaa8ba789289fcaa973c18aa80d62b1bff08d98 100644 (file)
@@ -619,7 +619,8 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
         return [obj['result'] for obj in objects]
 
     def update_network(self, context, id, network):
-        provider._raise_if_updates_provider_attributes(network['network'])
+        net_data = network[attributes.NETWORK]
+        provider._raise_if_updates_provider_attributes(net_data)
 
         session = context.session
         with session.begin(subtransactions=True):
@@ -627,10 +628,9 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             updated_network = super(Ml2Plugin, self).update_network(context,
                                                                     id,
                                                                     network)
-            self.extension_manager.process_update_network(context, network,
+            self.extension_manager.process_update_network(context, net_data,
                                                           updated_network)
-            self._process_l3_update(context, updated_network,
-                                    network['network'])
+            self._process_l3_update(context, updated_network, net_data)
             self.type_manager.extend_network_dict_provider(context,
                                                            updated_network)
             mech_context = driver_context.NetworkContext(
@@ -780,8 +780,8 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
         session = context.session
         with session.begin(subtransactions=True):
             result = super(Ml2Plugin, self).create_subnet(context, subnet)
-            self.extension_manager.process_create_subnet(context, subnet,
-                                                         result)
+            self.extension_manager.process_create_subnet(
+                context, subnet[attributes.SUBNET], result)
             mech_context = driver_context.SubnetContext(self, context, result)
             self.mechanism_manager.create_subnet_precommit(mech_context)
 
@@ -808,8 +808,8 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             original_subnet = super(Ml2Plugin, self).get_subnet(context, id)
             updated_subnet = super(Ml2Plugin, self).update_subnet(
                 context, id, subnet)
-            self.extension_manager.process_update_subnet(context, subnet,
-                                                         updated_subnet)
+            self.extension_manager.process_update_subnet(
+                context, subnet[attributes.SUBNET], updated_subnet)
             mech_context = driver_context.SubnetContext(
                 self, context, updated_subnet, original_subnet=original_subnet)
             self.mechanism_manager.update_subnet_precommit(mech_context)
diff --git a/neutron/tests/unit/ml2/drivers/ext_test.py b/neutron/tests/unit/ml2/drivers/ext_test.py
new file mode 100644 (file)
index 0000000..45fa4fd
--- /dev/null
@@ -0,0 +1,211 @@
+# Copyright 2015 Intel Corporation.
+# Copyright 2015 Isaku Yamahata <isaku.yamahata at intel com>
+#                               <isaku.yamahata at gmail com>
+# All Rights Reserved.
+#
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import sqlalchemy as sa
+from sqlalchemy import orm
+
+import oslo_db.sqlalchemy.session
+
+from neutron.api import extensions
+from neutron.api.v2 import attributes
+from neutron.db import model_base
+from neutron.db import models_v2
+from neutron.plugins.ml2 import driver_api
+from neutron.tests.unit.ml2 import extensions as test_extensions
+
+
+class TestExtensionDriverBase(driver_api.ExtensionDriver):
+    _supported_extension_aliases = 'test_extension'
+
+    def initialize(self):
+        extensions.append_api_extensions_path(test_extensions.__path__)
+
+    @property
+    def extension_alias(self):
+        return self._supported_extension_aliases
+
+
+class TestExtensionDriver(TestExtensionDriverBase):
+    def initialize(self):
+        super(TestExtensionDriver, self).initialize()
+        self.network_extension = 'Test_Network_Extension'
+        self.subnet_extension = 'Test_Subnet_Extension'
+        self.port_extension = 'Test_Port_Extension'
+
+    def _check_create(self, session, data, result):
+        assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+        assert(isinstance(data, dict))
+        assert('id' not in data)
+        assert(isinstance(result, dict))
+        assert(result['id'] is not None)
+
+    def _check_update(self, session, data, result):
+        assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+        assert(isinstance(data, dict))
+        assert(isinstance(result, dict))
+        assert(result['id'] is not None)
+
+    def _check_extend(self, session, result, db_entry,
+                      expected_db_entry_class):
+        assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+        assert(isinstance(result, dict))
+        assert(result['id'] is not None)
+        assert(isinstance(db_entry, expected_db_entry_class))
+        assert(db_entry.id == result['id'])
+
+    def process_create_network(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_create(session, data, result)
+        result['network_extension'] = self.network_extension + '_create'
+
+    def process_update_network(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_update(session, data, result)
+        self.network_extension = data['network_extension']
+        result['network_extension'] = self.network_extension + '_update'
+
+    def extend_network_dict(self, session, net_db, result):
+        self._check_extend(session, result, net_db, models_v2.Network)
+        result['network_extension'] = self.network_extension + '_extend'
+
+    def process_create_subnet(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_create(session, data, result)
+        result['subnet_extension'] = self.subnet_extension + '_create'
+
+    def process_update_subnet(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_update(session, data, result)
+        self.subnet_extension = data['subnet_extension']
+        result['subnet_extension'] = self.subnet_extension + '_update'
+
+    def extend_subnet_dict(self, session, subnet_db, result):
+        self._check_extend(session, result, subnet_db, models_v2.Subnet)
+        result['subnet_extension'] = self.subnet_extension + '_extend'
+
+    def process_create_port(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_create(session, data, result)
+        result['port_extension'] = self.port_extension + '_create'
+
+    def process_update_port(self, plugin_context, data, result):
+        session = plugin_context.session
+        self._check_update(session, data, result)
+        self.port_extension = data['port_extension']
+        result['port_extension'] = self.port_extension + '_update'
+
+    def extend_port_dict(self, session, port_db, result):
+        self._check_extend(session, result, port_db, models_v2.Port)
+        result['port_extension'] = self.port_extension + '_extend'
+
+
+class TestNetworkExtension(model_base.BASEV2):
+    network_id = sa.Column(sa.String(36),
+                           sa.ForeignKey('networks.id', ondelete="CASCADE"),
+                           primary_key=True)
+    value = sa.Column(sa.String(64))
+    network = orm.relationship(
+        models_v2.Network,
+        backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestSubnetExtension(model_base.BASEV2):
+    subnet_id = sa.Column(sa.String(36),
+                          sa.ForeignKey('subnets.id', ondelete="CASCADE"),
+                          primary_key=True)
+    value = sa.Column(sa.String(64))
+    subnet = orm.relationship(
+        models_v2.Subnet,
+        backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestPortExtension(model_base.BASEV2):
+    port_id = sa.Column(sa.String(36),
+                        sa.ForeignKey('ports.id', ondelete="CASCADE"),
+                        primary_key=True)
+    value = sa.Column(sa.String(64))
+    port = orm.relationship(
+        models_v2.Port,
+        backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestDBExtensionDriver(TestExtensionDriverBase):
+    def _get_value(self, data, key):
+        value = data[key]
+        if not attributes.is_attr_set(value):
+            value = ''
+        return value
+
+    def process_create_network(self, plugin_context, data, result):
+        session = plugin_context.session
+        value = self._get_value(data, 'network_extension')
+        record = TestNetworkExtension(network_id=result['id'], value=value)
+        session.add(record)
+        result['network_extension'] = value
+
+    def process_update_network(self, plugin_context, data, result):
+        session = plugin_context.session
+        record = (session.query(TestNetworkExtension).
+                  filter_by(network_id=result['id']).one())
+        value = data.get('network_extension')
+        if value and value != record.value:
+            record.value = value
+        result['network_extension'] = record.value
+
+    def extend_network_dict(self, session, net_db, result):
+        result['network_extension'] = net_db.extension.value
+
+    def process_create_subnet(self, plugin_context, data, result):
+        session = plugin_context.session
+        value = self._get_value(data, 'subnet_extension')
+        record = TestSubnetExtension(subnet_id=result['id'], value=value)
+        session.add(record)
+        result['subnet_extension'] = value
+
+    def process_update_subnet(self, plugin_context, data, result):
+        session = plugin_context.session
+        record = (session.query(TestSubnetExtension).
+                  filter_by(subnet_id=result['id']).one())
+        value = data.get('subnet_extension')
+        if value and value != record.value:
+            record.value = value
+        result['subnet_extension'] = record.value
+
+    def extend_subnet_dict(self, session, subnet_db, result):
+        value = subnet_db.extension.value if subnet_db.extension else ''
+        result['subnet_extension'] = value
+
+    def process_create_port(self, plugin_context, data, result):
+        session = plugin_context.session
+        value = self._get_value(data, 'port_extension')
+        record = TestPortExtension(port_id=result['id'], value=value)
+        session.add(record)
+        result['port_extension'] = value
+
+    def process_update_port(self, plugin_context, data, result):
+        session = plugin_context.session
+        record = (session.query(TestPortExtension).
+                  filter_by(port_id=result['id']).one())
+        value = data.get('port_extension')
+        if value and value != record.value:
+            record.value = value
+        result['port_extension'] = record.value
+
+    def extend_port_dict(self, session, port_db, result):
+        value = port_db.extension.value if port_db.extension else ''
+        result['port_extension'] = value
index c3c637d79faf40251d89553d568f29a2ce94a234..689dd79a99d67c2766c04e1958091eec3e2cbcd1 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-from neutron.api import extensions
+import contextlib
+import mock
+
+from neutron import context
+from neutron import manager
 from neutron.plugins.ml2 import config
-from neutron.plugins.ml2 import driver_api as api
-from neutron.tests.unit.ml2 import extensions as test_extensions
+from neutron.tests.unit.ml2.drivers import ext_test
 from neutron.tests.unit.ml2 import test_ml2_plugin
 
 
@@ -26,6 +29,8 @@ class ExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
                                      self._extension_drivers,
                                      group='ml2')
         super(ExtensionDriverTestCase, self).setUp()
+        self._plugin = manager.NeutronManager.get_plugin()
+        self._ctxt = context.get_admin_context()
 
     def test_network_attr(self):
         with self.network() as network:
@@ -36,14 +41,14 @@ class ExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
             # Test list networks
             res = self._list('networks')
             val = res['networks'][0].get('network_extension')
-            self.assertEqual('Test_Network_Extension', val)
+            self.assertEqual('Test_Network_Extension_extend', val)
 
             # Test network update
             data = {'network':
                     {'network_extension': 'Test_Network_Extension_Update'}}
             res = self._update('networks', network['network']['id'], data)
             val = res['network'].get('network_extension')
-            self.assertEqual('Test_Network_Extension_Update', val)
+            self.assertEqual('Test_Network_Extension_Update_update', val)
 
     def test_subnet_attr(self):
         with self.subnet() as subnet:
@@ -54,14 +59,14 @@ class ExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
             # Test list subnets
             res = self._list('subnets')
             val = res['subnets'][0].get('subnet_extension')
-            self.assertEqual('Test_Subnet_Extension', val)
+            self.assertEqual('Test_Subnet_Extension_extend', val)
 
             # Test subnet update
             data = {'subnet':
                     {'subnet_extension': 'Test_Subnet_Extension_Update'}}
             res = self._update('subnets', subnet['subnet']['id'], data)
             val = res['subnet'].get('subnet_extension')
-            self.assertEqual('Test_Subnet_Extension_Update', val)
+            self.assertEqual('Test_Subnet_Extension_Update_update', val)
 
     def test_port_attr(self):
         with self.port() as port:
@@ -72,57 +77,181 @@ class ExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
             # Test list ports
             res = self._list('ports')
             val = res['ports'][0].get('port_extension')
-            self.assertEqual('Test_Port_Extension', val)
+            self.assertEqual('Test_Port_Extension_extend', val)
 
             # Test port update
             data = {'port': {'port_extension': 'Test_Port_Extension_Update'}}
             res = self._update('ports', port['port']['id'], data)
             val = res['port'].get('port_extension')
-            self.assertEqual('Test_Port_Extension_Update', val)
+            self.assertEqual('Test_Port_Extension_Update_update', val)
+
+    def test_extend_network_dict(self):
+        with contextlib.nested(
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'process_update_network'),
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'extend_network_dict'),
+            self.network()
+        ) as (ext_update_net, ext_net_dict, network):
+            net_id = network['network']['id']
+            net_data = {'network': {'id': net_id}}
+            self._plugin.update_network(self._ctxt, net_id, net_data)
+            self.assertTrue(ext_update_net.called)
+            self.assertTrue(ext_net_dict.called)
+
+    def test_extend_subnet_dict(self):
+        with contextlib.nested(
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'process_update_subnet'),
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'extend_subnet_dict'),
+            self.subnet()
+        ) as (ext_update_subnet, ext_subnet_dict, subnet):
+            subnet_id = subnet['subnet']['id']
+            subnet_data = {'subnet': {'id': subnet_id}}
+            self._plugin.update_subnet(self._ctxt, subnet_id, subnet_data)
+            self.assertTrue(ext_update_subnet.called)
+            self.assertTrue(ext_subnet_dict.called)
 
+    def test_extend_port_dict(self):
+        with contextlib.nested(
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'process_update_port'),
+            mock.patch.object(ext_test.TestExtensionDriver,
+                              'extend_port_dict'),
+            self.port()
+        ) as (ext_update_port, ext_port_dict, port):
+            port_id = port['port']['id']
+            port_data = {'port': {'id': port_id}}
+            self._plugin.update_port(self._ctxt, port_id, port_data)
+            self.assertTrue(ext_update_port.called)
+            self.assertTrue(ext_port_dict.called)
 
-class TestExtensionDriver(api.ExtensionDriver):
-    _supported_extension_alias = 'test_extension'
 
-    def initialize(self):
-        self.network_extension = 'Test_Network_Extension'
-        self.subnet_extension = 'Test_Subnet_Extension'
-        self.port_extension = 'Test_Port_Extension'
-        extensions.append_api_extensions_path(test_extensions.__path__)
+class DBExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
+    _extension_drivers = ['testdb']
+
+    def setUp(self):
+        config.cfg.CONF.set_override('extension_drivers',
+                                     self._extension_drivers,
+                                     group='ml2')
+        super(DBExtensionDriverTestCase, self).setUp()
+        self._plugin = manager.NeutronManager.get_plugin()
+        self._ctxt = context.get_admin_context()
+
+    def test_network_attr(self):
+        with self.network() as network:
+            # Test create with default value.
+            net_id = network['network']['id']
+            val = network['network']['network_extension']
+            self.assertEqual("", val)
+            res = self._show('networks', net_id)
+            val = res['network']['network_extension']
+            self.assertEqual("", val)
+
+            # Test list.
+            res = self._list('networks')
+            val = res['networks'][0]['network_extension']
+            self.assertEqual("", val)
 
-    @property
-    def extension_alias(self):
-        return self._supported_extension_alias
+        # Test create with explict value.
+        res = self._create_network(self.fmt,
+                                   'test-network', True,
+                                   arg_list=('network_extension', ),
+                                   network_extension="abc")
+        network = self.deserialize(self.fmt, res)
+        net_id = network['network']['id']
+        val = network['network']['network_extension']
+        self.assertEqual("abc", val)
+        res = self._show('networks', net_id)
+        val = res['network']['network_extension']
+        self.assertEqual("abc", val)
 
-    def process_create_network(self, plugin_context, data, result):
-        result['network_extension'] = self.network_extension
+        # Test update.
+        data = {'network': {'network_extension': "def"}}
+        res = self._update('networks', net_id, data)
+        val = res['network']['network_extension']
+        self.assertEqual("def", val)
+        res = self._show('networks', net_id)
+        val = res['network']['network_extension']
+        self.assertEqual("def", val)
 
-    def process_create_subnet(self, plugin_context, data, result):
-        result['subnet_extension'] = self.subnet_extension
+    def test_subnet_attr(self):
+        with self.subnet() as subnet:
+            # Test create with default value.
+            net_id = subnet['subnet']['id']
+            val = subnet['subnet']['subnet_extension']
+            self.assertEqual("", val)
+            res = self._show('subnets', net_id)
+            val = res['subnet']['subnet_extension']
+            self.assertEqual("", val)
 
-    def process_create_port(self, plugin_context, data, result):
-        result['port_extension'] = self.port_extension
+            # Test list.
+            res = self._list('subnets')
+            val = res['subnets'][0]['subnet_extension']
+            self.assertEqual("", val)
 
-    def process_update_network(self, plugin_context, data, result):
-        self.network_extension = data['network']['network_extension']
-        result['network_extension'] = self.network_extension
+        with self.network() as network:
+            # Test create with explict value.
+            data = {'subnet':
+                    {'network_id': network['network']['id'],
+                     'cidr': '10.1.0.0/24',
+                     'ip_version': '4',
+                     'tenant_id': self._tenant_id,
+                     'subnet_extension': 'abc'}}
+            req = self.new_create_request('subnets', data, self.fmt)
+            res = req.get_response(self.api)
+            subnet = self.deserialize(self.fmt, res)
+            subnet_id = subnet['subnet']['id']
+            val = subnet['subnet']['subnet_extension']
+            self.assertEqual("abc", val)
+            res = self._show('subnets', subnet_id)
+            val = res['subnet']['subnet_extension']
+            self.assertEqual("abc", val)
 
-    def process_update_subnet(self, plugin_context, data, result):
-        self.subnet_extension = data['subnet']['subnet_extension']
-        result['subnet_extension'] = self.subnet_extension
+            # Test update.
+            data = {'subnet': {'subnet_extension': "def"}}
+            res = self._update('subnets', subnet_id, data)
+            val = res['subnet']['subnet_extension']
+            self.assertEqual("def", val)
+            res = self._show('subnets', subnet_id)
+            val = res['subnet']['subnet_extension']
+            self.assertEqual("def", val)
 
-    def process_update_port(self, plugin_context, data, result):
-        self.port_extension = data['port_extension']
-        result['port_extension'] = self.port_extension
+    def test_port_attr(self):
+        with self.port() as port:
+            # Test create with default value.
+            net_id = port['port']['id']
+            val = port['port']['port_extension']
+            self.assertEqual("", val)
+            res = self._show('ports', net_id)
+            val = res['port']['port_extension']
+            self.assertEqual("", val)
 
-    def extend_network_dict(self, session, base_model, result):
-        if self._supported_extension_alias is 'test_extension':
-            result['network_extension'] = self.network_extension
+            # Test list.
+            res = self._list('ports')
+            val = res['ports'][0]['port_extension']
+            self.assertEqual("", val)
 
-    def extend_subnet_dict(self, session, base_model, result):
-        if self._supported_extension_alias is 'test_extension':
-            result['subnet_extension'] = self.subnet_extension
+        with self.network() as network:
+            # Test create with explict value.
+            res = self._create_port(self.fmt,
+                                    network['network']['id'],
+                                    arg_list=('port_extension', ),
+                                    port_extension="abc")
+            port = self.deserialize(self.fmt, res)
+            port_id = port['port']['id']
+            val = port['port']['port_extension']
+            self.assertEqual("abc", val)
+            res = self._show('ports', port_id)
+            val = res['port']['port_extension']
+            self.assertEqual("abc", val)
 
-    def extend_port_dict(self, session, base_model, result):
-        if self._supported_extension_alias is 'test_extension':
-            result['port_extension'] = self.port_extension
+            # Test update.
+            data = {'port': {'port_extension': "def"}}
+            res = self._update('ports', port_id, data)
+            val = res['port']['port_extension']
+            self.assertEqual("def", val)
+            res = self._show('ports', port_id)
+            val = res['port']['port_extension']
+            self.assertEqual("def", val)
index 6a5a5e5d472e0b28f9e29bcac715c3c7bf6c397a..297875087adb1e7f00bea00d71ddc52bc86e69e0 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -186,7 +186,8 @@ neutron.ml2.mechanism_drivers =
     nuage = neutron.plugins.ml2.drivers.mech_nuage.driver:NuageMechanismDriver
     fake_agent = neutron.tests.unit.ml2.drivers.mech_fake_agent:FakeAgentMechanismDriver
 neutron.ml2.extension_drivers =
-    test = neutron.tests.unit.ml2.test_extension_driver_api:TestExtensionDriver
+    test = neutron.tests.unit.ml2.drivers.ext_test:TestExtensionDriver
+    testdb = neutron.tests.unit.ml2.drivers.ext_test:TestDBExtensionDriver
 neutron.openstack.common.cache.backends =
     memory = neutron.openstack.common.cache._backends.memory:MemoryBackend
 # These are for backwards compat with Icehouse notification_driver configuration values