]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Propagate notifications to agent consumers callbacks
authorJakub Libosvar <libosvar@redhat.com>
Wed, 5 Aug 2015 18:15:26 +0000 (18:15 +0000)
committerIhar Hrachyshka <ihrachys@redhat.com>
Sat, 8 Aug 2015 08:41:32 +0000 (10:41 +0200)
The update policy works. We still need to track down the deletes which
don't work currently.

Change-Id: I48e04b42c07c34cf1daa17e7a29a6950453946ff
Partially-Implements: blueprint quantum-qos-api

19 files changed:
neutron/agent/l2/extensions/manager.py
neutron/agent/l2/extensions/qos.py
neutron/api/rpc/callbacks/consumer/registry.py
neutron/api/rpc/handlers/resources_rpc.py
neutron/objects/qos/policy.py
neutron/plugins/ml2/drivers/openvswitch/agent/extension_drivers/qos_driver.py
neutron/plugins/ml2/drivers/openvswitch/agent/ovs_neutron_agent.py
neutron/services/qos/notification_drivers/manager.py
neutron/services/qos/notification_drivers/message_queue.py
neutron/services/qos/notification_drivers/qos_base.py
neutron/services/qos/qos_plugin.py
neutron/tests/unit/agent/l2/extensions/test_manager.py
neutron/tests/unit/agent/l2/extensions/test_qos.py
neutron/tests/unit/api/rpc/callbacks/consumer/test_registry.py
neutron/tests/unit/api/rpc/handlers/test_resources_rpc.py
neutron/tests/unit/objects/qos/test_policy.py
neutron/tests/unit/services/qos/notification_drivers/test_manager.py
neutron/tests/unit/services/qos/notification_drivers/test_message_queue.py
neutron/tests/unit/services/qos/test_qos_plugin.py

index 6e1aa63709410cfb06467cd36b800ca288418ea2..2c77adbf8e935194bd81c6671767a2849fcf0862 100644 (file)
@@ -43,11 +43,11 @@ class AgentExtensionsManager(stevedore.named.NamedExtensionManager):
             invoke_on_load=True, name_order=True)
         LOG.info(_LI("Loaded agent extensions: %s"), self.names())
 
-    def initialize(self):
+    def initialize(self, connection):
         # Initialize each agent extension in the list.
         for extension in self:
             LOG.info(_LI("Initializing agent extension '%s'"), extension.name)
-            extension.obj.initialize()
+            extension.obj.initialize(connection)
 
     def handle_port(self, context, data):
         """Notify all agent extensions to handle port."""
index 6483d5aa9f0e163974dc90c56e8cf7f4bd3a7577..736cc1458a743353bd5fc6c9051eb5a051efc5af 100644 (file)
@@ -20,6 +20,8 @@ from oslo_config import cfg
 import six
 
 from neutron.agent.l2 import agent_extension
+from neutron.api.rpc.callbacks.consumer import registry
+from neutron.api.rpc.callbacks import events
 from neutron.api.rpc.callbacks import resources
 from neutron.api.rpc.handlers import resources_rpc
 from neutron import manager
@@ -70,7 +72,9 @@ class QosAgentDriver(object):
 
 
 class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
-    def initialize(self):
+    SUPPORTED_RESOURCES = [resources.QOS_POLICY]
+
+    def initialize(self, connection):
         """Perform Agent Extension initialization.
 
         """
@@ -80,22 +84,40 @@ class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
         self.qos_driver = manager.NeutronManager.load_class_for_provider(
             'neutron.qos.agent_drivers', cfg.CONF.qos.agent_driver)()
         self.qos_driver.initialize()
+
+        # we cannot use a dict of sets here because port dicts are not hashable
         self.qos_policy_ports = collections.defaultdict(dict)
         self.known_ports = set()
 
+        registry.subscribe(self._handle_notification, resources.QOS_POLICY)
+        self._register_rpc_consumers(connection)
+
+    def _register_rpc_consumers(self, connection):
+        endpoints = [resources_rpc.ResourcesPushRpcCallback()]
+        for resource_type in self.SUPPORTED_RESOURCES:
+            # we assume that neutron-server always broadcasts the latest
+            # version known to the agent
+            topic = resources_rpc.resource_type_versioned_topic(resource_type)
+            connection.create_consumer(topic, endpoints, fanout=True)
+
+    def _handle_notification(self, qos_policy, event_type):
+        # server does not allow to remove a policy that is attached to any
+        # port, so we ignore DELETED events. Also, if we receive a CREATED
+        # event for a policy, it means that there are no ports so far that are
+        # attached to it. That's why we are interested in UPDATED events only
+        if event_type == events.UPDATED:
+            self._process_update_policy(qos_policy)
+
     def handle_port(self, context, port):
         """Handle agent QoS extension for port.
 
-        This method subscribes to qos_policy_id changes
-        with a callback and get all the qos_policy_ports and apply
-        them using the QoS driver.
-        Updates and delete event should be handle by the registered
-        callback.
+        This method applies a new policy to a port using the QoS driver.
+        Update events are handled in _handle_notification.
         """
         port_id = port['port_id']
         qos_policy_id = port.get('qos_policy_id')
         if qos_policy_id is None:
-            #TODO(QoS):  we should also handle removing policy
+            self._process_reset_port(port)
             return
 
         #Note(moshele) check if we have seen this port
@@ -104,23 +126,26 @@ class QosAgentExtension(agent_extension.AgentCoreResourceExtension):
                 port_id in self.qos_policy_ports[qos_policy_id]):
             return
 
+        # TODO(QoS): handle race condition between push and pull APIs
         self.qos_policy_ports[qos_policy_id][port_id] = port
         self.known_ports.add(port_id)
-        #TODO(QoS): handle updates when implemented
-        # we have two options:
-        # 1. to add new api for subscribe
-        #    registry.subscribe(self._process_policy_updates,
-        #                   resources.QOS_POLICY, qos_policy_id)
-        # 2. combine pull rpc to also subscribe to the resource
         qos_policy = self.resource_rpc.pull(
-            context,
-            resources.QOS_POLICY,
-            qos_policy_id)
-        self._process_policy_updates(
-            port, resources.QOS_POLICY, qos_policy_id,
-            qos_policy, 'create')
-
-    def _process_policy_updates(
-            self, port, resource_type, resource_id,
-            qos_policy, action_type):
-        getattr(self.qos_driver, action_type)(port, qos_policy)
+            context, resources.QOS_POLICY, qos_policy_id)
+        self.qos_driver.create(port, qos_policy)
+
+    def _process_update_policy(self, qos_policy):
+        for port_id, port in self.qos_policy_ports[qos_policy.id].items():
+            # TODO(QoS): for now, just reflush the rules on the port. Later, we
+            # may want to apply the difference between the rules lists only.
+            self.qos_driver.delete(port, None)
+            self.qos_driver.update(port, qos_policy)
+
+    def _process_reset_port(self, port):
+        port_id = port['port_id']
+        if port_id in self.known_ports:
+            self.known_ports.remove(port_id)
+            for qos_policy_id, port_dict in self.qos_policy_ports.items():
+                if port_id in port_dict:
+                    del port_dict[port_id]
+                    self.qos_driver.delete(port, None)
+                    return
index 454e423a083c3c6370a523f567873b66c926a017..3f6c5754f05cd5427379f8d2c3ca47dfe0424680 100644 (file)
@@ -37,7 +37,7 @@ def push(resource_type, resource, event_type):
 
     callbacks = _get_manager().get_callbacks(resource_type)
     for callback in callbacks:
-        callback(resource_type, resource, event_type)
+        callback(resource, event_type)
 
 
 def clear():
index dd20eb3c60bbb3db11b5feab1597815d132e4a13..c3c9afe0454729deb8927cbd280d912aaf08d382 100755 (executable)
@@ -48,6 +48,13 @@ def _validate_resource_type(resource_type):
         raise InvalidResourceTypeClass(resource_type=resource_type)
 
 
+def resource_type_versioned_topic(resource_type):
+    _validate_resource_type(resource_type)
+    cls = resources.get_resource_cls(resource_type)
+    return topics.RESOURCE_TOPIC_PATTERN % {'resource_type': resource_type,
+                                            'version': cls.VERSION}
+
+
 class ResourcesPullRpcApi(object):
     """Agent-side RPC (stub) for agent-to-plugin interaction.
 
@@ -113,12 +120,6 @@ class ResourcesPullRpcCallback(object):
             return obj.obj_to_primitive(target_version=version)
 
 
-def _object_topic(obj):
-    resource_type = resources.get_resource_type(obj)
-    return topics.RESOURCE_TOPIC_PATTERN % {
-        'resource_type': resource_type, 'version': obj.VERSION}
-
-
 class ResourcesPushRpcApi(object):
     """Plugin-side RPC for plugin-to-agents interaction.
 
@@ -137,7 +138,7 @@ class ResourcesPushRpcApi(object):
 
     def _prepare_object_fanout_context(self, obj):
         """Prepare fanout context, one topic per object type."""
-        obj_topic = _object_topic(obj)
+        obj_topic = resource_type_versioned_topic(obj.obj_name())
         return self.client.prepare(fanout=True, topic=obj_topic)
 
     @log_helpers.log_method_call
index b3b7a44e375018ffe73b059a5f9b4da2ee8e4b46..96d1536e8da56afb6bac7504e509ee3fedde999b 100644 (file)
@@ -56,12 +56,13 @@ class QosPolicy(base.NeutronDbObject):
             raise exceptions.ObjectActionError(
                 action='obj_load_attr', reason='unable to load %s' % attrname)
 
-        rules = rule_obj_impl.get_rules(self._context, self.id)
-        setattr(self, attrname, rules)
-        self.obj_reset_changes([attrname])
+        if not hasattr(self, attrname):
+            self.reload_rules()
 
-    def _load_rules(self):
-        self.obj_load_attr('rules')
+    def reload_rules(self):
+        rules = rule_obj_impl.get_rules(self._context, self.id)
+        setattr(self, 'rules', rules)
+        self.obj_reset_changes(['rules'])
 
     @staticmethod
     def _is_policy_accessible(context, db_obj):
@@ -82,7 +83,7 @@ class QosPolicy(base.NeutronDbObject):
                 not cls._is_policy_accessible(context, policy_obj)):
                 return
 
-            policy_obj._load_rules()
+            policy_obj.reload_rules()
             return policy_obj
 
     @classmethod
@@ -97,7 +98,7 @@ class QosPolicy(base.NeutronDbObject):
                 if not cls._is_policy_accessible(context, db_obj):
                     continue
                 obj = cls(context, **db_obj)
-                obj._load_rules()
+                obj.reload_rules()
                 objs.append(obj)
         return objs
 
@@ -122,7 +123,7 @@ class QosPolicy(base.NeutronDbObject):
     def create(self):
         with db_api.autonested_transaction(self._context.session):
             super(QosPolicy, self).create()
-            self._load_rules()
+            self.reload_rules()
 
     def delete(self):
         models = (
index c947748115694e2c18cf5c0df44006d6f281c41a..2584611d5f7c40d22dbc4dab3375251fb45f3e00 100644 (file)
@@ -46,7 +46,9 @@ class QosOVSAgentDriver(qos.QosAgentDriver):
         self._handle_rules('update', port, qos_policy)
 
     def delete(self, port, qos_policy):
-        self._handle_rules('delete', port, qos_policy)
+        # TODO(QoS): consider optimizing flushing of all QoS rules from the
+        # port by inspecting qos_policy.rules contents
+        self._delete_bandwidth_limit(port)
 
     def _handle_rules(self, action, port, qos_policy):
         for rule in qos_policy.rules:
@@ -76,7 +78,7 @@ class QosOVSAgentDriver(qos.QosAgentDriver):
                                                  max_kbps,
                                                  max_burst_kbps)
 
-    def _delete_bandwidth_limit(self, port, rule):
+    def _delete_bandwidth_limit(self, port):
         port_name = port['vif_port'].port_name
         current_max_kbps, current_max_burst = (
             self.br_int.get_qos_bw_limit_for_port(port_name))
index d07532bad9bf31ef91f7951b060a984e8d86cea4..a5190f9a39657756fcaf463ba780eea13f809e4d 100644 (file)
@@ -226,7 +226,7 @@ class OVSNeutronAgent(sg_rpc.SecurityGroupAgentRpcCallbackMixin,
         # keeps association between ports and ofports to detect ofport change
         self.vifname_to_ofport_map = {}
         self.setup_rpc()
-        self.init_extension_manager()
+        self.init_extension_manager(self.connection)
         self.bridge_mappings = bridge_mappings
         self.setup_physical_bridges(self.bridge_mappings)
         self.local_vlan_map = {}
@@ -367,11 +367,11 @@ class OVSNeutronAgent(sg_rpc.SecurityGroupAgentRpcCallbackMixin,
                                                      consumers,
                                                      start_listening=False)
 
-    def init_extension_manager(self):
+    def init_extension_manager(self, connection):
         ext_manager.register_opts(self.conf)
         self.ext_manager = (
             ext_manager.AgentExtensionsManager(self.conf))
-        self.ext_manager.initialize()
+        self.ext_manager.initialize(connection)
 
     def get_net_uuid(self, vif_id):
         for network_id, vlan_mapping in six.iteritems(self.local_vlan_map):
index 2dd5e11977ba363b73177a5dac50278a7cb426b0..d027c1945c7cf2482c638b59a1bc00c6f11ff2e2 100644 (file)
@@ -33,17 +33,17 @@ class QosServiceNotificationDriverManager(object):
         self.notification_drivers = []
         self._load_drivers(cfg.CONF.qos.notification_drivers)
 
-    def update_policy(self, qos_policy):
+    def update_policy(self, context, qos_policy):
         for driver in self.notification_drivers:
-            driver.update_policy(qos_policy)
+            driver.update_policy(context, qos_policy)
 
-    def delete_policy(self, qos_policy):
+    def delete_policy(self, context, qos_policy):
         for driver in self.notification_drivers:
-            driver.delete_policy(qos_policy)
+            driver.delete_policy(context, qos_policy)
 
-    def create_policy(self, qos_policy):
+    def create_policy(self, context, qos_policy):
         for driver in self.notification_drivers:
-            driver.create_policy(qos_policy)
+            driver.create_policy(context, qos_policy)
 
     def _load_drivers(self, notification_drivers):
         """Load all the instances of the configured QoS notification drivers
index aa804f72306c96a0d255e2c005b93d930ff98f1f..1af63f9ac3c2459d6dc0057465f3f4df81204494 100644 (file)
 
 from oslo_log import log as logging
 
+from neutron.api.rpc.callbacks import events
 from neutron.api.rpc.callbacks.producer import registry
 from neutron.api.rpc.callbacks import resources
+from neutron.api.rpc.handlers import resources_rpc
 from neutron.i18n import _LW
 from neutron.objects.qos import policy as policy_object
 from neutron.services.qos.notification_drivers import qos_base
@@ -40,19 +42,18 @@ class RpcQosServiceNotificationDriver(
     """RPC message queue service notification driver for QoS."""
 
     def __init__(self):
+        self.notification_api = resources_rpc.ResourcesPushRpcApi()
         registry.provide(_get_qos_policy_cb, resources.QOS_POLICY)
 
     def get_description(self):
         return "Message queue updates"
 
-    def create_policy(self, policy):
+    def create_policy(self, context, policy):
         #No need to update agents on create
         pass
 
-    def update_policy(self, policy):
-        # TODO(QoS): implement notification
-        pass
+    def update_policy(self, context, policy):
+        self.notification_api.push(context, policy, events.UPDATED)
 
-    def delete_policy(self, policy):
-        # TODO(QoS): implement notification
-        pass
+    def delete_policy(self, context, policy):
+        self.notification_api.push(context, policy, events.DELETED)
index d87870272f4d9c5c8be246f9d97b5944d6500eb2..50f98f0c4b43df17842c0569f06b55fc51b908ca 100644 (file)
@@ -24,18 +24,18 @@ class QosServiceNotificationDriverBase(object):
         """
 
     @abc.abstractmethod
-    def create_policy(self, policy):
+    def create_policy(self, context, policy):
         """Create the QoS policy."""
 
     @abc.abstractmethod
-    def update_policy(self, policy):
+    def update_policy(self, context, policy):
         """Update the QoS policy.
 
         Apply changes to the QoS policy.
         """
 
     @abc.abstractmethod
-    def delete_policy(self, policy):
+    def delete_policy(self, context, policy):
         """Delete the QoS policy.
 
         Remove all rules for this policy and free up all the resources.
index 0b91d46b9c21f3723206953fd9366049fbb73d6e..7111c4e94b3df131227dce27e053b83091df497b 100644 (file)
@@ -16,6 +16,7 @@ from oslo_log import log as logging
 
 
 from neutron.common import exceptions as n_exc
+from neutron.db import api as db_api
 from neutron.db import db_base_plugin_common
 from neutron.extensions import qos
 from neutron.objects.qos import policy as policy_object
@@ -46,7 +47,7 @@ class QoSPlugin(qos.QoSPluginBase):
     def create_policy(self, context, policy):
         policy = policy_object.QosPolicy(context, **policy['policy'])
         policy.create()
-        self.notification_driver_manager.create_policy(policy)
+        self.notification_driver_manager.create_policy(context, policy)
         return policy
 
     @db_base_plugin_common.convert_result_to_dict
@@ -54,14 +55,14 @@ class QoSPlugin(qos.QoSPluginBase):
         policy = policy_object.QosPolicy(context, **policy['policy'])
         policy.id = policy_id
         policy.update()
-        self.notification_driver_manager.update_policy(policy)
+        self.notification_driver_manager.update_policy(context, policy)
         return policy
 
     def delete_policy(self, context, policy_id):
         policy = policy_object.QosPolicy(context)
         policy.id = policy_id
+        self.notification_driver_manager.delete_policy(context, policy)
         policy.delete()
-        self.notification_driver_manager.delete_policy(policy)
 
     def _get_policy_obj(self, context, policy_id):
         obj = policy_object.QosPolicy.get_by_id(context, policy_id)
@@ -89,42 +90,54 @@ class QoSPlugin(qos.QoSPluginBase):
     @db_base_plugin_common.convert_result_to_dict
     def create_policy_bandwidth_limit_rule(self, context, policy_id,
                                            bandwidth_limit_rule):
-        # validate that we have access to the policy
-        policy = self._get_policy_obj(context, policy_id)
-        rule = rule_object.QosBandwidthLimitRule(
-            context, qos_policy_id=policy_id,
-            **bandwidth_limit_rule['bandwidth_limit_rule'])
-        rule.create()
-        self.notification_driver_manager.update_policy(policy)
+        # make sure we will have a policy object to push resource update
+        with db_api.autonested_transaction(context.session):
+            # first, validate that we have access to the policy
+            policy = self._get_policy_obj(context, policy_id)
+            rule = rule_object.QosBandwidthLimitRule(
+                context, qos_policy_id=policy_id,
+                **bandwidth_limit_rule['bandwidth_limit_rule'])
+            rule.create()
+            policy.reload_rules()
+        self.notification_driver_manager.update_policy(context, policy)
         return rule
 
     @db_base_plugin_common.convert_result_to_dict
     def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
                                            bandwidth_limit_rule):
-        # validate that we have access to the policy
-        policy = self._get_policy_obj(context, policy_id)
-        rule = rule_object.QosBandwidthLimitRule(
-            context, **bandwidth_limit_rule['bandwidth_limit_rule'])
-        rule.id = rule_id
-        rule.update()
-        self.notification_driver_manager.update_policy(policy)
+        # make sure we will have a policy object to push resource update
+        with db_api.autonested_transaction(context.session):
+            # first, validate that we have access to the policy
+            policy = self._get_policy_obj(context, policy_id)
+            rule = rule_object.QosBandwidthLimitRule(
+                context, **bandwidth_limit_rule['bandwidth_limit_rule'])
+            rule.id = rule_id
+            rule.update()
+            policy.reload_rules()
+        self.notification_driver_manager.update_policy(context, policy)
         return rule
 
     def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
-        # validate that we have access to the policy
-        policy = self._get_policy_obj(context, policy_id)
-        rule = rule_object.QosBandwidthLimitRule(context)
-        rule.id = rule_id
-        rule.delete()
-        self.notification_driver_manager.update_policy(policy)
+        # make sure we will have a policy object to push resource update
+        with db_api.autonested_transaction(context.session):
+            # first, validate that we have access to the policy
+            policy = self._get_policy_obj(context, policy_id)
+            rule = rule_object.QosBandwidthLimitRule(context)
+            rule.id = rule_id
+            rule.delete()
+            policy.reload_rules()
+        self.notification_driver_manager.update_policy(context, policy)
 
     @db_base_plugin_common.filter_fields
     @db_base_plugin_common.convert_result_to_dict
     def get_policy_bandwidth_limit_rule(self, context, rule_id,
                                         policy_id, fields=None):
-        # validate that we have access to the policy
-        self._get_policy_obj(context, policy_id)
-        rule = rule_object.QosBandwidthLimitRule.get_by_id(context, rule_id)
+        # make sure we have access to the policy when fetching the rule
+        with db_api.autonested_transaction(context.session):
+            # first, validate that we have access to the policy
+            self._get_policy_obj(context, policy_id)
+            rule = rule_object.QosBandwidthLimitRule.get_by_id(
+                context, rule_id)
         if not rule:
             raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id)
         return rule
@@ -136,9 +149,11 @@ class QoSPlugin(qos.QoSPluginBase):
                                          sorts=None, limit=None,
                                          marker=None, page_reverse=False):
         #TODO(QoS): Support all the optional parameters
-        # validate that we have access to the policy
-        self._get_policy_obj(context, policy_id)
-        return rule_object.QosBandwidthLimitRule.get_objects(context)
+        # make sure we have access to the policy when fetching rules
+        with db_api.autonested_transaction(context.session):
+            # first, validate that we have access to the policy
+            self._get_policy_obj(context, policy_id)
+            return rule_object.QosBandwidthLimitRule.get_objects(context)
 
     # TODO(QoS): enforce rule types when accessing rule objects
     @db_base_plugin_common.filter_fields
index 54dd0603d545b1da8e514f0c65a4de364ebd04cd..3aa8ea58ba162d40450efda7c7b99a2c58b5720a 100644 (file)
@@ -32,9 +32,10 @@ class TestAgentExtensionsManager(base.BaseTestCase):
         return self.manager.extensions[0].obj
 
     def test_initialize(self):
-        self.manager.initialize()
+        connection = object()
+        self.manager.initialize(connection)
         ext = self._get_extension()
-        self.assertTrue(ext.initialize.called)
+        ext.initialize.assert_called_once_with(connection)
 
     def test_handle_port(self):
         context = object()
index 006044bf36998c371e8713291fd7568aa1ba2332..d78fc3121b18b68197f1f15901f03376e3193c45 100755 (executable)
@@ -17,21 +17,25 @@ import mock
 from oslo_utils import uuidutils
 
 from neutron.agent.l2.extensions import qos
+from neutron.api.rpc.callbacks.consumer import registry
+from neutron.api.rpc.callbacks import events
 from neutron.api.rpc.callbacks import resources
+from neutron.api.rpc.handlers import resources_rpc
 from neutron import context
+from neutron.plugins.ml2.drivers.openvswitch.agent.common import config  # noqa
 from neutron.tests import base
 
-# This is a minimalistic mock of rules to be passed/checked around
-# which should be exteneded as needed to make real rules
-TEST_GET_RESOURCE_RULES = ['rule1', 'rule2']
 
+TEST_POLICY = object()
 
-class QosAgentExtensionTestCase(base.BaseTestCase):
+
+class QosExtensionBaseTestCase(base.BaseTestCase):
 
     def setUp(self):
-        super(QosAgentExtensionTestCase, self).setUp()
+        super(QosExtensionBaseTestCase, self).setUp()
         self.qos_ext = qos.QosAgentExtension()
         self.context = context.get_admin_context()
+        self.connection = mock.Mock()
 
         # Don't rely on used driver
         mock.patch(
@@ -39,11 +43,16 @@ class QosAgentExtensionTestCase(base.BaseTestCase):
             return_value=lambda: mock.Mock(spec=qos.QosAgentDriver)
         ).start()
 
-        self.qos_ext.initialize()
+
+class QosExtensionRpcTestCase(QosExtensionBaseTestCase):
+
+    def setUp(self):
+        super(QosExtensionRpcTestCase, self).setUp()
+        self.qos_ext.initialize(self.connection)
 
         self.pull_mock = mock.patch.object(
             self.qos_ext.resource_rpc, 'pull',
-            return_value=TEST_GET_RESOURCE_RULES).start()
+            return_value=TEST_POLICY).start()
 
     def _create_test_port_dict(self):
         return {'port_id': uuidutils.generate_uuid(),
@@ -52,9 +61,9 @@ class QosAgentExtensionTestCase(base.BaseTestCase):
     def test_handle_port_with_no_policy(self):
         port = self._create_test_port_dict()
         del port['qos_policy_id']
-        self.qos_ext._process_rules_updates = mock.Mock()
+        self.qos_ext._process_reset_port = mock.Mock()
         self.qos_ext.handle_port(self.context, port)
-        self.assertFalse(self.qos_ext._process_rules_updates.called)
+        self.qos_ext._process_reset_port.assert_called_with(port)
 
     def test_handle_unknown_port(self):
         port = self._create_test_port_dict()
@@ -64,7 +73,7 @@ class QosAgentExtensionTestCase(base.BaseTestCase):
         # we make sure the underlaying qos driver is called with the
         # right parameters
         self.qos_ext.qos_driver.create.assert_called_once_with(
-            port, TEST_GET_RESOURCE_RULES)
+            port, TEST_POLICY)
         self.assertEqual(port,
             self.qos_ext.qos_policy_ports[qos_policy_id][port_id])
         self.assertTrue(port_id in self.qos_ext.known_ports)
@@ -88,3 +97,73 @@ class QosAgentExtensionTestCase(base.BaseTestCase):
              port['qos_policy_id'])
         #TODO(QoS): handle qos_driver.update call check when
         #           we do that
+
+    def test__handle_notification_ignores_all_event_types_except_updated(self):
+        with mock.patch.object(
+            self.qos_ext, '_process_update_policy') as update_mock:
+
+            for event_type in set(events.VALID) - {events.UPDATED}:
+                self.qos_ext._handle_notification(object(), event_type)
+                self.assertFalse(update_mock.called)
+
+    def test__handle_notification_passes_update_events(self):
+        with mock.patch.object(
+            self.qos_ext, '_process_update_policy') as update_mock:
+
+            policy = mock.Mock()
+            self.qos_ext._handle_notification(policy, events.UPDATED)
+            update_mock.assert_called_with(policy)
+
+    def test__process_update_policy(self):
+        port1 = self._create_test_port_dict()
+        port2 = self._create_test_port_dict()
+        self.qos_ext.qos_policy_ports = {
+            port1['qos_policy_id']: {port1['port_id']: port1},
+            port2['qos_policy_id']: {port2['port_id']: port2},
+        }
+        policy = mock.Mock()
+        policy.id = port1['qos_policy_id']
+        self.qos_ext._process_update_policy(policy)
+        self.qos_ext.qos_driver.update.assert_called_with(port1, policy)
+
+        self.qos_ext.qos_driver.update.reset_mock()
+        policy.id = port2['qos_policy_id']
+        self.qos_ext._process_update_policy(policy)
+        self.qos_ext.qos_driver.update.assert_called_with(port2, policy)
+
+    def test__process_reset_port(self):
+        port1 = self._create_test_port_dict()
+        port2 = self._create_test_port_dict()
+        port1_id = port1['port_id']
+        port2_id = port2['port_id']
+        self.qos_ext.qos_policy_ports = {
+            port1['qos_policy_id']: {port1_id: port1},
+            port2['qos_policy_id']: {port2_id: port2},
+        }
+        self.qos_ext.known_ports = {port1_id, port2_id}
+
+        self.qos_ext._process_reset_port(port1)
+        self.qos_ext.qos_driver.delete.assert_called_with(port1, None)
+        self.assertNotIn(port1_id, self.qos_ext.known_ports)
+        self.assertIn(port2_id, self.qos_ext.known_ports)
+
+        self.qos_ext.qos_driver.delete.reset_mock()
+        self.qos_ext._process_reset_port(port2)
+        self.qos_ext.qos_driver.delete.assert_called_with(port2, None)
+        self.assertNotIn(port2_id, self.qos_ext.known_ports)
+
+
+class QosExtensionInitializeTestCase(QosExtensionBaseTestCase):
+
+    @mock.patch.object(registry, 'subscribe')
+    @mock.patch.object(resources_rpc, 'ResourcesPushRpcCallback')
+    def test_initialize_subscribed_to_rpc(self, rpc_mock, subscribe_mock):
+        self.qos_ext.initialize(self.connection)
+        self.connection.create_consumer.assert_has_calls(
+            [mock.call(
+                 resources_rpc.resource_type_versioned_topic(resource_type),
+                 [rpc_mock()],
+                 fanout=True)
+             for resource_type in self.qos_ext.SUPPORTED_RESOURCES]
+        )
+        subscribe_mock.assert_called_with(mock.ANY, resources.QOS_POLICY)
index 5d18e539fd7561ce8495c54a04e10e416fcf1221..d07b49c2fd5dac8e095f54c97abca21b742de229 100644 (file)
@@ -53,4 +53,4 @@ class ConsumerRegistryTestCase(base.BaseTestCase):
         manager_mock().get_callbacks.return_value = callbacks
         registry.push(resource_type_, resource_, event_type_)
         for callback in callbacks:
-            callback.assert_called_with(resource_type_, resource_, event_type_)
+            callback.assert_called_with(resource_, event_type_)
index 9a6ccd4a6f077968ab09a217ff3516eaa578592e..4fd58afa2656841e2e4ca533dead3531d796872c 100755 (executable)
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 import mock
-from oslo_utils import uuidutils
 from oslo_versionedobjects import base as obj_base
 from oslo_versionedobjects import fields as obj_fields
 import testtools
@@ -27,6 +26,18 @@ from neutron.objects import base as objects_base
 from neutron.tests import base
 
 
+def _create_test_dict():
+    return {'id': 'uuid',
+            'field': 'foo'}
+
+
+def _create_test_resource(context=None):
+    resource_dict = _create_test_dict()
+    resource = FakeResource(context, **resource_dict)
+    resource.obj_reset_changes()
+    return resource
+
+
 @obj_base.VersionedObjectRegistry.register
 class FakeResource(objects_base.NeutronObject):
 
@@ -46,15 +57,6 @@ class ResourcesRpcBaseTestCase(base.BaseTestCase):
         super(ResourcesRpcBaseTestCase, self).setUp()
         self.context = context.get_admin_context()
 
-    def _create_test_dict(self):
-        return {'id': uuidutils.generate_uuid(),
-                'field': 'foo'}
-
-    def _create_test_resource(self, **kwargs):
-        resource = FakeResource(self.context, **kwargs)
-        resource.obj_reset_changes()
-        return resource
-
 
 class _ValidateResourceTypeTestCase(base.BaseTestCase):
     def setUp(self):
@@ -73,6 +75,19 @@ class _ValidateResourceTypeTestCase(base.BaseTestCase):
             resources_rpc._validate_resource_type('foo')
 
 
+class _ResourceTypeVersionedTopicTestCase(base.BaseTestCase):
+
+    @mock.patch.object(resources_rpc, '_validate_resource_type')
+    def test_resource_type_versioned_topic(self, validate_mock):
+        obj_name = FakeResource.obj_name()
+        expected = topics.RESOURCE_TOPIC_PATTERN % {
+            'resource_type': 'FakeResource', 'version': '1.0'}
+        with mock.patch.object(resources_rpc.resources, 'get_resource_cls',
+                return_value=FakeResource):
+            observed = resources_rpc.resource_type_versioned_topic(obj_name)
+        self.assertEqual(expected, observed)
+
+
 class ResourcesPullRpcApiTestCase(ResourcesRpcBaseTestCase):
 
     def setUp(self):
@@ -85,13 +100,11 @@ class ResourcesPullRpcApiTestCase(ResourcesRpcBaseTestCase):
         self.cctxt_mock = self.rpc.client.prepare.return_value
 
     def test_is_singleton(self):
-        self.assertEqual(id(self.rpc),
-                         id(resources_rpc.ResourcesPullRpcApi()))
+        self.assertIs(self.rpc, resources_rpc.ResourcesPullRpcApi())
 
     def test_pull(self):
-        resource_dict = self._create_test_dict()
-        expected_obj = self._create_test_resource(**resource_dict)
-        resource_id = resource_dict['id']
+        expected_obj = _create_test_resource(self.context)
+        resource_id = expected_obj.id
         self.cctxt_mock.call.return_value = expected_obj.obj_to_primitive()
 
         result = self.rpc.pull(
@@ -103,7 +116,7 @@ class ResourcesPullRpcApiTestCase(ResourcesRpcBaseTestCase):
         self.assertEqual(expected_obj, result)
 
     def test_pull_resource_not_found(self):
-        resource_dict = self._create_test_dict()
+        resource_dict = _create_test_dict()
         resource_id = resource_dict['id']
         self.cctxt_mock.call.return_value = None
         with testtools.ExpectedException(resources_rpc.ResourceNotFound):
@@ -116,20 +129,20 @@ class ResourcesPullRpcCallbackTestCase(ResourcesRpcBaseTestCase):
     def setUp(self):
         super(ResourcesPullRpcCallbackTestCase, self).setUp()
         self.callbacks = resources_rpc.ResourcesPullRpcCallback()
-        self.resource_dict = self._create_test_dict()
-        self.resource_obj = self._create_test_resource(**self.resource_dict)
+        self.resource_obj = _create_test_resource(self.context)
 
     def test_pull(self):
+        resource_dict = _create_test_dict()
         with mock.patch.object(
                 resources_rpc.prod_registry, 'pull',
                 return_value=self.resource_obj) as registry_mock:
             primitive = self.callbacks.pull(
                 self.context, resource_type=FakeResource.obj_name(),
                 version=FakeResource.VERSION,
-                resource_id=self.resource_dict['id'])
+                resource_id=self.resource_obj.id)
         registry_mock.assert_called_once_with(
-            'FakeResource', self.resource_dict['id'], context=self.context)
-        self.assertEqual(self.resource_dict,
+            'FakeResource', self.resource_obj.id, context=self.context)
+        self.assertEqual(resource_dict,
                          primitive['versioned_object.data'])
         self.assertEqual(self.resource_obj.obj_to_primitive(), primitive)
 
@@ -150,7 +163,7 @@ class ResourcesPullRpcCallbackTestCase(ResourcesRpcBaseTestCase):
             self.callbacks.pull(
                 self.context, resource_type=FakeResource.obj_name(),
                 version='0.9',  # less than initial version 1.0
-                resource_id=self.resource_dict['id'])
+                resource_id=self.resource_obj.id)
             to_prim_mock.assert_called_with(target_version='0.9')
 
 
@@ -162,23 +175,27 @@ class ResourcesPushRpcApiTestCase(ResourcesRpcBaseTestCase):
         mock.patch.object(resources_rpc, '_validate_resource_type').start()
         self.rpc = resources_rpc.ResourcesPushRpcApi()
         self.cctxt_mock = self.rpc.client.prepare.return_value
-        resource_dict = self._create_test_dict()
-        self.resource_obj = self._create_test_resource(**resource_dict)
+        self.resource_obj = _create_test_resource(self.context)
 
     def test__prepare_object_fanout_context(self):
         expected_topic = topics.RESOURCE_TOPIC_PATTERN % {
             'resource_type': resources.get_resource_type(self.resource_obj),
             'version': self.resource_obj.VERSION}
 
-        observed = self.rpc._prepare_object_fanout_context(self.resource_obj)
+        with mock.patch.object(resources_rpc.resources, 'get_resource_cls',
+                return_value=FakeResource):
+            observed = self.rpc._prepare_object_fanout_context(
+                self.resource_obj)
 
         self.rpc.client.prepare.assert_called_once_with(
             fanout=True, topic=expected_topic)
         self.assertEqual(self.cctxt_mock, observed)
 
-    def test_push(self):
-        self.rpc.push(
-            self.context, self.resource_obj, 'TYPE')
+    def test_pushy(self):
+        with mock.patch.object(resources_rpc.resources, 'get_resource_cls',
+                return_value=FakeResource):
+            self.rpc.push(
+                self.context, self.resource_obj, 'TYPE')
 
         self.cctxt_mock.cast.assert_called_once_with(
             self.context, 'push',
@@ -194,8 +211,7 @@ class ResourcesPushRpcCallbackTestCase(ResourcesRpcBaseTestCase):
         mock.patch.object(
             resources_rpc.resources,
             'get_resource_cls', return_value=FakeResource).start()
-        resource_dict = self._create_test_dict()
-        self.resource_obj = self._create_test_resource(**resource_dict)
+        self.resource_obj = _create_test_resource(self.context)
         self.resource_prim = self.resource_obj.obj_to_primitive()
         self.callbacks = resources_rpc.ResourcesPushRpcCallback()
 
index 0af07e9d1b106bb6fc4aca0c00ac38ab18f770f7..97af37bbb2f9035e6b5b8e407b1fb924ef74c1c7 100644 (file)
@@ -265,3 +265,10 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
 
         obj.detach_network(self._network['id'])
         obj.delete()
+
+    def test_reload_rules_reloads_rules(self):
+        policy_obj, rule_obj = self._create_test_policy_with_rule()
+        self.assertEqual([], policy_obj.rules)
+
+        policy_obj.reload_rules()
+        self.assertEqual([rule_obj], policy_obj.rules)
index efc1cbbbb030b5d2b2fd08d0c9f6069207fbc172..c46e99a24db13d94bd18cdb8424952d3e5e3c5e5 100644 (file)
@@ -46,7 +46,8 @@ class TestQosDriversManagerBase(base.BaseQosTestCase):
                             'description': 'test policy description',
                             'shared': True}}
 
-        self.policy = policy_object.QosPolicy(context,
+        self.context = context.get_admin_context()
+        self.policy = policy_object.QosPolicy(self.context,
                         **self.policy_data['policy'])
         ctxt = None
         self.kwargs = {'context': ctxt}
@@ -56,24 +57,30 @@ class TestQosDriversManager(TestQosDriversManagerBase):
 
     def setUp(self):
         super(TestQosDriversManager, self).setUp()
+        #TODO(Qos): Fix this unittest to test manager and not message_queue
+        #           notification driver
+        rpc_api_cls = mock.patch('neutron.api.rpc.handlers.resources_rpc'
+                                 '.ResourcesPushRpcApi').start()
+        self.rpc_api = rpc_api_cls.return_value
         self.driver_manager = driver_mgr.QosServiceNotificationDriverManager()
 
     def _validate_registry_params(self, event_type, policy):
-        #TODO(QoS): actually validate the notification once implemented
-        pass
+        self.rpc_api.push.assert_called_with(self.context, policy,
+                                             event_type)
 
     def test_create_policy_default_configuration(self):
         #RPC driver should be loaded by default
-        self.driver_manager.create_policy(self.policy)
+        self.driver_manager.create_policy(self.context, self.policy)
+        self.assertFalse(self.rpc_api.push.called)
 
     def test_update_policy_default_configuration(self):
         #RPC driver should be loaded by default
-        self.driver_manager.update_policy(self.policy)
+        self.driver_manager.update_policy(self.context, self.policy)
         self._validate_registry_params(events.UPDATED, self.policy)
 
     def test_delete_policy_default_configuration(self):
         #RPC driver should be loaded by default
-        self.driver_manager.delete_policy(self.policy)
+        self.driver_manager.delete_policy(self.context, self.policy)
         self._validate_registry_params(events.DELETED, self.policy)
 
 
@@ -86,9 +93,9 @@ class TestQosDriversManagerMulti(TestQosDriversManagerBase):
         with mock.patch('.'.join([DUMMY_DRIVER, handler])) as dummy_mock:
             rpc_driver = message_queue.RpcQosServiceNotificationDriver
             with mock.patch.object(rpc_driver, handler) as rpc_mock:
-                getattr(driver_manager, handler)(self.policy)
+                getattr(driver_manager, handler)(self.context, self.policy)
         for mock_ in (dummy_mock, rpc_mock):
-            mock_.assert_called_with(self.policy)
+            mock_.assert_called_with(self.context, self.policy)
 
     def test_multi_drivers_configuration_create(self):
         self._test_multi_drivers_configuration_op('create')
index 710451307a9f6ddf9b2ce366e17a8a582387ce26..0a95cae4108da211e2b909badfdb2cefa682c886 100644 (file)
@@ -10,6 +10,8 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import mock
+
 from neutron.api.rpc.callbacks import events
 from neutron import context
 from neutron.objects.qos import policy as policy_object
@@ -24,6 +26,9 @@ class TestQosRpcNotificationDriver(base.BaseQosTestCase):
 
     def setUp(self):
         super(TestQosRpcNotificationDriver, self).setUp()
+        rpc_api_cls = mock.patch('neutron.api.rpc.handlers.resources_rpc'
+                                 '.ResourcesPushRpcApi').start()
+        self.rpc_api = rpc_api_cls.return_value
         self.driver = message_queue.RpcQosServiceNotificationDriver()
 
         self.policy_data = {'policy': {
@@ -38,25 +43,26 @@ class TestQosRpcNotificationDriver(base.BaseQosTestCase):
                             'max_kbps': 100,
                             'max_burst_kbps': 150}}
 
-        self.policy = policy_object.QosPolicy(context,
+        self.context = context.get_admin_context()
+        self.policy = policy_object.QosPolicy(self.context,
                         **self.policy_data['policy'])
 
         self.rule = rule_object.QosBandwidthLimitRule(
-                                context,
+                                self.context,
                                 **self.rule_data['bandwidth_limit_rule'])
 
     def _validate_push_params(self, event_type, policy):
-        # TODO(QoS): actually validate push works once implemented
-        pass
+        self.rpc_api.push.assert_called_once_with(self.context, policy,
+                                                  event_type)
 
     def test_create_policy(self):
-        self.driver.create_policy(self.policy)
-        self._validate_push_params(events.CREATED, self.policy)
+        self.driver.create_policy(self.context, self.policy)
+        self.assertFalse(self.rpc_api.push.called)
 
     def test_update_policy(self):
-        self.driver.update_policy(self.policy)
+        self.driver.update_policy(self.context, self.policy)
         self._validate_push_params(events.UPDATED, self.policy)
 
     def test_delete_policy(self):
-        self.driver.delete_policy(self.policy)
+        self.driver.delete_policy(self.context, self.policy)
         self._validate_push_params(events.DELETED, self.policy)
index 1f530512a19b8d4d2d1789c1fd8807f4ba24555e..a44d27381a7b17708786af63ff980c4d479ed169 100644 (file)
@@ -46,9 +46,8 @@ class TestQosPlugin(base.BaseQosTestCase):
         self.qos_plugin = mgr.get_service_plugins().get(
             constants.QOS)
 
-        self.notif_driver_p = mock.patch.object(
-            self.qos_plugin, 'notification_driver_manager')
-        self.notif_driver_m = self.notif_driver_p.start()
+        self.notif_driver_m = mock.patch.object(
+            self.qos_plugin, 'notification_driver_manager').start()
 
         self.ctxt = context.Context('fake_user', 'fake_tenant')
         self.policy_data = {
@@ -64,16 +63,16 @@ class TestQosPlugin(base.BaseQosTestCase):
                                      'max_burst_kbps': 150}}
 
         self.policy = policy_object.QosPolicy(
-            context, **self.policy_data['policy'])
+            self.ctxt, **self.policy_data['policy'])
 
         self.rule = rule_object.QosBandwidthLimitRule(
-            context, **self.rule_data['bandwidth_limit_rule'])
+            self.ctxt, **self.rule_data['bandwidth_limit_rule'])
 
     def _validate_notif_driver_params(self, method_name):
         method = getattr(self.notif_driver_m, method_name)
         self.assertTrue(method.called)
         self.assertIsInstance(
-            method.call_args[0][0], policy_object.QosPolicy)
+            method.call_args[0][1], policy_object.QosPolicy)
 
     def test_add_policy(self):
         self.qos_plugin.create_policy(self.ctxt, self.policy_data)