]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Add support Quantum Security Groups for Ryu plugin
authorYoshihiro Kaneko <ykaneko0929@gmail.com>
Thu, 14 Feb 2013 12:04:12 +0000 (21:04 +0900)
committerNachi Ueno <nachi@nttmcl.com>
Wed, 20 Feb 2013 08:38:00 +0000 (00:38 -0800)
fix bug #1124965

This patch add support security-groups extension to Ryu plugin.

Change-Id: I569ab1e48517f28a5103175fd04e848f82eb2a3c

etc/quantum/plugins/ryu/ryu.ini
quantum/db/migration/alembic_migrations/versions/3cb5d900c5de_security_groups.py
quantum/plugins/ryu/agent/ryu_quantum_agent.py
quantum/plugins/ryu/common/config.py
quantum/plugins/ryu/db/api_v2.py
quantum/plugins/ryu/ryu_quantum_plugin.py
quantum/tests/unit/ryu/test_defaults.py
quantum/tests/unit/ryu/test_ryu_agent.py
quantum/tests/unit/ryu/test_ryu_security_group.py [new file with mode: 0644]

index 3376d2bbfb4d17cc72737886a5449c67c2b2b213..288d703c11575cd4fba4b11b2c7af987fc6ce5b3 100644 (file)
@@ -45,3 +45,11 @@ tunnel_interface = eth0
 # ovsdb_ip =
 # ovsdb_interface =
 ovsdb_interface = eth0
+
+[SECURITYGROUP]
+# Firewall driver for realizing quantum security group function
+# firewall_driver = quantum.agent.linux.iptables_firewall.OVSHybridIptablesFirewallDriver
+
+[AGENT]
+# Agent's polling interval in seconds
+polling_interval = 2
index 59e006537326ac337bf14324aebe8b19396ffe59..0b5ee231b37c51abfd4b7e38253ea09c8a882135 100644 (file)
@@ -34,6 +34,7 @@ migration_for_plugins = [
     'quantum.plugins.nicira.nicira_nvp_plugin.QuantumPlugin.NvpPluginV2',
     'quantum.plugins.openvswitch.ovs_quantum_plugin.OVSQuantumPluginV2',
     'quantum.plugins.nec.nec_plugin.NECPluginV2',
+    'quantum.plugins.ryu.ryu_quantum_plugin.RyuQuantumPluginV2',
 ]
 
 from alembic import op
index dbcab7cb3fa57fc7c7e19b502b05ab6e4517675a..684b13961eb580fb33cec0550afd10bcd06a5666 100755 (executable)
@@ -23,7 +23,9 @@
 import httplib
 import socket
 import sys
+import time
 
+import eventlet
 import netifaces
 from oslo.config import cfg
 from ryu.app import client
@@ -33,11 +35,14 @@ from ryu.app import rest_nw_id
 from quantum.agent.linux import ovs_lib
 from quantum.agent.linux.ovs_lib import VifPort
 from quantum.agent import rpc as agent_rpc
+from quantum.agent import securitygroups_rpc as sg_rpc
 from quantum.common import config as logging_config
 from quantum.common import exceptions as q_exc
 from quantum.common import topics
 from quantum import context as q_context
 from quantum.openstack.common import log
+from quantum.openstack.common.rpc import dispatcher
+from quantum.extensions import securitygroup as ext_sg
 from quantum.plugins.ryu.common import config
 
 
@@ -148,7 +153,8 @@ class VifPortSet(object):
                                  port.switch.datapath_id, port.ofport)
 
 
-class RyuPluginApi(agent_rpc.PluginApi):
+class RyuPluginApi(agent_rpc.PluginApi,
+                   sg_rpc.SecurityGroupServerRpcApiMixin):
     def get_ofp_rest_api_addr(self, context):
         LOG.debug(_("Get Ryu rest API address"))
         return self.call(context,
@@ -156,17 +162,42 @@ class RyuPluginApi(agent_rpc.PluginApi):
                          topic=self.topic)
 
 
-class OVSQuantumOFPRyuAgent(object):
+class RyuSecurityGroupAgent(sg_rpc.SecurityGroupAgentRpcMixin):
+    def __init__(self, context, plugin_rpc, root_helper):
+        self.context = context
+        self.plugin_rpc = plugin_rpc
+        self.root_helper = root_helper
+        self.init_firewall()
+
+
+class OVSQuantumOFPRyuAgent(sg_rpc.SecurityGroupAgentRpcCallbackMixin):
+
+    RPC_API_VERSION = '1.1'
+
     def __init__(self, integ_br, tunnel_ip, ovsdb_ip, ovsdb_port,
-                 root_helper):
+                 polling_interval, root_helper):
         super(OVSQuantumOFPRyuAgent, self).__init__()
+        self.polling_interval = polling_interval
         self._setup_rpc()
+        self.sg_agent = RyuSecurityGroupAgent(self.context,
+                                              self.plugin_rpc,
+                                              root_helper)
         self._setup_integration_br(root_helper, integ_br, tunnel_ip,
                                    ovsdb_port, ovsdb_ip)
 
     def _setup_rpc(self):
+        self.topic = topics.AGENT
         self.plugin_rpc = RyuPluginApi(topics.PLUGIN)
         self.context = q_context.get_admin_context_without_session()
+        self.dispatcher = self._create_rpc_dispatcher()
+        consumers = [[topics.PORT, topics.UPDATE],
+                     [topics.SECURITY_GROUP, topics.UPDATE]]
+        self.connection = agent_rpc.create_consumers(self.dispatcher,
+                                                     self.topic,
+                                                     consumers)
+
+    def _create_rpc_dispatcher(self):
+        return dispatcher.RpcDispatcher([self])
 
     def _setup_integration_br(self, root_helper, integ_br,
                               tunnel_ip, ovsdb_port, ovsdb_ip):
@@ -192,13 +223,64 @@ class OVSQuantumOFPRyuAgent(object):
         sc_client.set_key(self.int_br.datapath_id, conf_switch_key.OVSDB_ADDR,
                           'tcp:%s:%d' % (ovsdb_ip, ovsdb_port))
 
+    def port_update(self, context, **kwargs):
+        LOG.debug(_("port update received"))
+        port = kwargs.get('port')
+        vif_port = self.int_br.get_vif_port_by_id(port['id'])
+        if not vif_port:
+            return
+
+        if ext_sg.SECURITYGROUPS in port:
+            self.sg_agent.refresh_firewall()
+
+    def _update_ports(self, registered_ports):
+        ports = self.int_br.get_vif_port_set()
+        if ports == registered_ports:
+            return
+        added = ports - registered_ports
+        removed = registered_ports - ports
+        return {'current': ports,
+                'added': added,
+                'removed': removed}
+
+    def _process_devices_filter(self, port_info):
+        if 'added' in port_info:
+            self.sg_agent.prepare_devices_filter(port_info['added'])
+        if 'removed' in port_info:
+            self.sg_agent.remove_devices_filter(port_info['removed'])
+
+    def daemon_loop(self):
+        ports = set()
+
+        while True:
+            start = time.time()
+            try:
+                port_info = self._update_ports(ports)
+                if port_info:
+                    LOG.debug(_("Agent loop has new device"))
+                    self._process_devices_filter(port_info)
+                    ports = port_info['current']
+            except:
+                LOG.exception(_("Error in agent event loop"))
+
+            elapsed = max(time.time() - start, 0)
+            if (elapsed < self.polling_interval):
+                time.sleep(self.polling_interval - elapsed)
+            else:
+                LOG.debug(_("Loop iteration exceeded interval "
+                            "(%(polling_interval)s vs. %(elapsed)s)!"),
+                          {'polling_interval': self.polling_interval,
+                           'elapsed': elapsed})
+
 
 def main():
+    eventlet.monkey_patch()
     cfg.CONF(project='quantum')
 
     logging_config.setup_logging(cfg.CONF)
 
     integ_br = cfg.CONF.OVS.integration_bridge
+    polling_interval = cfg.CONF.AGENT.polling_interval
     root_helper = cfg.CONF.AGENT.root_helper
 
     tunnel_ip = _get_tunnel_ip()
@@ -208,14 +290,16 @@ def main():
     ovsdb_ip = _get_ovsdb_ip()
     LOG.debug(_('ovsdb_ip %s'), ovsdb_ip)
     try:
-        OVSQuantumOFPRyuAgent(integ_br, tunnel_ip, ovsdb_ip, ovsdb_port,
-                              root_helper)
+        agent = OVSQuantumOFPRyuAgent(integ_br, tunnel_ip, ovsdb_ip,
+                                      ovsdb_port, polling_interval,
+                                      root_helper)
     except httplib.HTTPException, e:
         LOG.error(_("Initialization failed: %s"), e)
         sys.exit(1)
 
-    LOG.info(_("Ryu initialization on the node is done."
-               " Now Ryu agent exits successfully."))
+    LOG.info(_("Ryu initialization on the node is done. "
+               "Agent initialized successfully, now running..."))
+    agent.daemon_loop()
     sys.exit(0)
 
 
index e8bc99988089a7414ea52a2b6e3356515db178c1..fec0ead4dbb6583d4ef5fed0fc65e4f2dac4a5b8 100644 (file)
@@ -40,6 +40,13 @@ ovs_opts = [
                help=_("OVSDB interface to connect to")),
 ]
 
+agent_opts = [
+    cfg.IntOpt('polling_interval', default=2,
+               help=_("The number of seconds the agent will wait between "
+                      "polling for local device changes.")),
+]
+
 
 cfg.CONF.register_opts(ovs_opts, "OVS")
+cfg.CONF.register_opts(agent_opts, "AGENT")
 config.register_root_helper(cfg.CONF)
index 4ebd27eece3de6adf8a15acaafccd9ac1a4cff9c..8d172504ad00765b21daf42a9c268d22462464c5 100644 (file)
@@ -21,6 +21,9 @@ from sqlalchemy.orm import exc as orm_exc
 from quantum.common import exceptions as q_exc
 import quantum.db.api as db
 from quantum.db import models_v2
+from quantum.db import securitygroups_db as sg_db
+from quantum.extensions import securitygroup as ext_sg
+from quantum import manager
 from quantum.openstack.common import log as logging
 from quantum.plugins.ryu.db import models_v2 as ryu_models_v2
 
@@ -33,6 +36,30 @@ def network_all_tenant_list():
     return session.query(models_v2.Network).all()
 
 
+def get_port_from_device(port_id):
+    LOG.debug(_("get_port_from_device() called:port_id=%s"), port_id)
+    session = db.get_session()
+    sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
+
+    query = session.query(models_v2.Port,
+                          sg_db.SecurityGroupPortBinding.security_group_id)
+    query = query.outerjoin(sg_db.SecurityGroupPortBinding,
+                            models_v2.Port.id == sg_binding_port)
+    query = query.filter(models_v2.Port.id == port_id)
+    port_and_sgs = query.all()
+    if not port_and_sgs:
+        return None
+    port = port_and_sgs[0][0]
+    plugin = manager.QuantumManager.get_plugin()
+    port_dict = plugin._make_port_dict(port)
+    port_dict[ext_sg.SECURITYGROUPS] = [
+        sg_id for port, sg_id in port_and_sgs if sg_id]
+    port_dict['security_group_rules'] = []
+    port_dict['security_group_source_groups'] = []
+    port_dict['fixed_ips'] = [ip['ip_address'] for ip in port['fixed_ips']]
+    return port_dict
+
+
 class TunnelKey(object):
     # VLAN: 12 bits
     # GRE, VXLAN: 24bits
index 402a9d3534a0ecbeb0a934f121070757e08cb23a..58dc0e4e452204b4426467dd89a96c227d280049 100644 (file)
@@ -20,6 +20,7 @@ from oslo.config import cfg
 from ryu.app import client
 from ryu.app import rest_nw_id
 
+from quantum.agent import securitygroups_rpc as sg_rpc
 from quantum.common import constants as q_const
 from quantum.common import exceptions as q_exc
 from quantum.common import rpc as q_rpc
@@ -30,8 +31,12 @@ from quantum.db import dhcp_rpc_base
 from quantum.db import extraroute_db
 from quantum.db import l3_rpc_base
 from quantum.db import models_v2
+from quantum.db import securitygroups_rpc_base as sg_db_rpc
+from quantum.extensions import securitygroup as ext_sg
+from quantum.openstack.common import cfg
 from quantum.openstack.common import log as logging
 from quantum.openstack.common import rpc
+from quantum.openstack.common.rpc import proxy
 from quantum.plugins.ryu.common import config
 from quantum.plugins.ryu.db import api_v2 as db_api_v2
 
@@ -40,9 +45,10 @@ LOG = logging.getLogger(__name__)
 
 
 class RyuRpcCallbacks(dhcp_rpc_base.DhcpRpcCallbackMixin,
-                      l3_rpc_base.L3RpcCallbackMixin):
+                      l3_rpc_base.L3RpcCallbackMixin,
+                      sg_db_rpc.SecurityGroupServerRpcCallbackMixin):
 
-    RPC_API_VERSION = '1.0'
+    RPC_API_VERSION = '1.1'
 
     def __init__(self, ofp_rest_api_addr):
         self.ofp_rest_api_addr = ofp_rest_api_addr
@@ -54,11 +60,37 @@ class RyuRpcCallbacks(dhcp_rpc_base.DhcpRpcCallbackMixin,
         LOG.debug(_("get_ofp_rest_api: %s"), self.ofp_rest_api_addr)
         return self.ofp_rest_api_addr
 
+    @classmethod
+    def get_port_from_device(cls, device):
+        port = db_api_v2.get_port_from_device(device)
+        if port:
+            port['device'] = device
+        return port
+
+
+class AgentNotifierApi(proxy.RpcProxy,
+                       sg_rpc.SecurityGroupAgentRpcApiMixin):
+
+    BASE_RPC_API_VERSION = '1.0'
+
+    def __init__(self, topic):
+        super(AgentNotifierApi, self).__init__(
+            topic=topic, default_version=self.BASE_RPC_API_VERSION)
+        self.topic_port_update = topics.get_topic_name(topic,
+                                                       topics.PORT,
+                                                       topics.UPDATE)
+
+    def port_update(self, context, port):
+        self.fanout_cast(context,
+                         self.make_msg('port_update', port=port),
+                         topic=self.topic_port_update)
+
 
 class RyuQuantumPluginV2(db_base_plugin_v2.QuantumDbPluginV2,
-                         extraroute_db.ExtraRoute_db_mixin):
+                         extraroute_db.ExtraRoute_db_mixin,
+                         sg_db_rpc.SecurityGroupServerRpcMixin):
 
-    supported_extension_aliases = ["router", "extraroute"]
+    supported_extension_aliases = ["router", "extraroute", "security-group"]
 
     def __init__(self, configfile=None):
         db.configure_db()
@@ -82,6 +114,7 @@ class RyuQuantumPluginV2(db_base_plugin_v2.QuantumDbPluginV2,
 
     def _setup_rpc(self):
         self.conn = rpc.create_connection(new=True)
+        self.notifier = AgentNotifierApi(topics.AGENT)
         self.callbacks = RyuRpcCallbacks(self.ofp_api_host)
         self.dispatcher = self.callbacks.create_rpc_dispatcher()
         self.conn.create_consumer(topics.PLUGIN, self.dispatcher, fanout=False)
@@ -109,6 +142,11 @@ class RyuQuantumPluginV2(db_base_plugin_v2.QuantumDbPluginV2,
     def create_network(self, context, network):
         session = context.session
         with session.begin(subtransactions=True):
+            #set up default security groups
+            tenant_id = self._get_tenant_id_for_create(
+                context, network['network'])
+            self._ensure_default_security_group(context, tenant_id)
+
             net = super(RyuQuantumPluginV2, self).create_network(context,
                                                                  network)
             self._process_l3_create(context, network['network'], net['id'])
@@ -154,7 +192,19 @@ class RyuQuantumPluginV2(db_base_plugin_v2.QuantumDbPluginV2,
         return [self._fields(net, fields) for net in nets]
 
     def create_port(self, context, port):
-        port = super(RyuQuantumPluginV2, self).create_port(context, port)
+        session = context.session
+        with session.begin(subtransactions=True):
+            self._ensure_default_security_group_on_port(context, port)
+            sgids = self._get_security_groups_on_port(context, port)
+            port = super(RyuQuantumPluginV2, self).create_port(context, port)
+            self._process_port_create_security_group(
+                context, port['id'], sgids)
+            self._extend_port_dict_security_group(context, port)
+        if port['device_owner'] == q_const.DEVICE_OWNER_DHCP:
+            self.notifier.security_groups_provider_updated(context)
+        else:
+            self.notifier.security_groups_member_updated(
+                context, port.get(ext_sg.SECURITYGROUPS))
         self.iface_client.create_network_id(port['id'], port['network_id'])
         return port
 
@@ -163,13 +213,53 @@ class RyuQuantumPluginV2(db_base_plugin_v2.QuantumDbPluginV2,
         # and l3-router. If so, we should prevent deletion.
         if l3_port_check:
             self.prevent_l3_port_deletion(context, id)
-        self.disassociate_floatingips(context, id)
-        return super(RyuQuantumPluginV2, self).delete_port(context, id)
+
+        with context.session.begin(subtransactions=True):
+            self.disassociate_floatingips(context, id)
+            port = self.get_port(context, id)
+            self._delete_port_security_group_bindings(context, id)
+            super(RyuQuantumPluginV2, self).delete_port(context, id)
+
+        self.notifier.security_groups_member_updated(
+            context, port.get(ext_sg.SECURITYGROUPS))
 
     def update_port(self, context, id, port):
         deleted = port['port'].get('deleted', False)
-        port = super(RyuQuantumPluginV2, self).update_port(context, id, port)
+        session = context.session
+
+        need_port_update_notify = False
+        with session.begin(subtransactions=True):
+            original_port = super(RyuQuantumPluginV2, self).get_port(
+                context, id)
+            updated_port = super(RyuQuantumPluginV2, self).update_port(
+                context, id, port)
+            need_port_update_notify = self.update_security_group_on_port(
+                context, id, port, original_port, updated_port)
+
+        need_port_update_notify |= self.is_security_group_member_updated(
+            context, original_port, updated_port)
+
+        need_port_update_notify |= (original_port['admin_state_up'] !=
+                                    updated_port['admin_state_up'])
+
+        if need_port_update_notify:
+            self.notifier.port_update(context, updated_port)
+
         if deleted:
-            session = context.session
             db_api_v2.set_port_status(session, id, q_const.PORT_STATUS_DOWN)
-        return port
+        return updated_port
+
+    def get_port(self, context, id, fields=None):
+        with context.session.begin(subtransactions=True):
+            port = super(RyuQuantumPluginV2, self).get_port(context, id,
+                                                            fields)
+            self._extend_port_dict_security_group(context, port)
+        return self._fields(port, fields)
+
+    def get_ports(self, context, filters=None, fields=None):
+        with context.session.begin(subtransactions=True):
+            ports = super(RyuQuantumPluginV2, self).get_ports(
+                context, filters, fields)
+            for port in ports:
+                self._extend_port_dict_security_group(context, port)
+        return [self._fields(port, fields) for port in ports]
index a9d823f950dca3e8a026dcf3c998acafd2d91afd..c2d81bf49c935f3fa7cc64cc86f9c4a463dfc4cc 100644 (file)
@@ -28,6 +28,7 @@ class ConfigurationTest(unittest2.TestCase):
         self.assertEqual('br-int', cfg.CONF.OVS.integration_bridge)
         self.assertEqual(-1, cfg.CONF.DATABASE.sql_max_retries)
         self.assertEqual(2, cfg.CONF.DATABASE.reconnect_interval)
+        self.assertEqual(2, cfg.CONF.AGENT.polling_interval)
         self.assertEqual('sudo', cfg.CONF.AGENT.root_helper)
         self.assertEqual('127.0.0.1:8080', cfg.CONF.OVS.openflow_rest_api)
         self.assertEqual(1, cfg.CONF.OVS.tunnel_key_min)
index 2a22ab75c65e5d7fdc8979e18c6b1f11c1a48921..95cc4407e92700113d73b2042005076285dd2893 100644 (file)
@@ -44,12 +44,19 @@ class TestOVSQuantumOFPRyuAgent(RyuAgentTestCase):
             self._AGENT_NAME + '.VifPortSet').start()
         self.q_ctx = mock.patch(
             self._AGENT_NAME + '.q_context').start()
+        self.agent_rpc = mock.patch(
+            self._AGENT_NAME + '.agent_rpc.create_consumers').start()
+        self.sg_rpc = mock.patch(
+            self._AGENT_NAME + '.sg_rpc').start()
+        self.sg_agent = mock.patch(
+            self._AGENT_NAME + '.RyuSecurityGroupAgent').start()
 
     def mock_rest_addr(self, rest_addr):
         integ_br = 'integ_br'
         tunnel_ip = '192.168.0.1'
         ovsdb_ip = '172.16.0.1'
         ovsdb_port = 16634
+        interval = 2
         root_helper = 'helper'
 
         self.mod_agent.OVSBridge.return_value.datapath_id = '1234'
@@ -61,8 +68,8 @@ class TestOVSQuantumOFPRyuAgent(RyuAgentTestCase):
         self.plugin_api.return_value.get_ofp_rest_api_addr = mock_rest_addr
 
         # Instantiate OVSQuantumOFPRyuAgent
-        self.agent = self.mod_agent.OVSQuantumOFPRyuAgent(
-            integ_br, tunnel_ip, ovsdb_ip, ovsdb_port, root_helper)
+        return self.mod_agent.OVSQuantumOFPRyuAgent(
+            integ_br, tunnel_ip, ovsdb_ip, ovsdb_port, interval, root_helper)
 
     def test_valid_rest_addr(self):
         self.mock_rest_addr('192.168.0.1:8080')
@@ -79,6 +86,11 @@ class TestOVSQuantumOFPRyuAgent(RyuAgentTestCase):
             mock.call().get_ofp_rest_api_addr('abc')
         ])
 
+        # Agent RPC
+        self.agent_rpc.assert_has_calls([
+            mock.call(mock.ANY, 'q-agent-notifier', mock.ANY)
+        ])
+
         # OFPClient
         self.mod_agent.client.OFPClient.assert_calls([
             mock.call('192.168.0.1:8080')
@@ -93,7 +105,6 @@ class TestOVSQuantumOFPRyuAgent(RyuAgentTestCase):
         ])
 
         # SwitchConfClient
-
         self.mod_agent.client.SwitchConfClient.assert_has_calls([
             mock.call('192.168.0.1:8080'),
             mock.call().set_key('1234', 'ovs_tunnel_addr', '192.168.0.1'),
@@ -110,6 +121,108 @@ class TestOVSQuantumOFPRyuAgent(RyuAgentTestCase):
         self.assertRaises(self.mod_agent.q_exc.Invalid,
                           self.mock_rest_addr, (''))
 
+    def mock_port_update(self, **kwargs):
+        agent = self.mock_rest_addr('192.168.0.1:8080')
+        agent.port_update(mock.Mock(), **kwargs)
+
+    def test_port_update(self, **kwargs):
+        port = {'id': 1, 'security_groups': 'default'}
+
+        with mock.patch.object(self.ovsbridge.return_value,
+                               'get_vif_port_by_id',
+                               return_value=1) as get_vif:
+            self.mock_port_update(port=port)
+
+        get_vif.assert_called_once_with(1)
+        self.sg_agent.assert_calls([
+            mock.call().refresh_firewall()
+        ])
+
+    def test_port_update_not_vifport(self, **kwargs):
+        port = {'id': 1, 'security_groups': 'default'}
+
+        with mock.patch.object(self.ovsbridge.return_value,
+                               'get_vif_port_by_id',
+                               return_value=0) as get_vif:
+            self.mock_port_update(port=port)
+
+        get_vif.assert_called_once_with(1)
+        self.assertFalse(self.sg_agent.return_value.refresh_firewall.called)
+
+    def test_port_update_without_secgroup(self, **kwargs):
+        port = {'id': 1}
+
+        with mock.patch.object(self.ovsbridge.return_value,
+                               'get_vif_port_by_id',
+                               return_value=1) as get_vif:
+            self.mock_port_update(port=port)
+
+        get_vif.assert_called_once_with(1)
+        self.assertFalse(self.sg_agent.return_value.refresh_firewall.called)
+
+    def mock_update_ports(self, vif_port_set=None, registered_ports=None):
+        with mock.patch.object(self.ovsbridge.return_value,
+                               'get_vif_port_set',
+                               return_value=vif_port_set):
+            agent = self.mock_rest_addr('192.168.0.1:8080')
+            return agent._update_ports(registered_ports)
+
+    def test_update_ports_unchanged(self):
+        self.assertIsNone(self.mock_update_ports())
+
+    def test_update_ports_changed(self):
+        vif_port_set = set([1, 3])
+        registered_ports = set([1, 2])
+        expected = dict(current=vif_port_set,
+                        added=set([3]),
+                        removed=set([2]))
+
+        actual = self.mock_update_ports(vif_port_set, registered_ports)
+
+        self.assertEqual(expected, actual)
+
+    def mock_process_devices_filter(self, port_info):
+        agent = self.mock_rest_addr('192.168.0.1:8080')
+        agent._process_devices_filter(port_info)
+
+    def test_process_devices_filter_add(self):
+        port_info = {'added': 1}
+
+        self.mock_process_devices_filter(port_info)
+
+        self.sg_agent.assert_calls([
+            mock.call().prepare_devices_filter(1)
+        ])
+
+    def test_process_devices_filter_remove(self):
+        port_info = {'removed': 2}
+
+        self.mock_process_devices_filter(port_info)
+
+        self.sg_agent.assert_calls([
+            mock.call().remove_devices_filter(2)
+        ])
+
+    def test_process_devices_filter_both(self):
+        port_info = {'added': 1, 'removed': 2}
+
+        self.mock_process_devices_filter(port_info)
+
+        self.sg_agent.assert_calls([
+            mock.call().prepare_devices_filter(1),
+            mock.call().remove_devices_filter(2)
+        ])
+
+    def test_process_devices_filter_none(self):
+        port_info = {}
+
+        self.mock_process_devices_filter(port_info)
+
+        self.assertFalse(
+            self.sg_agent.return_value.prepare_devices_filter.called)
+        self.assertFalse(
+            self.sg_agent.return_value.remove_devices_filter.called)
+
 
 class TestRyuPluginApi(RyuAgentTestCase):
     def test_get_ofp_rest_api_addr(self):
@@ -468,17 +581,15 @@ class TestRyuQuantumAgent(RyuAgentTestCase):
         ])
 
     def test_main(self):
-        with nested(
-            mock.patch(self._AGENT_NAME + '.OVSQuantumOFPRyuAgent'),
-            mock.patch('sys.exit', side_effect=SystemExit(0))
-        ) as (mock_agent, mock_exit):
+        agent_attrs = {'daemon_loop.side_effect': SystemExit(0)}
+        with mock.patch(self._AGENT_NAME + '.OVSQuantumOFPRyuAgent',
+                        **agent_attrs) as mock_agent:
             self.assertRaises(SystemExit, self.mock_main)
 
         mock_agent.assert_calls([
-            mock.call('integ_br', '10.0.0.1', '172.16.0.1', 16634, 'helper')
-        ])
-        mock_exit.assert_calls([
-            mock.call(0)
+            mock.call('integ_br', '10.0.0.1', '172.16.0.1', 16634, 2,
+                      'helper'),
+            mock.call().daemon_loop()
         ])
 
     def test_main_raise(self):
@@ -490,7 +601,8 @@ class TestRyuQuantumAgent(RyuAgentTestCase):
             self.assertRaises(SystemExit, self.mock_main)
 
         mock_agent.assert_calls([
-            mock.call('integ_br', '10.0.0.1', '172.16.0.1', 16634, 'helper')
+            mock.call('integ_br', '10.0.0.1', '172.16.0.1', 16634, 2,
+                      'helper')
         ])
         mock_exit.assert_calls([
             mock.call(1)
diff --git a/quantum/tests/unit/ryu/test_ryu_security_group.py b/quantum/tests/unit/ryu/test_ryu_security_group.py
new file mode 100644 (file)
index 0000000..64f6963
--- /dev/null
@@ -0,0 +1,96 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2012, Nachi Ueno, NTT MCL, Inc.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import contextlib
+
+import mock
+
+from quantum.api.v2 import attributes
+from quantum.extensions import securitygroup as ext_sg
+from quantum import manager
+from quantum.plugins.ryu.db import api_v2 as api_db_v2
+from quantum.tests.unit import test_extension_security_group as test_sg
+from quantum.tests.unit import test_security_groups_rpc as test_sg_rpc
+
+PLUGIN_NAME = ('quantum.plugins.ryu.'
+               'ryu_quantum_plugin.RyuQuantumPluginV2')
+AGENT_NAME = ('quantum.plugins.ryu.'
+              'agent.ryu_quantum_agent.OVSQuantumOFPRyuAgent')
+NOTIFIER = ('quantum.plugins.ryu.'
+            'ryu_quantum_plugin.AgentNotifierApi')
+
+
+class RyuSecurityGroupsTestCase(test_sg.SecurityGroupDBTestCase):
+    _plugin_name = PLUGIN_NAME
+
+    def setUp(self, plugin=None):
+        self.addCleanup(mock.patch.stopall)
+        notifier_p = mock.patch(NOTIFIER)
+        notifier_cls = notifier_p.start()
+        self.notifier = mock.Mock()
+        notifier_cls.return_value = self.notifier
+        self._attribute_map_bk_ = {}
+        for item in attributes.RESOURCE_ATTRIBUTE_MAP:
+            self._attribute_map_bk_[item] = (attributes.
+                                             RESOURCE_ATTRIBUTE_MAP[item].
+                                             copy())
+        super(RyuSecurityGroupsTestCase, self).setUp(PLUGIN_NAME)
+
+    def tearDown(self):
+        super(RyuSecurityGroupsTestCase, self).tearDown()
+        attributes.RESOURCE_ATTRIBUTE_MAP = self._attribute_map_bk_
+
+
+class TestRyuSecurityGroups(RyuSecurityGroupsTestCase,
+                            test_sg.TestSecurityGroups,
+                            test_sg_rpc.SGNotificationTestMixin):
+    def test_security_group_get_port_from_device(self):
+        with contextlib.nested(self.network(),
+                               self.security_group()) as (n, sg):
+            with self.subnet(n):
+                security_group_id = sg['security_group']['id']
+                res = self._create_port(self.fmt, n['network']['id'])
+                port = self.deserialize(self.fmt, res)
+                fixed_ips = port['port']['fixed_ips']
+                data = {'port': {'fixed_ips': fixed_ips,
+                                 'name': port['port']['name'],
+                                 ext_sg.SECURITYGROUPS:
+                                 [security_group_id]}}
+
+                req = self.new_update_request('ports', data,
+                                              port['port']['id'])
+                res = self.deserialize(self.fmt,
+                                       req.get_response(self.api))
+                port_id = res['port']['id']
+                plugin = manager.QuantumManager.get_plugin()
+                port_dict = plugin.callbacks.get_port_from_device(port_id)
+                self.assertEqual(port_id, port_dict['id'])
+                self.assertEqual([security_group_id],
+                                 port_dict[ext_sg.SECURITYGROUPS])
+                self.assertEqual([], port_dict['security_group_rules'])
+                self.assertEqual([fixed_ips[0]['ip_address']],
+                                 port_dict['fixed_ips'])
+                self._delete('ports', port_id)
+
+    def test_security_group_get_port_from_device_with_no_port(self):
+        plugin = manager.QuantumManager.get_plugin()
+        port_dict = plugin.callbacks.get_port_from_device('bad_device_id')
+        self.assertEqual(None, port_dict)
+
+
+class TestRyuSecurityGroupsXML(TestRyuSecurityGroups):
+    fmt = 'xml'