]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
QoS agent extension and driver refactoring
authorMoshe Levi <moshele@mellanox.com>
Mon, 10 Aug 2015 09:25:59 +0000 (12:25 +0300)
committerMiguel Angel Ajo <mangelajo@redhat.com>
Tue, 15 Sep 2015 14:29:25 +0000 (16:29 +0200)
Moved some code common to all drivers into base
qos driver abstract class, so related bugfixes go all in one
place and we simplify the logic for every qos drivers.

Port/Policy mapping moved out to a separate class.

Support delete per rule_type or delete all rules.

Related-bug: #1486039

Co-Authored-By: Miguel Angel Ajo <mangelajo@redhat.com>
Partially-Implements: blueprint ml2-qos
Change-Id: Ia9d8638b9268b5aa8512cbb9d001413751f82649

neutron/agent/l2/extensions/qos.py
neutron/plugins/ml2/drivers/mech_sriov/agent/extension_drivers/qos_driver.py
neutron/plugins/ml2/drivers/openvswitch/agent/extension_drivers/qos_driver.py
neutron/tests/unit/agent/l2/extensions/test_qos.py

index 13e94cb290d8477ddaebd1f2515025eb81862947..06d47161dd7a73928b92dafa6eed2bf948abce9a 100644 (file)
@@ -17,6 +17,7 @@ import abc
 import collections
 
 from oslo_concurrency import lockutils
+from oslo_log import log as logging
 import six
 
 from neutron.agent.l2 import agent_extension
@@ -24,8 +25,12 @@ from neutron.api.rpc.callbacks.consumer import registry
 from neutron.api.rpc.callbacks import events
 from neutron.api.rpc.callbacks import resources
 from neutron.api.rpc.handlers import resources_rpc
+from neutron.common import exceptions
+from neutron.i18n import _LW, _LI
 from neutron import manager
 
+LOG = logging.getLogger(__name__)
+
 
 @six.add_metaclass(abc.ABCMeta)
 class QosAgentDriver(object):
@@ -35,36 +40,126 @@ class QosAgentDriver(object):
     for applying QoS Rules on a port.
     """
 
+    # Each QoS driver should define the set of rule types that it supports, and
+    # correspoding handlers that has the following names:
+    #
+    # create_<type>
+    # update_<type>
+    # delete_<type>
+    #
+    # where <type> is one of VALID_RULE_TYPES
+    SUPPORTED_RULES = set()
+
     @abc.abstractmethod
     def initialize(self):
         """Perform QoS agent driver initialization.
         """
 
-    @abc.abstractmethod
     def create(self, port, qos_policy):
         """Apply QoS rules on port for the first time.
 
         :param port: port object.
         :param qos_policy: the QoS policy to be applied on port.
         """
-        #TODO(QoS) we may want to provide default implementations of calling
-        #delete and then update
+        self._handle_update_create_rules('create', port, qos_policy)
 
-    @abc.abstractmethod
     def update(self, port, qos_policy):
         """Apply QoS rules on port.
 
         :param port: port object.
         :param qos_policy: the QoS policy to be applied on port.
         """
+        self._handle_update_create_rules('update', port, qos_policy)
 
-    @abc.abstractmethod
-    def delete(self, port, qos_policy):
+    def delete(self, port, qos_policy=None):
         """Remove QoS rules from port.
 
         :param port: port object.
         :param qos_policy: the QoS policy to be removed from port.
         """
+        if qos_policy is None:
+            rule_types = self.SUPPORTED_RULES
+        else:
+            rule_types = set(
+                [rule.rule_type
+                 for rule in self._iterate_rules(qos_policy.rules)])
+
+        for rule_type in rule_types:
+            self._handle_rule_delete(port, rule_type)
+
+    def _iterate_rules(self, rules):
+        for rule in rules:
+            rule_type = rule.rule_type
+            if rule_type in self.SUPPORTED_RULES:
+                yield rule
+            else:
+                LOG.warning(_LW('Unsupported QoS rule type for %(rule_id)s: '
+                                '%(rule_type)s; skipping'),
+                            {'rule_id': rule.id, 'rule_type': rule_type})
+
+    def _handle_rule_delete(self, port, rule_type):
+        handler_name = "".join(("delete_", rule_type))
+        handler = getattr(self, handler_name)
+        handler(port)
+
+    def _handle_update_create_rules(self, action, port, qos_policy):
+        for rule in self._iterate_rules(qos_policy.rules):
+            handler_name = "".join((action, "_", rule.rule_type))
+            handler = getattr(self, handler_name)
+            handler(port, rule)
+
+
+class PortPolicyMap(object):
+    def __init__(self):
+        # we cannot use a dict of sets here because port dicts are not hashable
+        self.qos_policy_ports = collections.defaultdict(dict)
+        self.known_policies = {}
+        self.port_policies = {}
+
+    def get_ports(self, policy):
+        return self.qos_policy_ports[policy.id].values()
+
+    def get_policy(self, policy_id):
+        return self.known_policies.get(policy_id)
+
+    def update_policy(self, policy):
+        self.known_policies[policy.id] = policy
+
+    def has_policy_changed(self, port, policy_id):
+        return self.port_policies.get(port['port_id']) != policy_id
+
+    def get_port_policy(self, port):
+        policy_id = self.port_policies.get(port['port_id'])
+        if policy_id:
+            return self.get_policy(policy_id)
+
+    def set_port_policy(self, port, policy):
+        """Attach a port to policy and return any previous policy on port."""
+        port_id = port['port_id']
+        old_policy = self.get_port_policy(port)
+        self.known_policies[policy.id] = policy
+        self.port_policies[port_id] = policy.id
+        self.qos_policy_ports[policy.id][port_id] = port
+        if old_policy and old_policy.id != policy.id:
+            del self.qos_policy_ports[old_policy.id][port_id]
+        return old_policy
+
+    def clean_by_port(self, port):
+        """Detach port from policy and cleanup data we don't need anymore."""
+        port_id = port['port_id']
+        if port_id in self.port_policies:
+            del self.port_policies[port_id]
+            for qos_policy_id, port_dict in self.qos_policy_ports.items():
+                if port_id in port_dict:
+                    del port_dict[port_id]
+                    if not port_dict:
+                        self._clean_policy_info(qos_policy_id)
+                    return
+        raise exceptions.PortNotFound(port_id=port['port_id'])
+
+    def _clean_policy_info(self, qos_policy_id):
+        del self.qos_policy_ports[qos_policy_id]
+        del self.known_policies[qos_policy_id]
 
 
 class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
@@ -79,9 +174,7 @@ class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
             'neutron.qos.agent_drivers', driver_type)()
         self.qos_driver.initialize()
 
-        # we cannot use a dict of sets here because port dicts are not hashable
-        self.qos_policy_ports = collections.defaultdict(dict)
-        self.known_ports = set()
+        self.policy_map = PortPolicyMap()
 
         registry.subscribe(self._handle_notification, resources.QOS_POLICY)
         self._register_rpc_consumers(connection)
@@ -116,34 +209,43 @@ class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
             self._process_reset_port(port)
             return
 
-        #Note(moshele) check if we have seen this port
-        #and it has the same policy we do nothing.
-        if (port_id in self.known_ports and
-                port_id in self.qos_policy_ports[qos_policy_id]):
+        if not self.policy_map.has_policy_changed(port, qos_policy_id):
             return
 
-        self.qos_policy_ports[qos_policy_id][port_id] = port
-        self.known_ports.add(port_id)
         qos_policy = self.resource_rpc.pull(
             context, resources.QOS_POLICY, qos_policy_id)
-        self.qos_driver.create(port, qos_policy)
+        if qos_policy is None:
+            LOG.info(_LI("QoS policy %(qos_policy_id)s applied to port "
+                         "%(port_id)s is not available on server, "
+                         "it has been deleted. Skipping."),
+                     {'qos_policy_id': qos_policy_id, 'port_id': port_id})
+            self._process_reset_port(port)
+        else:
+            old_qos_policy = self.policy_map.set_port_policy(port, qos_policy)
+            if old_qos_policy:
+                self.qos_driver.delete(port, old_qos_policy)
+                self.qos_driver.update(port, qos_policy)
+            else:
+                self.qos_driver.create(port, qos_policy)
 
     def delete_port(self, context, port):
         self._process_reset_port(port)
 
     def _process_update_policy(self, qos_policy):
-        for port_id, port in self.qos_policy_ports[qos_policy.id].items():
-            # TODO(QoS): for now, just reflush the rules on the port. Later, we
-            # may want to apply the difference between the rules lists only.
-            self.qos_driver.delete(port, None)
+        old_qos_policy = self.policy_map.get_policy(qos_policy.id)
+        for port in self.policy_map.get_ports(qos_policy):
+            #NOTE(QoS): for now, just reflush the rules on the port. Later, we
+            #           may want to apply the difference between the old and
+            #           new rule lists.
+            self.qos_driver.delete(port, old_qos_policy)
             self.qos_driver.update(port, qos_policy)
+            self.policy_map.update_policy(qos_policy)
 
     def _process_reset_port(self, port):
-        port_id = port['port_id']
-        if port_id in self.known_ports:
-            self.known_ports.remove(port_id)
-            for qos_policy_id, port_dict in self.qos_policy_ports.items():
-                if port_id in port_dict:
-                    del port_dict[port_id]
-                    self.qos_driver.delete(port, None)
-                    return
+        try:
+            self.policy_map.clean_by_port(port)
+            self.qos_driver.delete(port)
+        except exceptions.PortNotFound:
+            LOG.info(_LI("QoS extension did have no information about the "
+                         "port %s that we were trying to reset"),
+                     port['port_id'])
index 4822360a717679819fecb25679d7ff59c070cc13..b20f06dfbf3e14b4ba78f748fb58ad9a922c0dee 100755 (executable)
@@ -15,7 +15,7 @@
 from oslo_log import log as logging
 
 from neutron.agent.l2.extensions import qos
-from neutron.i18n import _LE, _LI, _LW
+from neutron.i18n import _LE, _LI
 from neutron.plugins.ml2.drivers.mech_sriov.agent.common import (
     exceptions as exc)
 from neutron.plugins.ml2.drivers.mech_sriov.agent import eswitch_manager as esm
@@ -27,7 +27,7 @@ LOG = logging.getLogger(__name__)
 
 class QosSRIOVAgentDriver(qos.QosAgentDriver):
 
-    _SUPPORTED_RULES = (
+    SUPPORTED_RULES = (
         mech_driver.SriovNicSwitchMechanismDriver.supported_qos_rule_types)
 
     def __init__(self):
@@ -37,37 +37,15 @@ class QosSRIOVAgentDriver(qos.QosAgentDriver):
     def initialize(self):
         self.eswitch_mgr = esm.ESwitchManager()
 
-    def create(self, port, qos_policy):
-        self._handle_rules('create', port, qos_policy)
+    def create_bandwidth_limit(self, port, rule):
+        self.update_bandwidth_limit(port, rule)
 
-    def update(self, port, qos_policy):
-        self._handle_rules('update', port, qos_policy)
-
-    def delete(self, port, qos_policy):
-        # TODO(QoS): consider optimizing flushing of all QoS rules from the
-        # port by inspecting qos_policy.rules contents
-        self._delete_bandwidth_limit(port)
-
-    def _handle_rules(self, action, port, qos_policy):
-        for rule in qos_policy.rules:
-            if rule.rule_type in self._SUPPORTED_RULES:
-                handler_name = ("".join(("_", action, "_", rule.rule_type)))
-                handler = getattr(self, handler_name)
-                handler(port, rule)
-            else:
-                LOG.warning(_LW('Unsupported QoS rule type for %(rule_id)s: '
-                            '%(rule_type)s; skipping'),
-                            {'rule_id': rule.id, 'rule_type': rule.rule_type})
-
-    def _create_bandwidth_limit(self, port, rule):
-        self._update_bandwidth_limit(port, rule)
-
-    def _update_bandwidth_limit(self, port, rule):
+    def update_bandwidth_limit(self, port, rule):
         pci_slot = port['profile'].get('pci_slot')
         device = port['device']
         self._set_vf_max_rate(device, pci_slot, rule.max_kbps)
 
-    def _delete_bandwidth_limit(self, port):
+    def delete_bandwidth_limit(self, port):
         pci_slot = port['profile'].get('pci_slot')
         self.eswitch_mgr.clear_max_rate(pci_slot)
 
index ce9f28687808b2048b37545f93eedccd3847d228..5977083794cbe2b292731187cb8e3ddad4a0d4cd 100644 (file)
 #    under the License.
 
 from oslo_config import cfg
-from oslo_log import log as logging
 
 from neutron.agent.common import ovs_lib
 from neutron.agent.l2.extensions import qos
-from neutron.i18n import _LW
 from neutron.plugins.ml2.drivers.openvswitch.mech_driver import (
     mech_openvswitch)
 
-LOG = logging.getLogger(__name__)
-
 
 class QosOVSAgentDriver(qos.QosAgentDriver):
 
-    _SUPPORTED_RULES = (
+    SUPPORTED_RULES = (
         mech_openvswitch.OpenvswitchMechanismDriver.supported_qos_rule_types)
 
     def __init__(self):
@@ -37,32 +33,10 @@ class QosOVSAgentDriver(qos.QosAgentDriver):
     def initialize(self):
         self.br_int = ovs_lib.OVSBridge(self.br_int_name)
 
-    def create(self, port, qos_policy):
-        self._handle_rules('create', port, qos_policy)
-
-    def update(self, port, qos_policy):
-        self._handle_rules('update', port, qos_policy)
-
-    def delete(self, port, qos_policy):
-        # TODO(QoS): consider optimizing flushing of all QoS rules from the
-        # port by inspecting qos_policy.rules contents
-        self._delete_bandwidth_limit(port)
-
-    def _handle_rules(self, action, port, qos_policy):
-        for rule in qos_policy.rules:
-            if rule.rule_type in self._SUPPORTED_RULES:
-                handler_name = ("".join(("_", action, "_", rule.rule_type)))
-                handler = getattr(self, handler_name)
-                handler(port, rule)
-            else:
-                LOG.warning(_LW('Unsupported QoS rule type for %(rule_id)s: '
-                            '%(rule_type)s; skipping'),
-                            {'rule_id': rule.id, 'rule_type': rule.rule_type})
-
-    def _create_bandwidth_limit(self, port, rule):
-        self._update_bandwidth_limit(port, rule)
+    def create_bandwidth_limit(self, port, rule):
+        self.update_bandwidth_limit(port, rule)
 
-    def _update_bandwidth_limit(self, port, rule):
+    def update_bandwidth_limit(self, port, rule):
         port_name = port['vif_port'].port_name
         max_kbps = rule.max_kbps
         max_burst_kbps = rule.max_burst_kbps
@@ -71,6 +45,6 @@ class QosOVSAgentDriver(qos.QosAgentDriver):
                                                     max_kbps,
                                                     max_burst_kbps)
 
-    def _delete_bandwidth_limit(self, port):
+    def delete_bandwidth_limit(self, port):
         port_name = port['vif_port'].port_name
         self.br_int.delete_egress_bw_limit_for_port(port_name)
index 0ff6175c56056bf34b46aa316f1f2882b684bf67..47d42fc16521cae8fa531ee82b9493defa6b9162 100755 (executable)
@@ -21,12 +21,82 @@ from neutron.api.rpc.callbacks.consumer import registry
 from neutron.api.rpc.callbacks import events
 from neutron.api.rpc.callbacks import resources
 from neutron.api.rpc.handlers import resources_rpc
+from neutron.common import exceptions
 from neutron import context
+from neutron.objects.qos import policy
+from neutron.objects.qos import rule
 from neutron.plugins.ml2.drivers.openvswitch.agent.common import constants
+from neutron.services.qos import qos_consts
 from neutron.tests import base
 
 
-TEST_POLICY = object()
+TEST_POLICY = policy.QosPolicy(context=None,
+                               name='test1', id='fake_policy_id')
+TEST_POLICY2 = policy.QosPolicy(context=None,
+                                name='test2', id='fake_policy_id_2')
+
+TEST_PORT = {'port_id': 'test_port_id',
+             'qos_policy_id': TEST_POLICY.id}
+
+TEST_PORT2 = {'port_id': 'test_port_id_2',
+             'qos_policy_id': TEST_POLICY2.id}
+
+
+class FakeDriver(qos.QosAgentDriver):
+
+    SUPPORTED_RULES = {qos_consts.RULE_TYPE_BANDWIDTH_LIMIT}
+
+    def __init__(self):
+        super(FakeDriver, self).__init__()
+        self.create_bandwidth_limit = mock.Mock()
+        self.update_bandwidth_limit = mock.Mock()
+        self.delete_bandwidth_limit = mock.Mock()
+
+    def initialize(self):
+        pass
+
+
+class QosFakeRule(rule.QosRule):
+
+    rule_type = 'fake_type'
+
+
+class QosAgentDriverTestCase(base.BaseTestCase):
+
+    def setUp(self):
+        super(QosAgentDriverTestCase, self).setUp()
+        self.driver = FakeDriver()
+        self.policy = TEST_POLICY
+        self.rule = (
+            rule.QosBandwidthLimitRule(context=None, id='fake_rule_id',
+                                       max_kbps=100, max_burst_kbps=200))
+        self.policy.rules = [self.rule]
+        self.port = object()
+        self.fake_rule = QosFakeRule(context=None, id='really_fake_rule_id')
+
+    def test_create(self):
+        self.driver.create(self.port, self.policy)
+        self.driver.create_bandwidth_limit.assert_called_with(
+            self.port, self.rule)
+
+    def test_update(self):
+        self.driver.update(self.port, self.policy)
+        self.driver.update_bandwidth_limit.assert_called_with(
+            self.port, self.rule)
+
+    def test_delete(self):
+        self.driver.delete(self.port, self.policy)
+        self.driver.delete_bandwidth_limit.assert_called_with(self.port)
+
+    def test_delete_no_policy(self):
+        self.driver.delete(self.port, qos_policy=None)
+        self.driver.delete_bandwidth_limit.assert_called_with(self.port)
+
+    def test__iterate_rules_with_unknown_rule_type(self):
+        self.policy.rules.append(self.fake_rule)
+        rules = list(self.driver._iterate_rules(self.policy.rules))
+        self.assertEqual(1, len(rules))
+        self.assertIsInstance(rules[0], rule.QosBandwidthLimitRule)
 
 
 class QosExtensionBaseTestCase(base.BaseTestCase):
@@ -55,9 +125,9 @@ class QosExtensionRpcTestCase(QosExtensionBaseTestCase):
             self.qos_ext.resource_rpc, 'pull',
             return_value=TEST_POLICY).start()
 
-    def _create_test_port_dict(self):
+    def _create_test_port_dict(self, qos_policy_id=None):
         return {'port_id': uuidutils.generate_uuid(),
-                'qos_policy_id': uuidutils.generate_uuid()}
+                'qos_policy_id': qos_policy_id or TEST_POLICY.id}
 
     def test_handle_port_with_no_policy(self):
         port = self._create_test_port_dict()
@@ -76,8 +146,10 @@ class QosExtensionRpcTestCase(QosExtensionBaseTestCase):
         self.qos_ext.qos_driver.create.assert_called_once_with(
             port, TEST_POLICY)
         self.assertEqual(port,
-            self.qos_ext.qos_policy_ports[qos_policy_id][port_id])
-        self.assertTrue(port_id in self.qos_ext.known_ports)
+            self.qos_ext.policy_map.qos_policy_ports[qos_policy_id][port_id])
+        self.assertIn(port_id, self.qos_ext.policy_map.port_policies)
+        self.assertEqual(TEST_POLICY,
+            self.qos_ext.policy_map.known_policies[qos_policy_id])
 
     def test_handle_known_port(self):
         port_obj1 = self._create_test_port_dict()
@@ -96,24 +168,20 @@ class QosExtensionRpcTestCase(QosExtensionBaseTestCase):
         self.pull_mock.assert_called_once_with(
              self.context, resources.QOS_POLICY,
              port['qos_policy_id'])
-        #TODO(QoS): handle qos_driver.update call check when
-        #           we do that
 
     def test_delete_known_port(self):
         port = self._create_test_port_dict()
-        port_id = port['port_id']
         self.qos_ext.handle_port(self.context, port)
         self.qos_ext.qos_driver.reset_mock()
         self.qos_ext.delete_port(self.context, port)
-        self.qos_ext.qos_driver.delete.assert_called_with(port, None)
-        self.assertNotIn(port_id, self.qos_ext.known_ports)
+        self.qos_ext.qos_driver.delete.assert_called_with(port)
+        self.assertIsNone(self.qos_ext.policy_map.get_port_policy(port))
 
     def test_delete_unknown_port(self):
         port = self._create_test_port_dict()
-        port_id = port['port_id']
         self.qos_ext.delete_port(self.context, port)
         self.assertFalse(self.qos_ext.qos_driver.delete.called)
-        self.assertNotIn(port_id, self.qos_ext.known_ports)
+        self.assertIsNone(self.qos_ext.policy_map.get_port_policy(port))
 
     def test__handle_notification_ignores_all_event_types_except_updated(self):
         with mock.patch.object(
@@ -127,47 +195,41 @@ class QosExtensionRpcTestCase(QosExtensionBaseTestCase):
         with mock.patch.object(
             self.qos_ext, '_process_update_policy') as update_mock:
 
-            policy = mock.Mock()
-            self.qos_ext._handle_notification(policy, events.UPDATED)
-            update_mock.assert_called_with(policy)
+            policy_obj = mock.Mock()
+            self.qos_ext._handle_notification(policy_obj, events.UPDATED)
+            update_mock.assert_called_with(policy_obj)
 
     def test__process_update_policy(self):
-        port1 = self._create_test_port_dict()
-        port2 = self._create_test_port_dict()
-        self.qos_ext.qos_policy_ports = {
-            port1['qos_policy_id']: {port1['port_id']: port1},
-            port2['qos_policy_id']: {port2['port_id']: port2},
-        }
-        policy = mock.Mock()
-        policy.id = port1['qos_policy_id']
-        self.qos_ext._process_update_policy(policy)
-        self.qos_ext.qos_driver.update.assert_called_with(port1, policy)
+        port1 = self._create_test_port_dict(qos_policy_id=TEST_POLICY.id)
+        port2 = self._create_test_port_dict(qos_policy_id=TEST_POLICY2.id)
+        self.qos_ext.policy_map.set_port_policy(port1, TEST_POLICY)
+        self.qos_ext.policy_map.set_port_policy(port2, TEST_POLICY2)
+
+        policy_obj = mock.Mock()
+        policy_obj.id = port1['qos_policy_id']
+        self.qos_ext._process_update_policy(policy_obj)
+        self.qos_ext.qos_driver.update.assert_called_with(port1, policy_obj)
 
         self.qos_ext.qos_driver.update.reset_mock()
-        policy.id = port2['qos_policy_id']
-        self.qos_ext._process_update_policy(policy)
-        self.qos_ext.qos_driver.update.assert_called_with(port2, policy)
+        policy_obj.id = port2['qos_policy_id']
+        self.qos_ext._process_update_policy(policy_obj)
+        self.qos_ext.qos_driver.update.assert_called_with(port2, policy_obj)
 
     def test__process_reset_port(self):
-        port1 = self._create_test_port_dict()
-        port2 = self._create_test_port_dict()
-        port1_id = port1['port_id']
-        port2_id = port2['port_id']
-        self.qos_ext.qos_policy_ports = {
-            port1['qos_policy_id']: {port1_id: port1},
-            port2['qos_policy_id']: {port2_id: port2},
-        }
-        self.qos_ext.known_ports = {port1_id, port2_id}
+        port1 = self._create_test_port_dict(qos_policy_id=TEST_POLICY.id)
+        port2 = self._create_test_port_dict(qos_policy_id=TEST_POLICY2.id)
+        self.qos_ext.policy_map.set_port_policy(port1, TEST_POLICY)
+        self.qos_ext.policy_map.set_port_policy(port2, TEST_POLICY2)
 
         self.qos_ext._process_reset_port(port1)
-        self.qos_ext.qos_driver.delete.assert_called_with(port1, None)
-        self.assertNotIn(port1_id, self.qos_ext.known_ports)
-        self.assertIn(port2_id, self.qos_ext.known_ports)
+        self.qos_ext.qos_driver.delete.assert_called_with(port1)
+        self.assertIsNone(self.qos_ext.policy_map.get_port_policy(port1))
+        self.assertIsNotNone(self.qos_ext.policy_map.get_port_policy(port2))
 
         self.qos_ext.qos_driver.delete.reset_mock()
         self.qos_ext._process_reset_port(port2)
-        self.qos_ext.qos_driver.delete.assert_called_with(port2, None)
-        self.assertNotIn(port2_id, self.qos_ext.known_ports)
+        self.qos_ext.qos_driver.delete.assert_called_with(port2)
+        self.assertIsNone(self.qos_ext.policy_map.get_port_policy(port2))
 
 
 class QosExtensionInitializeTestCase(QosExtensionBaseTestCase):
@@ -185,3 +247,60 @@ class QosExtensionInitializeTestCase(QosExtensionBaseTestCase):
              for resource_type in self.qos_ext.SUPPORTED_RESOURCES]
         )
         subscribe_mock.assert_called_with(mock.ANY, resources.QOS_POLICY)
+
+
+class PortPolicyMapTestCase(base.BaseTestCase):
+
+    def setUp(self):
+        super(PortPolicyMapTestCase, self).setUp()
+        self.policy_map = qos.PortPolicyMap()
+
+    def test_update_policy(self):
+        self.policy_map.update_policy(TEST_POLICY)
+        self.assertEqual(TEST_POLICY,
+                         self.policy_map.known_policies[TEST_POLICY.id])
+
+    def _set_ports(self):
+        self.policy_map.set_port_policy(TEST_PORT, TEST_POLICY)
+        self.policy_map.set_port_policy(TEST_PORT2, TEST_POLICY2)
+
+    def test_set_port_policy(self):
+        self._set_ports()
+        self.assertEqual(TEST_POLICY,
+                         self.policy_map.known_policies[TEST_POLICY.id])
+        self.assertIn(TEST_PORT['port_id'],
+                      self.policy_map.qos_policy_ports[TEST_POLICY.id])
+
+    def test_get_port_policy(self):
+        self._set_ports()
+        self.assertEqual(TEST_POLICY,
+                         self.policy_map.get_port_policy(TEST_PORT))
+        self.assertEqual(TEST_POLICY2,
+                         self.policy_map.get_port_policy(TEST_PORT2))
+
+    def test_get_ports(self):
+        self._set_ports()
+        self.assertEqual([TEST_PORT],
+                         list(self.policy_map.get_ports(TEST_POLICY)))
+
+        self.assertEqual([TEST_PORT2],
+                         list(self.policy_map.get_ports(TEST_POLICY2)))
+
+    def test_clean_by_port(self):
+        self._set_ports()
+        self.policy_map.clean_by_port(TEST_PORT)
+        self.assertNotIn(TEST_POLICY.id, self.policy_map.known_policies)
+        self.assertNotIn(TEST_PORT['port_id'], self.policy_map.port_policies)
+        self.assertIn(TEST_POLICY2.id, self.policy_map.known_policies)
+
+    def test_clean_by_port_raises_exception_for_unknown_port(self):
+        self.assertRaises(exceptions.PortNotFound,
+                          self.policy_map.clean_by_port, TEST_PORT)
+
+    def test_has_policy_changed(self):
+        self._set_ports()
+        self.assertTrue(
+            self.policy_map.has_policy_changed(TEST_PORT, 'a_new_policy_id'))
+
+        self.assertFalse(
+            self.policy_map.has_policy_changed(TEST_PORT, TEST_POLICY.id))