--- /dev/null
+# Copyright 2015 Intel Corporation.
+# Copyright 2015 Isaku Yamahata <isaku.yamahata at intel com>
+# <isaku.yamahata at gmail com>
+# All Rights Reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sqlalchemy as sa
+from sqlalchemy import orm
+
+import oslo_db.sqlalchemy.session
+
+from neutron.api import extensions
+from neutron.api.v2 import attributes
+from neutron.db import model_base
+from neutron.db import models_v2
+from neutron.plugins.ml2 import driver_api
+from neutron.tests.unit.ml2 import extensions as test_extensions
+
+
+class TestExtensionDriverBase(driver_api.ExtensionDriver):
+ _supported_extension_aliases = 'test_extension'
+
+ def initialize(self):
+ extensions.append_api_extensions_path(test_extensions.__path__)
+
+ @property
+ def extension_alias(self):
+ return self._supported_extension_aliases
+
+
+class TestExtensionDriver(TestExtensionDriverBase):
+ def initialize(self):
+ super(TestExtensionDriver, self).initialize()
+ self.network_extension = 'Test_Network_Extension'
+ self.subnet_extension = 'Test_Subnet_Extension'
+ self.port_extension = 'Test_Port_Extension'
+
+ def _check_create(self, session, data, result):
+ assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+ assert(isinstance(data, dict))
+ assert('id' not in data)
+ assert(isinstance(result, dict))
+ assert(result['id'] is not None)
+
+ def _check_update(self, session, data, result):
+ assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+ assert(isinstance(data, dict))
+ assert(isinstance(result, dict))
+ assert(result['id'] is not None)
+
+ def _check_extend(self, session, result, db_entry,
+ expected_db_entry_class):
+ assert(isinstance(session, oslo_db.sqlalchemy.session.Session))
+ assert(isinstance(result, dict))
+ assert(result['id'] is not None)
+ assert(isinstance(db_entry, expected_db_entry_class))
+ assert(db_entry.id == result['id'])
+
+ def process_create_network(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_create(session, data, result)
+ result['network_extension'] = self.network_extension + '_create'
+
+ def process_update_network(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_update(session, data, result)
+ self.network_extension = data['network_extension']
+ result['network_extension'] = self.network_extension + '_update'
+
+ def extend_network_dict(self, session, net_db, result):
+ self._check_extend(session, result, net_db, models_v2.Network)
+ result['network_extension'] = self.network_extension + '_extend'
+
+ def process_create_subnet(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_create(session, data, result)
+ result['subnet_extension'] = self.subnet_extension + '_create'
+
+ def process_update_subnet(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_update(session, data, result)
+ self.subnet_extension = data['subnet_extension']
+ result['subnet_extension'] = self.subnet_extension + '_update'
+
+ def extend_subnet_dict(self, session, subnet_db, result):
+ self._check_extend(session, result, subnet_db, models_v2.Subnet)
+ result['subnet_extension'] = self.subnet_extension + '_extend'
+
+ def process_create_port(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_create(session, data, result)
+ result['port_extension'] = self.port_extension + '_create'
+
+ def process_update_port(self, plugin_context, data, result):
+ session = plugin_context.session
+ self._check_update(session, data, result)
+ self.port_extension = data['port_extension']
+ result['port_extension'] = self.port_extension + '_update'
+
+ def extend_port_dict(self, session, port_db, result):
+ self._check_extend(session, result, port_db, models_v2.Port)
+ result['port_extension'] = self.port_extension + '_extend'
+
+
+class TestNetworkExtension(model_base.BASEV2):
+ network_id = sa.Column(sa.String(36),
+ sa.ForeignKey('networks.id', ondelete="CASCADE"),
+ primary_key=True)
+ value = sa.Column(sa.String(64))
+ network = orm.relationship(
+ models_v2.Network,
+ backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestSubnetExtension(model_base.BASEV2):
+ subnet_id = sa.Column(sa.String(36),
+ sa.ForeignKey('subnets.id', ondelete="CASCADE"),
+ primary_key=True)
+ value = sa.Column(sa.String(64))
+ subnet = orm.relationship(
+ models_v2.Subnet,
+ backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestPortExtension(model_base.BASEV2):
+ port_id = sa.Column(sa.String(36),
+ sa.ForeignKey('ports.id', ondelete="CASCADE"),
+ primary_key=True)
+ value = sa.Column(sa.String(64))
+ port = orm.relationship(
+ models_v2.Port,
+ backref=orm.backref('extension', cascade='delete', uselist=False))
+
+
+class TestDBExtensionDriver(TestExtensionDriverBase):
+ def _get_value(self, data, key):
+ value = data[key]
+ if not attributes.is_attr_set(value):
+ value = ''
+ return value
+
+ def process_create_network(self, plugin_context, data, result):
+ session = plugin_context.session
+ value = self._get_value(data, 'network_extension')
+ record = TestNetworkExtension(network_id=result['id'], value=value)
+ session.add(record)
+ result['network_extension'] = value
+
+ def process_update_network(self, plugin_context, data, result):
+ session = plugin_context.session
+ record = (session.query(TestNetworkExtension).
+ filter_by(network_id=result['id']).one())
+ value = data.get('network_extension')
+ if value and value != record.value:
+ record.value = value
+ result['network_extension'] = record.value
+
+ def extend_network_dict(self, session, net_db, result):
+ result['network_extension'] = net_db.extension.value
+
+ def process_create_subnet(self, plugin_context, data, result):
+ session = plugin_context.session
+ value = self._get_value(data, 'subnet_extension')
+ record = TestSubnetExtension(subnet_id=result['id'], value=value)
+ session.add(record)
+ result['subnet_extension'] = value
+
+ def process_update_subnet(self, plugin_context, data, result):
+ session = plugin_context.session
+ record = (session.query(TestSubnetExtension).
+ filter_by(subnet_id=result['id']).one())
+ value = data.get('subnet_extension')
+ if value and value != record.value:
+ record.value = value
+ result['subnet_extension'] = record.value
+
+ def extend_subnet_dict(self, session, subnet_db, result):
+ value = subnet_db.extension.value if subnet_db.extension else ''
+ result['subnet_extension'] = value
+
+ def process_create_port(self, plugin_context, data, result):
+ session = plugin_context.session
+ value = self._get_value(data, 'port_extension')
+ record = TestPortExtension(port_id=result['id'], value=value)
+ session.add(record)
+ result['port_extension'] = value
+
+ def process_update_port(self, plugin_context, data, result):
+ session = plugin_context.session
+ record = (session.query(TestPortExtension).
+ filter_by(port_id=result['id']).one())
+ value = data.get('port_extension')
+ if value and value != record.value:
+ record.value = value
+ result['port_extension'] = record.value
+
+ def extend_port_dict(self, session, port_db, result):
+ value = port_db.extension.value if port_db.extension else ''
+ result['port_extension'] = value
# License for the specific language governing permissions and limitations
# under the License.
-from neutron.api import extensions
+import contextlib
+import mock
+
+from neutron import context
+from neutron import manager
from neutron.plugins.ml2 import config
-from neutron.plugins.ml2 import driver_api as api
-from neutron.tests.unit.ml2 import extensions as test_extensions
+from neutron.tests.unit.ml2.drivers import ext_test
from neutron.tests.unit.ml2 import test_ml2_plugin
self._extension_drivers,
group='ml2')
super(ExtensionDriverTestCase, self).setUp()
+ self._plugin = manager.NeutronManager.get_plugin()
+ self._ctxt = context.get_admin_context()
def test_network_attr(self):
with self.network() as network:
# Test list networks
res = self._list('networks')
val = res['networks'][0].get('network_extension')
- self.assertEqual('Test_Network_Extension', val)
+ self.assertEqual('Test_Network_Extension_extend', val)
# Test network update
data = {'network':
{'network_extension': 'Test_Network_Extension_Update'}}
res = self._update('networks', network['network']['id'], data)
val = res['network'].get('network_extension')
- self.assertEqual('Test_Network_Extension_Update', val)
+ self.assertEqual('Test_Network_Extension_Update_update', val)
def test_subnet_attr(self):
with self.subnet() as subnet:
# Test list subnets
res = self._list('subnets')
val = res['subnets'][0].get('subnet_extension')
- self.assertEqual('Test_Subnet_Extension', val)
+ self.assertEqual('Test_Subnet_Extension_extend', val)
# Test subnet update
data = {'subnet':
{'subnet_extension': 'Test_Subnet_Extension_Update'}}
res = self._update('subnets', subnet['subnet']['id'], data)
val = res['subnet'].get('subnet_extension')
- self.assertEqual('Test_Subnet_Extension_Update', val)
+ self.assertEqual('Test_Subnet_Extension_Update_update', val)
def test_port_attr(self):
with self.port() as port:
# Test list ports
res = self._list('ports')
val = res['ports'][0].get('port_extension')
- self.assertEqual('Test_Port_Extension', val)
+ self.assertEqual('Test_Port_Extension_extend', val)
# Test port update
data = {'port': {'port_extension': 'Test_Port_Extension_Update'}}
res = self._update('ports', port['port']['id'], data)
val = res['port'].get('port_extension')
- self.assertEqual('Test_Port_Extension_Update', val)
+ self.assertEqual('Test_Port_Extension_Update_update', val)
+
+ def test_extend_network_dict(self):
+ with contextlib.nested(
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'process_update_network'),
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'extend_network_dict'),
+ self.network()
+ ) as (ext_update_net, ext_net_dict, network):
+ net_id = network['network']['id']
+ net_data = {'network': {'id': net_id}}
+ self._plugin.update_network(self._ctxt, net_id, net_data)
+ self.assertTrue(ext_update_net.called)
+ self.assertTrue(ext_net_dict.called)
+
+ def test_extend_subnet_dict(self):
+ with contextlib.nested(
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'process_update_subnet'),
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'extend_subnet_dict'),
+ self.subnet()
+ ) as (ext_update_subnet, ext_subnet_dict, subnet):
+ subnet_id = subnet['subnet']['id']
+ subnet_data = {'subnet': {'id': subnet_id}}
+ self._plugin.update_subnet(self._ctxt, subnet_id, subnet_data)
+ self.assertTrue(ext_update_subnet.called)
+ self.assertTrue(ext_subnet_dict.called)
+ def test_extend_port_dict(self):
+ with contextlib.nested(
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'process_update_port'),
+ mock.patch.object(ext_test.TestExtensionDriver,
+ 'extend_port_dict'),
+ self.port()
+ ) as (ext_update_port, ext_port_dict, port):
+ port_id = port['port']['id']
+ port_data = {'port': {'id': port_id}}
+ self._plugin.update_port(self._ctxt, port_id, port_data)
+ self.assertTrue(ext_update_port.called)
+ self.assertTrue(ext_port_dict.called)
-class TestExtensionDriver(api.ExtensionDriver):
- _supported_extension_alias = 'test_extension'
- def initialize(self):
- self.network_extension = 'Test_Network_Extension'
- self.subnet_extension = 'Test_Subnet_Extension'
- self.port_extension = 'Test_Port_Extension'
- extensions.append_api_extensions_path(test_extensions.__path__)
+class DBExtensionDriverTestCase(test_ml2_plugin.Ml2PluginV2TestCase):
+ _extension_drivers = ['testdb']
+
+ def setUp(self):
+ config.cfg.CONF.set_override('extension_drivers',
+ self._extension_drivers,
+ group='ml2')
+ super(DBExtensionDriverTestCase, self).setUp()
+ self._plugin = manager.NeutronManager.get_plugin()
+ self._ctxt = context.get_admin_context()
+
+ def test_network_attr(self):
+ with self.network() as network:
+ # Test create with default value.
+ net_id = network['network']['id']
+ val = network['network']['network_extension']
+ self.assertEqual("", val)
+ res = self._show('networks', net_id)
+ val = res['network']['network_extension']
+ self.assertEqual("", val)
+
+ # Test list.
+ res = self._list('networks')
+ val = res['networks'][0]['network_extension']
+ self.assertEqual("", val)
- @property
- def extension_alias(self):
- return self._supported_extension_alias
+ # Test create with explict value.
+ res = self._create_network(self.fmt,
+ 'test-network', True,
+ arg_list=('network_extension', ),
+ network_extension="abc")
+ network = self.deserialize(self.fmt, res)
+ net_id = network['network']['id']
+ val = network['network']['network_extension']
+ self.assertEqual("abc", val)
+ res = self._show('networks', net_id)
+ val = res['network']['network_extension']
+ self.assertEqual("abc", val)
- def process_create_network(self, plugin_context, data, result):
- result['network_extension'] = self.network_extension
+ # Test update.
+ data = {'network': {'network_extension': "def"}}
+ res = self._update('networks', net_id, data)
+ val = res['network']['network_extension']
+ self.assertEqual("def", val)
+ res = self._show('networks', net_id)
+ val = res['network']['network_extension']
+ self.assertEqual("def", val)
- def process_create_subnet(self, plugin_context, data, result):
- result['subnet_extension'] = self.subnet_extension
+ def test_subnet_attr(self):
+ with self.subnet() as subnet:
+ # Test create with default value.
+ net_id = subnet['subnet']['id']
+ val = subnet['subnet']['subnet_extension']
+ self.assertEqual("", val)
+ res = self._show('subnets', net_id)
+ val = res['subnet']['subnet_extension']
+ self.assertEqual("", val)
- def process_create_port(self, plugin_context, data, result):
- result['port_extension'] = self.port_extension
+ # Test list.
+ res = self._list('subnets')
+ val = res['subnets'][0]['subnet_extension']
+ self.assertEqual("", val)
- def process_update_network(self, plugin_context, data, result):
- self.network_extension = data['network']['network_extension']
- result['network_extension'] = self.network_extension
+ with self.network() as network:
+ # Test create with explict value.
+ data = {'subnet':
+ {'network_id': network['network']['id'],
+ 'cidr': '10.1.0.0/24',
+ 'ip_version': '4',
+ 'tenant_id': self._tenant_id,
+ 'subnet_extension': 'abc'}}
+ req = self.new_create_request('subnets', data, self.fmt)
+ res = req.get_response(self.api)
+ subnet = self.deserialize(self.fmt, res)
+ subnet_id = subnet['subnet']['id']
+ val = subnet['subnet']['subnet_extension']
+ self.assertEqual("abc", val)
+ res = self._show('subnets', subnet_id)
+ val = res['subnet']['subnet_extension']
+ self.assertEqual("abc", val)
- def process_update_subnet(self, plugin_context, data, result):
- self.subnet_extension = data['subnet']['subnet_extension']
- result['subnet_extension'] = self.subnet_extension
+ # Test update.
+ data = {'subnet': {'subnet_extension': "def"}}
+ res = self._update('subnets', subnet_id, data)
+ val = res['subnet']['subnet_extension']
+ self.assertEqual("def", val)
+ res = self._show('subnets', subnet_id)
+ val = res['subnet']['subnet_extension']
+ self.assertEqual("def", val)
- def process_update_port(self, plugin_context, data, result):
- self.port_extension = data['port_extension']
- result['port_extension'] = self.port_extension
+ def test_port_attr(self):
+ with self.port() as port:
+ # Test create with default value.
+ net_id = port['port']['id']
+ val = port['port']['port_extension']
+ self.assertEqual("", val)
+ res = self._show('ports', net_id)
+ val = res['port']['port_extension']
+ self.assertEqual("", val)
- def extend_network_dict(self, session, base_model, result):
- if self._supported_extension_alias is 'test_extension':
- result['network_extension'] = self.network_extension
+ # Test list.
+ res = self._list('ports')
+ val = res['ports'][0]['port_extension']
+ self.assertEqual("", val)
- def extend_subnet_dict(self, session, base_model, result):
- if self._supported_extension_alias is 'test_extension':
- result['subnet_extension'] = self.subnet_extension
+ with self.network() as network:
+ # Test create with explict value.
+ res = self._create_port(self.fmt,
+ network['network']['id'],
+ arg_list=('port_extension', ),
+ port_extension="abc")
+ port = self.deserialize(self.fmt, res)
+ port_id = port['port']['id']
+ val = port['port']['port_extension']
+ self.assertEqual("abc", val)
+ res = self._show('ports', port_id)
+ val = res['port']['port_extension']
+ self.assertEqual("abc", val)
- def extend_port_dict(self, session, base_model, result):
- if self._supported_extension_alias is 'test_extension':
- result['port_extension'] = self.port_extension
+ # Test update.
+ data = {'port': {'port_extension': "def"}}
+ res = self._update('ports', port_id, data)
+ val = res['port']['port_extension']
+ self.assertEqual("def", val)
+ res = self._show('ports', port_id)
+ val = res['port']['port_extension']
+ self.assertEqual("def", val)