]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Refactor type_tunnel/gre/vxlan to reduce duplicate code
authorCedric Brandily <zzelle@gmail.com>
Sat, 30 May 2015 18:41:29 +0000 (20:41 +0200)
committerCedric Brandily <zzelle@gmail.com>
Tue, 2 Jun 2015 08:03:48 +0000 (08:03 +0000)
gre and vxlan type drivers have similar implementations for multiple
methods:
* get_endpoint_by_host
* get_endpoint_by_ip
* delete_endpoint
* get_endpoints
* add_endpoint

This change abstracts these methods and moves the abstractions to the
new class EndpointTunnelTypeDriver.

Change-Id: Iab97f8283b6bf5586334958de950664f6e74202a

neutron/plugins/ml2/drivers/type_gre.py
neutron/plugins/ml2/drivers/type_tunnel.py
neutron/plugins/ml2/drivers/type_vxlan.py
neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py

index 134348b0697aea7692f4e82479e992d47e205128..18d7040f79ae2027c498fcc4d82a88dc3b1103f5 100644 (file)
@@ -66,10 +66,11 @@ class GreEndpoints(model_base.BASEV2):
         return "<GreTunnelEndpoint(%s)>" % self.ip_address
 
 
-class GreTypeDriver(type_tunnel.TunnelTypeDriver):
+class GreTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
 
     def __init__(self):
-        super(GreTypeDriver, self).__init__(GreAllocation)
+        super(GreTypeDriver, self).__init__(
+            GreAllocation, GreEndpoints)
 
     def get_type(self):
         return p_const.TYPE_GRE
@@ -127,45 +128,13 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver):
 
     def get_endpoints(self):
         """Get every gre endpoints from database."""
-
-        LOG.debug("get_gre_endpoints() called")
-        session = db_api.get_session()
-
-        gre_endpoints = session.query(GreEndpoints)
+        gre_endpoints = self._get_endpoints()
         return [{'ip_address': gre_endpoint.ip_address,
                  'host': gre_endpoint.host}
                 for gre_endpoint in gre_endpoints]
 
-    def get_endpoint_by_host(self, host):
-        LOG.debug("get_endpoint_by_host() called for host %s", host)
-        session = db_api.get_session()
-        return (session.query(GreEndpoints).
-                filter_by(host=host).first())
-
-    def get_endpoint_by_ip(self, ip):
-        LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
-        session = db_api.get_session()
-        return (session.query(GreEndpoints).
-                filter_by(ip_address=ip).first())
-
     def add_endpoint(self, ip, host):
-        LOG.debug("add_gre_endpoint() called for ip %s", ip)
-        session = db_api.get_session()
-        try:
-            gre_endpoint = GreEndpoints(ip_address=ip, host=host)
-            gre_endpoint.save(session)
-        except db_exc.DBDuplicateEntry:
-            gre_endpoint = (session.query(GreEndpoints).
-                            filter_by(ip_address=ip).one())
-            LOG.warning(_LW("Gre endpoint with ip %s already exists"), ip)
-        return gre_endpoint
-
-    def delete_endpoint(self, ip):
-        LOG.debug("delete_gre_endpoint() called for ip %s", ip)
-        session = db_api.get_session()
-
-        with session.begin(subtransactions=True):
-            session.query(GreEndpoints).filter_by(ip_address=ip).delete()
+        return self._add_endpoint(ip, host)
 
     def get_mtu(self, physical_network=None):
         mtu = super(GreTypeDriver, self).get_mtu(physical_network)
index 68ffc3d3b06550ccdaf64db6cf56a437652b4a29..12dce86f48fd42e262fddf41e68bae18cef8609f 100644 (file)
 import abc
 
 from oslo_config import cfg
+from oslo_db import exception as db_exc
 from oslo_log import log
 
 from neutron.common import exceptions as exc
 from neutron.common import topics
+from neutron.db import api as db_api
 from neutron.i18n import _LI, _LW
 from neutron.plugins.common import utils as plugin_utils
 from neutron.plugins.ml2 import driver_api as api
@@ -196,6 +198,50 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver):
         return min(mtu) if mtu else 0
 
 
+class EndpointTunnelTypeDriver(TunnelTypeDriver):
+
+    def __init__(self, segment_model, endpoint_model):
+        super(EndpointTunnelTypeDriver, self).__init__(segment_model)
+        self.endpoint_model = endpoint_model
+        self.segmentation_key = iter(self.primary_keys).next()
+
+    def get_endpoint_by_host(self, host):
+        LOG.debug("get_endpoint_by_host() called for host %s", host)
+        session = db_api.get_session()
+        return (session.query(self.endpoint_model).
+                filter_by(host=host).first())
+
+    def get_endpoint_by_ip(self, ip):
+        LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
+        session = db_api.get_session()
+        return (session.query(self.endpoint_model).
+                filter_by(ip_address=ip).first())
+
+    def delete_endpoint(self, ip):
+        LOG.debug("delete_endpoint() called for ip %s", ip)
+        session = db_api.get_session()
+        with session.begin(subtransactions=True):
+            (session.query(self.endpoint_model).
+             filter_by(ip_address=ip).delete())
+
+    def _get_endpoints(self):
+        LOG.debug("_get_endpoints() called")
+        session = db_api.get_session()
+        return session.query(self.endpoint_model)
+
+    def _add_endpoint(self, ip, host, **kwargs):
+        LOG.debug("_add_endpoint() called for ip %s", ip)
+        session = db_api.get_session()
+        try:
+            endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs)
+            endpoint.save(session)
+        except db_exc.DBDuplicateEntry:
+            endpoint = (session.query(self.endpoint_model).
+                        filter_by(ip_address=ip).one())
+            LOG.warning(_LW("Endpoint with ip %s already exists"), ip)
+        return endpoint
+
+
 class TunnelRpcCallbackMixin(object):
 
     def setup_tunnel_callback_mixin(self, notifier, type_manager):
index 51125701c226d4a0ea69715d879906096f108721..b8cdb003c33f02002ca5312c227df56dbefaa992 100644 (file)
@@ -14,7 +14,6 @@
 #    under the License.
 
 from oslo_config import cfg
-from oslo_db import exception as db_exc
 from oslo_log import log
 from six import moves
 import sqlalchemy as sa
@@ -23,7 +22,7 @@ from sqlalchemy import sql
 from neutron.common import exceptions as n_exc
 from neutron.db import api as db_api
 from neutron.db import model_base
-from neutron.i18n import _LE, _LW
+from neutron.i18n import _LE
 from neutron.plugins.common import constants as p_const
 from neutron.plugins.ml2.drivers import type_tunnel
 
@@ -70,10 +69,11 @@ class VxlanEndpoints(model_base.BASEV2):
         return "<VxlanTunnelEndpoint(%s)>" % self.ip_address
 
 
-class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
+class VxlanTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
 
     def __init__(self):
-        super(VxlanTypeDriver, self).__init__(VxlanAllocation)
+        super(VxlanTypeDriver, self).__init__(
+            VxlanAllocation, VxlanEndpoints)
 
     def get_type(self):
         return p_const.TYPE_VXLAN
@@ -132,48 +132,14 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
 
     def get_endpoints(self):
         """Get every vxlan endpoints from database."""
-
-        LOG.debug("get_vxlan_endpoints() called")
-        session = db_api.get_session()
-
-        vxlan_endpoints = session.query(VxlanEndpoints)
+        vxlan_endpoints = self._get_endpoints()
         return [{'ip_address': vxlan_endpoint.ip_address,
                  'udp_port': vxlan_endpoint.udp_port,
                  'host': vxlan_endpoint.host}
                 for vxlan_endpoint in vxlan_endpoints]
 
-    def get_endpoint_by_host(self, host):
-        LOG.debug("get_endpoint_by_host() called for host %s", host)
-        session = db_api.get_session()
-        return (session.query(VxlanEndpoints).
-                filter_by(host=host).first())
-
-    def get_endpoint_by_ip(self, ip):
-        LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
-        session = db_api.get_session()
-        return (session.query(VxlanEndpoints).
-                filter_by(ip_address=ip).first())
-
     def add_endpoint(self, ip, host, udp_port=p_const.VXLAN_UDP_PORT):
-        LOG.debug("add_vxlan_endpoint() called for ip %s", ip)
-        session = db_api.get_session()
-        try:
-            vxlan_endpoint = VxlanEndpoints(ip_address=ip,
-                                            udp_port=udp_port,
-                                            host=host)
-            vxlan_endpoint.save(session)
-        except db_exc.DBDuplicateEntry:
-            vxlan_endpoint = (session.query(VxlanEndpoints).
-                              filter_by(ip_address=ip).one())
-            LOG.warning(_LW("Vxlan endpoint with ip %s already exists"), ip)
-        return vxlan_endpoint
-
-    def delete_endpoint(self, ip):
-        LOG.debug("delete_vxlan_endpoint() called for ip %s", ip)
-        session = db_api.get_session()
-
-        with session.begin(subtransactions=True):
-            session.query(VxlanEndpoints).filter_by(ip_address=ip).delete()
+        return self._add_endpoint(ip, host, udp_port=udp_port)
 
     def get_mtu(self, physical_network=None):
         mtu = super(VxlanTypeDriver, self).get_mtu()
index 41431e0c898ed4d2bbbf8ce30ed44c89a702e85d..725fdaab18e51c7d8e387d4d84b2e55954a26062 100644 (file)
@@ -21,6 +21,7 @@ from testtools import matchers
 from neutron.common import exceptions as exc
 from neutron.db import api as db
 from neutron.plugins.ml2 import driver_api as api
+from neutron.plugins.ml2.drivers import type_tunnel
 
 TUNNEL_IP_ONE = "10.10.10.10"
 TUNNEL_IP_TWO = "10.10.10.20"
@@ -33,7 +34,6 @@ UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)]
 
 
 class TunnelTypeTestMixin(object):
-    DRIVER_MODULE = None
     DRIVER_CLASS = None
     TYPE = None
 
@@ -208,8 +208,7 @@ class TunnelTypeTestMixin(object):
     def test_add_endpoint_for_existing_tunnel_ip(self):
         self.add_endpoint()
 
-        log = getattr(self.DRIVER_MODULE, 'LOG')
-        with mock.patch.object(log, 'warning') as log_warn:
+        with mock.patch.object(type_tunnel.LOG, 'warning') as log_warn:
             self.add_endpoint()
             log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)