b1f1e5d8e3f06d383f76f2a554d3350c1d1585aa
[openstack-build/neutron-build.git] / neutron / tests / unit / plugins / ml2 / test_security_group.py
1 # Copyright (c) 2013 OpenStack Foundation
2 # Copyright 2013, Nachi Ueno, NTT MCL, Inc.
3 # All Rights Reserved.
4 #
5 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
6 #    not use this file except in compliance with the License. You may obtain
7 #    a copy of the License at
8 #
9 #         http://www.apache.org/licenses/LICENSE-2.0
10 #
11 #    Unless required by applicable law or agreed to in writing, software
12 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 #    License for the specific language governing permissions and limitations
15 #    under the License.
16
17 import math
18 import mock
19
20 from neutron.common import constants as const
21 from neutron import context
22 from neutron.extensions import securitygroup as ext_sg
23 from neutron import manager
24 from neutron.tests import tools
25 from neutron.tests.unit.agent import test_securitygroups_rpc as test_sg_rpc
26 from neutron.tests.unit.api.v2 import test_base
27 from neutron.tests.unit.extensions import test_securitygroup as test_sg
28
29 PLUGIN_NAME = 'neutron.plugins.ml2.plugin.Ml2Plugin'
30 NOTIFIER = 'neutron.plugins.ml2.rpc.AgentNotifierApi'
31
32
33 class Ml2SecurityGroupsTestCase(test_sg.SecurityGroupDBTestCase):
34     _plugin_name = PLUGIN_NAME
35
36     def setUp(self, plugin=None):
37         test_sg_rpc.set_firewall_driver(test_sg_rpc.FIREWALL_HYBRID_DRIVER)
38         notifier_p = mock.patch(NOTIFIER)
39         notifier_cls = notifier_p.start()
40         self.notifier = mock.Mock()
41         notifier_cls.return_value = self.notifier
42         self.useFixture(tools.AttributeMapMemento())
43         super(Ml2SecurityGroupsTestCase, self).setUp(PLUGIN_NAME)
44
45     def tearDown(self):
46         super(Ml2SecurityGroupsTestCase, self).tearDown()
47
48
49 class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
50                             test_sg.TestSecurityGroups,
51                             test_sg_rpc.SGNotificationTestMixin):
52     def setUp(self):
53         super(TestMl2SecurityGroups, self).setUp()
54         self.ctx = context.get_admin_context()
55         plugin = manager.NeutronManager.get_plugin()
56         plugin.start_rpc_listeners()
57
58     def _make_port_with_new_sec_group(self, net_id):
59         sg = self._make_security_group(self.fmt, 'name', 'desc')
60         port = self._make_port(
61             self.fmt, net_id, security_groups=[sg['security_group']['id']])
62         return port['port']
63
64     def _make_port_without_sec_group(self, net_id):
65         port = self._make_port(
66             self.fmt, net_id, security_groups=[])
67         return port['port']
68
69     def test_security_group_get_ports_from_devices(self):
70         with self.network() as n:
71             with self.subnet(n):
72                 orig_ports = [
73                     self._make_port_with_new_sec_group(n['network']['id']),
74                     self._make_port_with_new_sec_group(n['network']['id']),
75                     self._make_port_without_sec_group(n['network']['id'])
76                 ]
77                 plugin = manager.NeutronManager.get_plugin()
78                 # should match full ID and starting chars
79                 ports = plugin.get_ports_from_devices(self.ctx,
80                     [orig_ports[0]['id'], orig_ports[1]['id'][0:8],
81                      orig_ports[2]['id']])
82                 self.assertEqual(len(orig_ports), len(ports))
83                 for port_dict in ports:
84                     p = next(p for p in orig_ports
85                              if p['id'] == port_dict['id'])
86                     self.assertEqual(p['id'], port_dict['id'])
87                     self.assertEqual(p['security_groups'],
88                                      port_dict[ext_sg.SECURITYGROUPS])
89                     self.assertEqual([], port_dict['security_group_rules'])
90                     self.assertEqual([p['fixed_ips'][0]['ip_address']],
91                                      port_dict['fixed_ips'])
92                     self._delete('ports', p['id'])
93
94     def test_security_group_get_ports_from_devices_with_bad_id(self):
95         plugin = manager.NeutronManager.get_plugin()
96         ports = plugin.get_ports_from_devices(self.ctx, ['bad_device_id'])
97         self.assertFalse(ports)
98
99     def test_security_group_no_db_calls_with_no_ports(self):
100         plugin = manager.NeutronManager.get_plugin()
101         with mock.patch(
102             'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
103         ) as get_mock:
104             self.assertFalse(plugin.get_ports_from_devices(self.ctx, []))
105             self.assertFalse(get_mock.called)
106
107     def test_large_port_count_broken_into_parts(self):
108         plugin = manager.NeutronManager.get_plugin()
109         max_ports_per_query = 5
110         ports_to_query = 73
111         for max_ports_per_query in (1, 2, 5, 7, 9, 31):
112             with mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY',
113                             new=max_ports_per_query),\
114                     mock.patch(
115                         'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
116                         return_value={}) as get_mock:
117                 plugin.get_ports_from_devices(self.ctx,
118                     ['%s%s' % (const.TAP_DEVICE_PREFIX, i)
119                      for i in range(ports_to_query)])
120                 all_call_args = [x[1][1] for x in get_mock.mock_calls]
121                 last_call_args = all_call_args.pop()
122                 # all but last should be getting MAX_PORTS_PER_QUERY ports
123                 self.assertTrue(
124                     all(map(lambda x: len(x) == max_ports_per_query,
125                             all_call_args))
126                 )
127                 remaining = ports_to_query % max_ports_per_query
128                 if remaining:
129                     self.assertEqual(remaining, len(last_call_args))
130                 # should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls
131                 self.assertEqual(
132                     math.ceil(ports_to_query / float(max_ports_per_query)),
133                     get_mock.call_count
134                 )
135
136     def test_full_uuids_skip_port_id_lookup(self):
137         plugin = manager.NeutronManager.get_plugin()
138         # when full UUIDs are provided, the _or statement should only
139         # have one matching 'IN' criteria for all of the IDs
140         with mock.patch('neutron.plugins.ml2.db.or_') as or_mock,\
141                 mock.patch('sqlalchemy.orm.Session.query') as qmock:
142             fmock = qmock.return_value.outerjoin.return_value.filter
143             # return no ports to exit the method early since we are mocking
144             # the query
145             fmock.return_value = []
146             plugin.get_ports_from_devices(self.ctx,
147                                           [test_base._uuid(),
148                                            test_base._uuid()])
149             # the or_ function should only have one argument
150             or_mock.assert_called_once_with(mock.ANY)
151
152
153 class TestMl2SGServerRpcCallBack(
154     Ml2SecurityGroupsTestCase,
155     test_sg_rpc.SGServerRpcCallBackTestCase):
156     pass