# under the License.
#
# @author: Aaron Rosen, Nicira, Inc
-#
import sqlalchemy as sa
from sqlalchemy import orm
else:
return default_group[0]['id']
- def _validate_security_groups_on_port(self, context, port):
+ def _get_security_groups_on_port(self, context, port):
+ """Check that all security groups on port belong to tenant.
+
+ :returns: all security groups IDs on port belonging to tenant.
+ """
p = port['port']
if not attr.is_attr_set(p.get(ext_sg.SECURITYGROUPS)):
return
if p.get('device_owner') and p['device_owner'].startswith('network:'):
return
- valid_groups = self.get_security_groups(context, fields={'id': None})
- valid_groups_set = set([x['id'] for x in valid_groups])
- req_sg_set = set(p[ext_sg.SECURITYGROUPS])
- invalid_sg_set = req_sg_set - valid_groups_set
- if invalid_sg_set:
- msg = ' '.join(str(x) for x in invalid_sg_set)
- raise ext_sg.SecurityGroupNotFound(id=msg)
+ valid_groups = self.get_security_groups(
+ context, fields=['external_id', 'id'])
+ valid_group_map = dict((g['id'], g['id']) for g in valid_groups)
+ valid_group_map.update((g['external_id'], g['id'])
+ for g in valid_groups if g.get('external_id'))
+ try:
+ return set([valid_group_map[sg_id]
+ for sg_id in p.get(ext_sg.SECURITYGROUPS, [])])
+ except KeyError as e:
+ raise ext_sg.SecurityGroupNotFound(id=str(e))
def _ensure_default_security_group_on_port(self, context, port):
+ # return if proxy_mode is enabled since nova will handle adding
+ # the port to the default security group.
+ if cfg.CONF.SECURITYGROUP.proxy_mode:
+ return
# we don't apply security groups for dhcp, router
if (port['port'].get('device_owner') and
port['port']['device_owner'].startswith('network:')):
default_sg = self._ensure_default_security_group(context, tenant_id)
if not port['port'].get(ext_sg.SECURITYGROUPS):
port['port'][ext_sg.SECURITYGROUPS] = [default_sg]
- self._validate_security_groups_on_port(context, port)
session = context.session
with session.begin(subtransactions=True):
- sgids = port['port'].get(ext_sg.SECURITYGROUPS)
+ sgids = self._get_security_groups_on_port(context, port)
port = super(SecurityGroupTestPlugin, self).create_port(context,
port)
self._process_port_create_security_group(context, port['id'],
session = context.session
with session.begin(subtransactions=True):
if ext_sg.SECURITYGROUPS in port['port']:
- self._validate_security_groups_on_port(context, port)
+ port['port'][ext_sg.SECURITYGROUPS] = (
+ self._get_security_groups_on_port(context, port))
# delete the port binding and read it with the new rules
self._delete_port_security_group_bindings(context, id)
self._process_port_create_security_group(
self.assertEqual(security_group['security_group'][k], v)
def test_create_security_group_external_id(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
name = 'webservers'
description = 'my webservers'
external_id = 10
self.assertEqual(len(groups['security_groups']), 1)
def test_create_security_group_proxy_mode_not_admin(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
res = self._create_security_group('json', 'webservers',
'webservers', '1',
tenant_id='bad_tenant',
self.assertEqual(res.status_int, 403)
def test_create_security_group_no_external_id_proxy_mode(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
res = self._create_security_group('json', 'webservers',
'webservers')
self.deserialize('json', res)
self.assertEqual(res.status_int, 409)
def test_create_security_group_duplicate_external_id(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
name = 'webservers'
description = 'my webservers'
external_id = 1
self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_exteral_id_proxy_mode(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
with self.security_group(external_id=1) as sg:
rule = {'security_group_rule':
{'security_group_id': sg['security_group']['id'],
self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_not_admin(self):
- cfg.CONF.SECURITYGROUP.proxy_mode = True
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
with self.security_group(external_id='1') as sg:
rule = {'security_group_rule':
{'security_group_id': sg['security_group']['id'],
port['port'][ext_sg.SECURITYGROUPS]), 2)
self._delete('ports', port['port']['id'])
- def test_update_port_remove_security_group(self):
+ def test_update_port_remove_security_group_empty_list(self):
with self.network() as n:
with self.subnet(n):
with self.security_group() as sg:
[])
self._delete('ports', port['port']['id'])
+ def test_update_port_remove_security_group_none(self):
+ with self.network() as n:
+ with self.subnet(n):
+ with self.security_group() as sg:
+ res = self._create_port('json', n['network']['id'],
+ security_groups=(
+ [sg['security_group']['id']]))
+ port = self.deserialize('json', res)
+
+ data = {'port': {'fixed_ips': port['port']['fixed_ips'],
+ 'name': port['port']['name'],
+ 'security_groups': None}}
+
+ req = self.new_update_request('ports', data,
+ port['port']['id'])
+ res = self.deserialize('json', req.get_response(self.api))
+ self.assertEqual(res['port'].get(ext_sg.SECURITYGROUPS),
+ [])
+ self._delete('ports', port['port']['id'])
+
def test_create_port_with_bad_security_group(self):
with self.network() as n:
with self.subnet(n):
security_groups=['bad_id'])
self.deserialize('json', res)
- self.assertEqual(res.status_int, 404)
+ self.assertEqual(res.status_int, 400)
def test_create_delete_security_group_port_in_use(self):
with self.network() as n:
res = self._create_security_group_rule('json', rule)
self.deserialize('json', res)
self.assertEqual(res.status_int, 400)
+
+ def test_validate_port_external_id_quantum_id(self):
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
+ with self.network() as n:
+ with self.subnet(n):
+ sg1 = (self.deserialize('json',
+ self._create_security_group('json', 'foo', 'bar', '1')))
+ sg2 = (self.deserialize('json',
+ self._create_security_group('json', 'foo', 'bar', '2')))
+ res = self._create_port(
+ 'json', n['network']['id'],
+ security_groups=[sg1['security_group']['id']])
+
+ port = self.deserialize('json', res)
+ # This request updates the port sending the quantum security
+ # group id in and a nova security group id.
+ data = {'port': {'fixed_ips': port['port']['fixed_ips'],
+ 'name': port['port']['name'],
+ ext_sg.SECURITYGROUPS:
+ [sg1['security_group']['external_id'],
+ sg2['security_group']['id']]}}
+ req = self.new_update_request('ports', data,
+ port['port']['id'])
+ res = self.deserialize('json', req.get_response(self.api))
+ self.assertEquals(len(res['port'][ext_sg.SECURITYGROUPS]), 2)
+ for sg_id in res['port'][ext_sg.SECURITYGROUPS]:
+ # only security group id's should be
+ # returned and not external_ids
+ self.assertEquals(len(sg_id), 36)
+ self._delete('ports', port['port']['id'])
+
+ def test_validate_port_external_id_string_or_int(self):
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
+ with self.network() as n:
+ with self.subnet(n):
+ string_id = '1'
+ int_id = 2
+ self.deserialize(
+ 'json', self._create_security_group('json', 'foo', 'bar',
+ string_id))
+ self.deserialize(
+ 'json', self._create_security_group('json', 'foo', 'bar',
+ int_id))
+ res = self._create_port(
+ 'json', n['network']['id'],
+ security_groups=[string_id, int_id])
+
+ port = self.deserialize('json', res)
+ self._delete('ports', port['port']['id'])
+
+ def test_create_port_with_non_uuid_or_int(self):
+ with self.network() as n:
+ with self.subnet(n):
+ res = self._create_port('json', n['network']['id'],
+ security_groups=['not_valid'])
+
+ self.deserialize('json', res)
+ self.assertEqual(res.status_int, 400)
+
+ def test_validate_port_external_id_fail(self):
+ cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
+ with self.network() as n:
+ with self.subnet(n):
+ bad_id = 1
+ res = self._create_port(
+ 'json', n['network']['id'],
+ security_groups=[bad_id])
+
+ self.deserialize('json', res)
+ self.assertEqual(res.status_int, 404)