From: Cedric Brandily Date: Sat, 30 May 2015 18:41:29 +0000 (+0200) Subject: Refactor type_tunnel/gre/vxlan to reduce duplicate code X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=46223363bd4d41639102ae1923dd1dfb306ec808;p=openstack-build%2Fneutron-build.git Refactor type_tunnel/gre/vxlan to reduce duplicate code gre and vxlan type drivers have similar implementations for multiple methods: * get_endpoint_by_host * get_endpoint_by_ip * delete_endpoint * get_endpoints * add_endpoint This change abstracts these methods and moves the abstractions to the new class EndpointTunnelTypeDriver. Change-Id: Iab97f8283b6bf5586334958de950664f6e74202a --- diff --git a/neutron/plugins/ml2/drivers/type_gre.py b/neutron/plugins/ml2/drivers/type_gre.py index 134348b06..18d7040f7 100644 --- a/neutron/plugins/ml2/drivers/type_gre.py +++ b/neutron/plugins/ml2/drivers/type_gre.py @@ -66,10 +66,11 @@ class GreEndpoints(model_base.BASEV2): return "" % self.ip_address -class GreTypeDriver(type_tunnel.TunnelTypeDriver): +class GreTypeDriver(type_tunnel.EndpointTunnelTypeDriver): def __init__(self): - super(GreTypeDriver, self).__init__(GreAllocation) + super(GreTypeDriver, self).__init__( + GreAllocation, GreEndpoints) def get_type(self): return p_const.TYPE_GRE @@ -127,45 +128,13 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver): def get_endpoints(self): """Get every gre endpoints from database.""" - - LOG.debug("get_gre_endpoints() called") - session = db_api.get_session() - - gre_endpoints = session.query(GreEndpoints) + gre_endpoints = self._get_endpoints() return [{'ip_address': gre_endpoint.ip_address, 'host': gre_endpoint.host} for gre_endpoint in gre_endpoints] - def get_endpoint_by_host(self, host): - LOG.debug("get_endpoint_by_host() called for host %s", host) - session = db_api.get_session() - return (session.query(GreEndpoints). - filter_by(host=host).first()) - - def get_endpoint_by_ip(self, ip): - LOG.debug("get_endpoint_by_ip() called for ip %s", ip) - session = db_api.get_session() - return (session.query(GreEndpoints). - filter_by(ip_address=ip).first()) - def add_endpoint(self, ip, host): - LOG.debug("add_gre_endpoint() called for ip %s", ip) - session = db_api.get_session() - try: - gre_endpoint = GreEndpoints(ip_address=ip, host=host) - gre_endpoint.save(session) - except db_exc.DBDuplicateEntry: - gre_endpoint = (session.query(GreEndpoints). - filter_by(ip_address=ip).one()) - LOG.warning(_LW("Gre endpoint with ip %s already exists"), ip) - return gre_endpoint - - def delete_endpoint(self, ip): - LOG.debug("delete_gre_endpoint() called for ip %s", ip) - session = db_api.get_session() - - with session.begin(subtransactions=True): - session.query(GreEndpoints).filter_by(ip_address=ip).delete() + return self._add_endpoint(ip, host) def get_mtu(self, physical_network=None): mtu = super(GreTypeDriver, self).get_mtu(physical_network) diff --git a/neutron/plugins/ml2/drivers/type_tunnel.py b/neutron/plugins/ml2/drivers/type_tunnel.py index 68ffc3d3b..12dce86f4 100644 --- a/neutron/plugins/ml2/drivers/type_tunnel.py +++ b/neutron/plugins/ml2/drivers/type_tunnel.py @@ -15,10 +15,12 @@ import abc from oslo_config import cfg +from oslo_db import exception as db_exc from oslo_log import log from neutron.common import exceptions as exc from neutron.common import topics +from neutron.db import api as db_api from neutron.i18n import _LI, _LW from neutron.plugins.common import utils as plugin_utils from neutron.plugins.ml2 import driver_api as api @@ -196,6 +198,50 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver): return min(mtu) if mtu else 0 +class EndpointTunnelTypeDriver(TunnelTypeDriver): + + def __init__(self, segment_model, endpoint_model): + super(EndpointTunnelTypeDriver, self).__init__(segment_model) + self.endpoint_model = endpoint_model + self.segmentation_key = iter(self.primary_keys).next() + + def get_endpoint_by_host(self, host): + LOG.debug("get_endpoint_by_host() called for host %s", host) + session = db_api.get_session() + return (session.query(self.endpoint_model). + filter_by(host=host).first()) + + def get_endpoint_by_ip(self, ip): + LOG.debug("get_endpoint_by_ip() called for ip %s", ip) + session = db_api.get_session() + return (session.query(self.endpoint_model). + filter_by(ip_address=ip).first()) + + def delete_endpoint(self, ip): + LOG.debug("delete_endpoint() called for ip %s", ip) + session = db_api.get_session() + with session.begin(subtransactions=True): + (session.query(self.endpoint_model). + filter_by(ip_address=ip).delete()) + + def _get_endpoints(self): + LOG.debug("_get_endpoints() called") + session = db_api.get_session() + return session.query(self.endpoint_model) + + def _add_endpoint(self, ip, host, **kwargs): + LOG.debug("_add_endpoint() called for ip %s", ip) + session = db_api.get_session() + try: + endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs) + endpoint.save(session) + except db_exc.DBDuplicateEntry: + endpoint = (session.query(self.endpoint_model). + filter_by(ip_address=ip).one()) + LOG.warning(_LW("Endpoint with ip %s already exists"), ip) + return endpoint + + class TunnelRpcCallbackMixin(object): def setup_tunnel_callback_mixin(self, notifier, type_manager): diff --git a/neutron/plugins/ml2/drivers/type_vxlan.py b/neutron/plugins/ml2/drivers/type_vxlan.py index 51125701c..b8cdb003c 100644 --- a/neutron/plugins/ml2/drivers/type_vxlan.py +++ b/neutron/plugins/ml2/drivers/type_vxlan.py @@ -14,7 +14,6 @@ # under the License. from oslo_config import cfg -from oslo_db import exception as db_exc from oslo_log import log from six import moves import sqlalchemy as sa @@ -23,7 +22,7 @@ from sqlalchemy import sql from neutron.common import exceptions as n_exc from neutron.db import api as db_api from neutron.db import model_base -from neutron.i18n import _LE, _LW +from neutron.i18n import _LE from neutron.plugins.common import constants as p_const from neutron.plugins.ml2.drivers import type_tunnel @@ -70,10 +69,11 @@ class VxlanEndpoints(model_base.BASEV2): return "" % self.ip_address -class VxlanTypeDriver(type_tunnel.TunnelTypeDriver): +class VxlanTypeDriver(type_tunnel.EndpointTunnelTypeDriver): def __init__(self): - super(VxlanTypeDriver, self).__init__(VxlanAllocation) + super(VxlanTypeDriver, self).__init__( + VxlanAllocation, VxlanEndpoints) def get_type(self): return p_const.TYPE_VXLAN @@ -132,48 +132,14 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver): def get_endpoints(self): """Get every vxlan endpoints from database.""" - - LOG.debug("get_vxlan_endpoints() called") - session = db_api.get_session() - - vxlan_endpoints = session.query(VxlanEndpoints) + vxlan_endpoints = self._get_endpoints() return [{'ip_address': vxlan_endpoint.ip_address, 'udp_port': vxlan_endpoint.udp_port, 'host': vxlan_endpoint.host} for vxlan_endpoint in vxlan_endpoints] - def get_endpoint_by_host(self, host): - LOG.debug("get_endpoint_by_host() called for host %s", host) - session = db_api.get_session() - return (session.query(VxlanEndpoints). - filter_by(host=host).first()) - - def get_endpoint_by_ip(self, ip): - LOG.debug("get_endpoint_by_ip() called for ip %s", ip) - session = db_api.get_session() - return (session.query(VxlanEndpoints). - filter_by(ip_address=ip).first()) - def add_endpoint(self, ip, host, udp_port=p_const.VXLAN_UDP_PORT): - LOG.debug("add_vxlan_endpoint() called for ip %s", ip) - session = db_api.get_session() - try: - vxlan_endpoint = VxlanEndpoints(ip_address=ip, - udp_port=udp_port, - host=host) - vxlan_endpoint.save(session) - except db_exc.DBDuplicateEntry: - vxlan_endpoint = (session.query(VxlanEndpoints). - filter_by(ip_address=ip).one()) - LOG.warning(_LW("Vxlan endpoint with ip %s already exists"), ip) - return vxlan_endpoint - - def delete_endpoint(self, ip): - LOG.debug("delete_vxlan_endpoint() called for ip %s", ip) - session = db_api.get_session() - - with session.begin(subtransactions=True): - session.query(VxlanEndpoints).filter_by(ip_address=ip).delete() + return self._add_endpoint(ip, host, udp_port=udp_port) def get_mtu(self, physical_network=None): mtu = super(VxlanTypeDriver, self).get_mtu() diff --git a/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py b/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py index 41431e0c8..725fdaab1 100644 --- a/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py +++ b/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py @@ -21,6 +21,7 @@ from testtools import matchers from neutron.common import exceptions as exc from neutron.db import api as db from neutron.plugins.ml2 import driver_api as api +from neutron.plugins.ml2.drivers import type_tunnel TUNNEL_IP_ONE = "10.10.10.10" TUNNEL_IP_TWO = "10.10.10.20" @@ -33,7 +34,6 @@ UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)] class TunnelTypeTestMixin(object): - DRIVER_MODULE = None DRIVER_CLASS = None TYPE = None @@ -208,8 +208,7 @@ class TunnelTypeTestMixin(object): def test_add_endpoint_for_existing_tunnel_ip(self): self.add_endpoint() - log = getattr(self.DRIVER_MODULE, 'LOG') - with mock.patch.object(log, 'warning') as log_warn: + with mock.patch.object(type_tunnel.LOG, 'warning') as log_warn: self.add_endpoint() log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)