]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Stale VXLAN & GRE tunnel endpoint deletion from DB
authorRomil Gupta <romilg@hp.com>
Fri, 12 Sep 2014 06:26:57 +0000 (23:26 -0700)
committerRomil Gupta <romilg@hp.com>
Tue, 30 Dec 2014 13:01:55 +0000 (13:01 +0000)
Description:
Stale GRE and VXLAN tunnel endpoints persists in neutron db this should be
deleted from the database. Also, if local_ip of L2 agent changes the
stale tunnel ports and flows persists on br-tun on other Compute Nodes and
Network Nodes for that remote ip this should also be removed.

Implementation

Plugin changes:
Added host column in 'ml2_vxlan_endpoints' and 'ml2_gre_endpoints' table.
Added delete_endpoint method for deleting the stale endpoints from db.
Modified tunnel_sync() method to accommodate these changes.
Modified testcases in test_type_vxlan.py
Modified testcases in test_type_gre.py

Agent changes:
Added tunnel_delete rpc for removing stale ports and flows from br-tun.
tunnel_sync rpc signature upgrade to obtain 'host'.
Added testcases for TunnelRpcCallbackMixin().

This patch-set only deals with plugin side changes.

Partial-Bug: #1179223

Change-Id: I75c6581fcc9f47a68bde29cbefcaa1a2a082344e

neutron/db/migration/alembic_migrations/versions/38495dc99731_ml2_tunnel_endpoints_table.py [new file with mode: 0644]
neutron/db/migration/alembic_migrations/versions/HEAD
neutron/plugins/ml2/drivers/type_gre.py
neutron/plugins/ml2/drivers/type_tunnel.py
neutron/plugins/ml2/drivers/type_vxlan.py
neutron/tests/unit/ml2/test_type_gre.py
neutron/tests/unit/ml2/test_type_vxlan.py

diff --git a/neutron/db/migration/alembic_migrations/versions/38495dc99731_ml2_tunnel_endpoints_table.py b/neutron/db/migration/alembic_migrations/versions/38495dc99731_ml2_tunnel_endpoints_table.py
new file mode 100644 (file)
index 0000000..485b872
--- /dev/null
@@ -0,0 +1,68 @@
+# Copyright 2014 OpenStack Foundation
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+#
+
+"""ml2_tunnel_endpoints_table
+
+Revision ID: 38495dc99731
+Revises: 57086602ca0a
+Create Date: 2014-12-22 00:03:33.643799
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '38495dc99731'
+down_revision = '57086602ca0a'
+
+from alembic import op
+import sqlalchemy as sa
+
+CONSTRAINT_NAME_GRE = 'unique_ml2_gre_endpoints0host'
+CONSTRAINT_NAME_VXLAN = 'unique_ml2_vxlan_endpoints0host'
+
+
+def upgrade():
+
+    op.add_column('ml2_gre_endpoints',
+                  sa.Column('host', sa.String(length=255), nullable=True))
+    op.create_unique_constraint(
+        name=CONSTRAINT_NAME_GRE,
+        source='ml2_gre_endpoints',
+        local_cols=['host']
+    )
+
+    op.add_column('ml2_vxlan_endpoints',
+                  sa.Column('host', sa.String(length=255), nullable=True))
+    op.create_unique_constraint(
+        name=CONSTRAINT_NAME_VXLAN,
+        source='ml2_vxlan_endpoints',
+        local_cols=['host']
+    )
+
+
+def downgrade():
+
+    op.drop_constraint(
+        name=CONSTRAINT_NAME_VXLAN,
+        table_name='ml2_vxlan_endpoints',
+        type_='unique'
+    )
+    op.drop_column('ml2_vxlan_endpoints', 'host')
+
+    op.drop_constraint(
+        name=CONSTRAINT_NAME_GRE,
+        table_name='ml2_gre_endpoints',
+        type_='unique'
+    )
+    op.drop_column('ml2_gre_endpoints', 'host')
index 4a5339a75de360882ddf371718b9a249e4176e75..f7415c8a2d62a0a94437d9e9c0913e993fde2f8f 100644 (file)
@@ -1 +1 @@
-57086602ca0a
+38495dc99731
\ No newline at end of file
index c0565b53391f4165a67e65c4c954781a694fe15f..076bd2b2699e4f5a386df40f4d0e1d6769a6ac6d 100644 (file)
@@ -19,7 +19,7 @@ from six import moves
 import sqlalchemy as sa
 from sqlalchemy import sql
 
-from neutron.common import exceptions as exc
+from neutron.common import exceptions as n_exc
 from neutron.db import api as db_api
 from neutron.db import model_base
 from neutron.i18n import _LE, _LW
@@ -52,9 +52,14 @@ class GreAllocation(model_base.BASEV2):
 
 class GreEndpoints(model_base.BASEV2):
     """Represents tunnel endpoint in RPC mode."""
-    __tablename__ = 'ml2_gre_endpoints'
 
+    __tablename__ = 'ml2_gre_endpoints'
+    __table_args__ = (
+        sa.UniqueConstraint('host',
+                            name='unique_ml2_gre_endpoints0host'),
+    )
     ip_address = sa.Column(sa.String(64), primary_key=True)
+    host = sa.Column(sa.String(255), nullable=True)
 
     def __repr__(self):
         return "<GreTunnelEndpoint(%s)>" % self.ip_address
@@ -71,7 +76,7 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver):
     def initialize(self):
         try:
             self._initialize(cfg.CONF.ml2_type_gre.tunnel_id_ranges)
-        except exc.NetworkTunnelRangeError:
+        except n_exc.NetworkTunnelRangeError:
             LOG.exception(_LE("Failed to parse tunnel_id_ranges. "
                               "Service terminated!"))
             raise SystemExit()
@@ -115,19 +120,42 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver):
         LOG.debug("get_gre_endpoints() called")
         session = db_api.get_session()
 
-        with session.begin(subtransactions=True):
-            gre_endpoints = session.query(GreEndpoints)
-            return [{'ip_address': gre_endpoint.ip_address}
-                    for gre_endpoint in gre_endpoints]
+        gre_endpoints = session.query(GreEndpoints)
+        return [{'ip_address': gre_endpoint.ip_address,
+                 'host': gre_endpoint.host}
+                for gre_endpoint in gre_endpoints]
 
-    def add_endpoint(self, ip):
+    def get_endpoint_by_host(self, host):
+        LOG.debug("get_endpoint_by_host() called for host %s", host)
+        session = db_api.get_session()
+
+        host_endpoint = (session.query(GreEndpoints).
+                         filter_by(host=host).first())
+        return host_endpoint
+
+    def get_endpoint_by_ip(self, ip):
+        LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
+        session = db_api.get_session()
+
+        ip_endpoint = (session.query(GreEndpoints).
+                       filter_by(ip_address=ip).first())
+        return ip_endpoint
+
+    def add_endpoint(self, ip, host):
         LOG.debug("add_gre_endpoint() called for ip %s", ip)
         session = db_api.get_session()
         try:
-            gre_endpoint = GreEndpoints(ip_address=ip)
+            gre_endpoint = GreEndpoints(ip_address=ip, host=host)
             gre_endpoint.save(session)
         except db_exc.DBDuplicateEntry:
             gre_endpoint = (session.query(GreEndpoints).
                             filter_by(ip_address=ip).one())
             LOG.warning(_LW("Gre endpoint with ip %s already exists"), ip)
         return gre_endpoint
+
+    def delete_endpoint(self, ip):
+        LOG.debug("delete_gre_endpoint() called for ip %s", ip)
+        session = db_api.get_session()
+
+        with session.begin(subtransactions=True):
+            session.query(GreEndpoints).filter_by(ip_address=ip).delete()
index b0bad17834c719465b093cab89c3166f06de484e..c64e24e8e185957dd04beee7ed2571055aec20d5 100644 (file)
@@ -43,10 +43,11 @@ class TunnelTypeDriver(helpers.TypeDriverHelper):
         """Synchronize type_driver allocation table with configured ranges."""
 
     @abc.abstractmethod
-    def add_endpoint(self, ip):
+    def add_endpoint(self, ip, host):
         """Register the endpoint in the type_driver database.
 
-        param ip: the ip of the endpoint
+        param ip: the IP address of the endpoint
+        param host: the Host name of the endpoint
         """
         pass
 
@@ -54,7 +55,42 @@ class TunnelTypeDriver(helpers.TypeDriverHelper):
     def get_endpoints(self):
         """Get every endpoint managed by the type_driver
 
-        :returns a list of dict [{id:endpoint_id, ip_address:endpoint_ip},..]
+        :returns a list of dict [{ip_address:endpoint_ip, host:endpoint_host},
+        ..]
+        """
+        pass
+
+    @abc.abstractmethod
+    def get_endpoint_by_host(self, host):
+        """Get endpoint for a given host managed by the type_driver
+
+        param host: the Host name of the endpoint
+
+        if host found in type_driver database
+           :returns db object for that particular host
+        else
+           :returns None
+        """
+        pass
+
+    @abc.abstractmethod
+    def get_endpoint_by_ip(self, ip):
+        """Get endpoint for a given tunnel ip managed by the type_driver
+
+        param ip: the IP address of the endpoint
+
+        if ip found in type_driver database
+           :returns db object for that particular ip
+        else
+           :returns None
+        """
+        pass
+
+    @abc.abstractmethod
+    def delete_endpoint(self, ip):
+        """Delete the endpoint in the type_driver database.
+
+        param ip: the IP address of the endpoint
         """
         pass
 
@@ -160,13 +196,50 @@ class TunnelRpcCallbackMixin(object):
         be notified about the new tunnel IP.
         """
         tunnel_ip = kwargs.get('tunnel_ip')
+        if not tunnel_ip:
+            msg = _("Tunnel IP value needed by the ML2 plugin")
+            raise exc.InvalidInput(error_message=msg)
+
         tunnel_type = kwargs.get('tunnel_type')
         if not tunnel_type:
-            msg = _("Network_type value needed by the ML2 plugin")
+            msg = _("Network type value needed by the ML2 plugin")
             raise exc.InvalidInput(error_message=msg)
+
+        host = kwargs.get('host')
         driver = self._type_manager.drivers.get(tunnel_type)
         if driver:
-            tunnel = driver.obj.add_endpoint(tunnel_ip)
+            # The given conditional statements will verify the following
+            # things:
+            # 1. If host is not passed from an agent, it is a legacy mode.
+            # 2. If passed host and tunnel_ip are not found in the DB,
+            #    it is a new endpoint.
+            # 3. If host is passed from an agent and it is not found in DB
+            #    but the passed tunnel_ip is found, delete the endpoint
+            #    from DB and add the endpoint with (tunnel_ip, host),
+            #    it is an upgrade case.
+            # 4. If passed host is found in DB and passed tunnel ip is not
+            #    found, delete the endpoint belonging to that host and
+            #    add endpoint with latest (tunnel_ip, host), it is a case
+            #    where local_ip of an agent got changed.
+            if host:
+                host_endpoint = driver.obj.get_endpoint_by_host(host)
+                ip_endpoint = driver.obj.get_endpoint_by_ip(tunnel_ip)
+
+                if (ip_endpoint and ip_endpoint.host is None
+                    and host_endpoint is None):
+                    driver.obj.delete_endpoint(ip_endpoint.ip_address)
+                elif (ip_endpoint and ip_endpoint.host != host):
+                    msg = (_("Tunnel IP %(ip)s in use with host %(host)s"),
+                           {'ip': ip_endpoint.ip_address,
+                            'host': ip_endpoint.host})
+                    raise exc.InvalidInput(error_message=msg)
+                elif (host_endpoint and host_endpoint.ip_address != tunnel_ip):
+                    # Notify all other listening agents to delete stale tunnels
+                    self._notifier.tunnel_delete(rpc_context,
+                        host_endpoint.ip_address, tunnel_type)
+                    driver.obj.delete_endpoint(host_endpoint.ip_address)
+
+            tunnel = driver.obj.add_endpoint(tunnel_ip, host)
             tunnels = driver.obj.get_endpoints()
             entry = {'tunnels': tunnels}
             # Notify all other listening agents
@@ -175,7 +248,7 @@ class TunnelRpcCallbackMixin(object):
             # Return the list of tunnels IP's to the agent
             return entry
         else:
-            msg = _("network_type value '%s' not supported") % tunnel_type
+            msg = _("Network type value '%s' not supported") % tunnel_type
             raise exc.InvalidInput(error_message=msg)
 
 
@@ -191,3 +264,5 @@ class TunnelAgentRpcApiMixin(object):
                                     fanout=True)
         cctxt.cast(context, 'tunnel_update', tunnel_ip=tunnel_ip,
                    tunnel_type=tunnel_type)
+
+    # TODO(romilg): Add tunnel_delete rpc in dependent patch-set
index 834cd6b9b4c672219fa65c20eac1143bcd7f74ad..5c38d63eb07ee460cfbf3222b9aeeb80fb0f1aad 100644 (file)
@@ -19,7 +19,7 @@ from six import moves
 import sqlalchemy as sa
 from sqlalchemy import sql
 
-from neutron.common import exceptions as exc
+from neutron.common import exceptions as n_exc
 from neutron.db import api as db_api
 from neutron.db import model_base
 from neutron.i18n import _LE, _LW
@@ -58,10 +58,15 @@ class VxlanAllocation(model_base.BASEV2):
 
 class VxlanEndpoints(model_base.BASEV2):
     """Represents tunnel endpoint in RPC mode."""
-    __tablename__ = 'ml2_vxlan_endpoints'
 
+    __tablename__ = 'ml2_vxlan_endpoints'
+    __table_args__ = (
+        sa.UniqueConstraint('host',
+                            name='unique_ml2_vxlan_endpoints0host'),
+    )
     ip_address = sa.Column(sa.String(64), primary_key=True)
     udp_port = sa.Column(sa.Integer, nullable=False)
+    host = sa.Column(sa.String(255), nullable=True)
 
     def __repr__(self):
         return "<VxlanTunnelEndpoint(%s)>" % self.ip_address
@@ -78,7 +83,7 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
     def initialize(self):
         try:
             self._initialize(cfg.CONF.ml2_type_vxlan.vni_ranges)
-        except exc.NetworkTunnelRangeError:
+        except n_exc.NetworkTunnelRangeError:
             LOG.exception(_LE("Failed to parse vni_ranges. "
                               "Service terminated!"))
             raise SystemExit()
@@ -132,21 +137,45 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
         LOG.debug("get_vxlan_endpoints() called")
         session = db_api.get_session()
 
-        with session.begin(subtransactions=True):
-            vxlan_endpoints = session.query(VxlanEndpoints)
-            return [{'ip_address': vxlan_endpoint.ip_address,
-                     'udp_port': vxlan_endpoint.udp_port}
-                    for vxlan_endpoint in vxlan_endpoints]
+        vxlan_endpoints = session.query(VxlanEndpoints)
+        return [{'ip_address': vxlan_endpoint.ip_address,
+                 'udp_port': vxlan_endpoint.udp_port,
+                 'host': vxlan_endpoint.host}
+                for vxlan_endpoint in vxlan_endpoints]
 
-    def add_endpoint(self, ip, udp_port=VXLAN_UDP_PORT):
+    def get_endpoint_by_host(self, host):
+        LOG.debug("get_endpoint_by_host() called for host %s", host)
+        session = db_api.get_session()
+
+        host_endpoint = (session.query(VxlanEndpoints).
+                         filter_by(host=host).first())
+        return host_endpoint
+
+    def get_endpoint_by_ip(self, ip):
+        LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
+        session = db_api.get_session()
+
+        ip_endpoint = (session.query(VxlanEndpoints).
+                       filter_by(ip_address=ip).first())
+        return ip_endpoint
+
+    def add_endpoint(self, ip, host, udp_port=VXLAN_UDP_PORT):
         LOG.debug("add_vxlan_endpoint() called for ip %s", ip)
         session = db_api.get_session()
         try:
             vxlan_endpoint = VxlanEndpoints(ip_address=ip,
-                                            udp_port=udp_port)
+                                            udp_port=udp_port,
+                                            host=host)
             vxlan_endpoint.save(session)
         except db_exc.DBDuplicateEntry:
             vxlan_endpoint = (session.query(VxlanEndpoints).
                               filter_by(ip_address=ip).one())
             LOG.warning(_LW("Vxlan endpoint with ip %s already exists"), ip)
         return vxlan_endpoint
+
+    def delete_endpoint(self, ip):
+        LOG.debug("delete_vxlan_endpoint() called for ip %s", ip)
+        session = db_api.get_session()
+
+        with session.begin(subtransactions=True):
+            session.query(VxlanEndpoints).filter_by(ip_address=ip).delete()
index c7b9117e53b591f7bd3d21ecaa9f072a0226a361..a967132852076a3dd12c49aace2b1848769f2167 100644 (file)
@@ -23,6 +23,8 @@ from neutron.tests.unit import testlib_api
 
 TUNNEL_IP_ONE = "10.10.10.10"
 TUNNEL_IP_TWO = "10.10.10.20"
+HOST_ONE = 'fake_host_one'
+HOST_TWO = 'fake_host_two'
 
 
 class GreTypeTest(test_type_tunnel.TunnelTypeTestMixin,
@@ -30,23 +32,56 @@ class GreTypeTest(test_type_tunnel.TunnelTypeTestMixin,
     DRIVER_CLASS = type_gre.GreTypeDriver
     TYPE = p_const.TYPE_GRE
 
-    def test_endpoints(self):
-        tun_1 = self.driver.add_endpoint(TUNNEL_IP_ONE)
-        tun_2 = self.driver.add_endpoint(TUNNEL_IP_TWO)
-        self.assertEqual(TUNNEL_IP_ONE, tun_1.ip_address)
-        self.assertEqual(TUNNEL_IP_TWO, tun_2.ip_address)
+    def test_add_endpoint(self):
+        endpoint = self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+        self.assertEqual(TUNNEL_IP_ONE, endpoint.ip_address)
+        self.assertEqual(HOST_ONE, endpoint.host)
+
+    def test_add_endpoint_for_existing_tunnel_ip(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+
+        with mock.patch.object(type_gre.LOG, 'warning') as log_warn:
+            self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+            log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)
+
+    def test_get_endpoint_by_host(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+
+        host_endpoint = self.driver.get_endpoint_by_host(HOST_ONE)
+        self.assertEqual(TUNNEL_IP_ONE, host_endpoint.ip_address)
+
+    def test_get_endpoint_by_host_for_not_existing_host(self):
+        ip_endpoint = self.driver.get_endpoint_by_host(HOST_TWO)
+        self.assertIsNone(ip_endpoint)
+
+    def test_get_endpoint_by_ip(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+
+        ip_endpoint = self.driver.get_endpoint_by_ip(TUNNEL_IP_ONE)
+        self.assertEqual(HOST_ONE, ip_endpoint.host)
+
+    def test_get_endpoint_by_ip_for_not_existing_tunnel_ip(self):
+        ip_endpoint = self.driver.get_endpoint_by_ip(TUNNEL_IP_TWO)
+        self.assertIsNone(ip_endpoint)
+
+    def test_get_endpoints(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+        self.driver.add_endpoint(TUNNEL_IP_TWO, HOST_TWO)
 
-        # Get all the endpoints
         endpoints = self.driver.get_endpoints()
         for endpoint in endpoints:
-            self.assertIn(endpoint['ip_address'],
-                          [TUNNEL_IP_ONE, TUNNEL_IP_TWO])
+            if endpoint['ip_address'] == TUNNEL_IP_ONE:
+                self.assertEqual(HOST_ONE, endpoint['host'])
+            elif endpoint['ip_address'] == TUNNEL_IP_TWO:
+                self.assertEqual(HOST_TWO, endpoint['host'])
 
-    def test_add_same_ip_endpoints(self):
-        self.driver.add_endpoint(TUNNEL_IP_ONE)
-        with mock.patch.object(type_gre.LOG, 'warning') as log_warn:
-            self.driver.add_endpoint(TUNNEL_IP_ONE)
-        log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)
+    def test_delete_endpoint(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE)
+
+        self.assertIsNone(self.driver.delete_endpoint(TUNNEL_IP_ONE))
+        # Get all the endpoints and verify its empty
+        endpoints = self.driver.get_endpoints()
+        self.assertNotIn(TUNNEL_IP_ONE, endpoints)
 
 
 class GreTypeMultiRangeTest(test_type_tunnel.TunnelTypeMultiRangeTestMixin,
index 4d316e3666e1862c499f767ca65bcea9c6d779e1..8b60706a34b5e530f07a16673673da1db2d1bcdf 100644 (file)
@@ -23,6 +23,8 @@ from neutron.tests.unit import testlib_api
 
 TUNNEL_IP_ONE = "10.10.10.10"
 TUNNEL_IP_TWO = "10.10.10.20"
+HOST_ONE = 'fake_host_one'
+HOST_TWO = 'fake_host_two'
 VXLAN_UDP_PORT_ONE = 9999
 VXLAN_UDP_PORT_TWO = 8888
 
@@ -32,34 +34,63 @@ class VxlanTypeTest(test_type_tunnel.TunnelTypeTestMixin,
     DRIVER_CLASS = type_vxlan.VxlanTypeDriver
     TYPE = p_const.TYPE_VXLAN
 
-    def test_endpoints(self):
-        # Set first endpoint, verify it gets VXLAN VNI 1
-        vxlan1_endpoint = self.driver.add_endpoint(TUNNEL_IP_ONE,
-                                                   VXLAN_UDP_PORT_ONE)
-        self.assertEqual(TUNNEL_IP_ONE, vxlan1_endpoint.ip_address)
-        self.assertEqual(VXLAN_UDP_PORT_ONE, vxlan1_endpoint.udp_port)
+    def test_add_endpoint(self):
+        endpoint = self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE,
+                                            VXLAN_UDP_PORT_ONE)
+        self.assertEqual(TUNNEL_IP_ONE, endpoint.ip_address)
+        self.assertEqual(VXLAN_UDP_PORT_ONE, endpoint.udp_port)
+        self.assertEqual(HOST_ONE, endpoint.host)
 
-        # Set second endpoint, verify it gets VXLAN VNI 2
-        vxlan2_endpoint = self.driver.add_endpoint(TUNNEL_IP_TWO,
-                                                   VXLAN_UDP_PORT_TWO)
-        self.assertEqual(TUNNEL_IP_TWO, vxlan2_endpoint.ip_address)
-        self.assertEqual(VXLAN_UDP_PORT_TWO, vxlan2_endpoint.udp_port)
+    def test_add_endpoint_for_existing_tunnel_ip(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE, VXLAN_UDP_PORT_ONE)
+
+        with mock.patch.object(type_vxlan.LOG, 'warning') as log_warn:
+            self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE,
+                                     VXLAN_UDP_PORT_ONE)
+            log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)
+
+    def test_get_endpoint_by_host(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE, VXLAN_UDP_PORT_ONE)
+
+        host_endpoint = self.driver.get_endpoint_by_host(HOST_ONE)
+        self.assertEqual(TUNNEL_IP_ONE, host_endpoint.ip_address)
+        self.assertEqual(VXLAN_UDP_PORT_ONE, host_endpoint.udp_port)
+
+    def test_get_endpoint_by_host_for_not_existing_host(self):
+        ip_endpoint = self.driver.get_endpoint_by_host(HOST_TWO)
+        self.assertIsNone(ip_endpoint)
+
+    def test_get_endpoint_by_ip(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE, VXLAN_UDP_PORT_ONE)
+
+        ip_endpoint = self.driver.get_endpoint_by_ip(TUNNEL_IP_ONE)
+        self.assertEqual(HOST_ONE, ip_endpoint.host)
+        self.assertEqual(VXLAN_UDP_PORT_ONE, ip_endpoint.udp_port)
+
+    def test_get_endpoint_by_ip_for_not_existing_tunnel_ip(self):
+        ip_endpoint = self.driver.get_endpoint_by_ip(TUNNEL_IP_TWO)
+        self.assertIsNone(ip_endpoint)
+
+    def test_get_endpoints(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE, VXLAN_UDP_PORT_ONE)
+        self.driver.add_endpoint(TUNNEL_IP_TWO, HOST_TWO, VXLAN_UDP_PORT_TWO)
 
-        # Get all the endpoints
         endpoints = self.driver.get_endpoints()
         for endpoint in endpoints:
             if endpoint['ip_address'] == TUNNEL_IP_ONE:
                 self.assertEqual(VXLAN_UDP_PORT_ONE, endpoint['udp_port'])
+                self.assertEqual(HOST_ONE, endpoint['host'])
             elif endpoint['ip_address'] == TUNNEL_IP_TWO:
                 self.assertEqual(VXLAN_UDP_PORT_TWO, endpoint['udp_port'])
+                self.assertEqual(HOST_TWO, endpoint['host'])
 
-    def test_add_same_ip_endpoints(self):
-        self.driver.add_endpoint(TUNNEL_IP_ONE, VXLAN_UDP_PORT_ONE)
-        with mock.patch.object(type_vxlan.LOG, 'warning') as log_warn:
-            observed = self.driver.add_endpoint(TUNNEL_IP_ONE,
-                                                VXLAN_UDP_PORT_TWO)
-            self.assertEqual(VXLAN_UDP_PORT_ONE, observed['udp_port'])
-            log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)
+    def test_delete_endpoint(self):
+        self.driver.add_endpoint(TUNNEL_IP_ONE, HOST_ONE, VXLAN_UDP_PORT_ONE)
+
+        self.assertIsNone(self.driver.delete_endpoint(TUNNEL_IP_ONE))
+        # Get all the endpoints and verify its empty
+        endpoints = self.driver.get_endpoints()
+        self.assertNotIn(TUNNEL_IP_ONE, endpoints)
 
 
 class VxlanTypeMultiRangeTest(test_type_tunnel.TunnelTypeMultiRangeTestMixin,