from sqlalchemy.orm import scoped_session
from neutron.api.v2 import attributes as attr
+from neutron.common import constants
from neutron.db import db_base_plugin_v2
from neutron.db import model_base
from neutron.db import models_v2
from neutron.openstack.common import uuidutils
+IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL,
+ 'udp': constants.UDP_PROTOCOL,
+ 'icmp': constants.ICMP_PROTOCOL}
+
+
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
"""Represents a v2 neutron security group."""
return self.create_security_group_rule_bulk_native(context,
bulk_rule)[0]
+ def _get_ip_proto_number(self, protocol):
+ if protocol is None:
+ return
+ return IP_PROTOCOL_MAP.get(protocol, protocol)
+
+ def _validate_port_range(self, rule):
+ """Check that port_range is valid."""
+ if (rule['port_range_min'] is None and
+ rule['port_range_max'] is None):
+ return
+ if not rule['protocol']:
+ raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
+ ip_proto = self._get_ip_proto_number(rule['protocol'])
+ if ip_proto in [constants.TCP_PROTOCOL, constants.UDP_PROTOCOL]:
+ if (rule['port_range_min'] is not None and
+ rule['port_range_min'] <= rule['port_range_max']):
+ pass
+ else:
+ raise ext_sg.SecurityGroupInvalidPortRange()
+ elif ip_proto == constants.ICMP_PROTOCOL:
+ for attr, field in [('port_range_min', 'type'),
+ ('port_range_max', 'code')]:
+ if rule[attr] > 255:
+ raise ext_sg.SecurityGroupInvalidIcmpValue(
+ field=field, attr=attr, value=rule[attr])
+
def _validate_security_group_rules(self, context, security_group_rule):
"""Check that rules being installed.
rule = rules.get('security_group_rule')
new_rules.add(rule['security_group_id'])
- # Check that port_range's are valid
- if (rule['port_range_min'] is None and
- rule['port_range_max'] is None):
- pass
- elif (rule['port_range_min'] is not None and
- rule['port_range_min'] <= rule['port_range_max']):
- if not rule['protocol']:
- raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
- else:
- raise ext_sg.SecurityGroupInvalidPortRange()
+ self._validate_port_range(rule)
if rule['remote_ip_prefix'] and rule['remote_group_id']:
raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()
for k, v, in keys:
self.assertEqual(rule['security_group_rule'][k], v)
+ def test_create_security_group_rule_icmp_with_type_and_code(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ direction = "ingress"
+ remote_ip_prefix = "10.0.0.0/24"
+ protocol = 'icmp'
+ # port_range_min (ICMP type) is greater than port_range_max
+ # (ICMP code) in order to confirm min <= max port check is
+ # not called for ICMP.
+ port_range_min = 8
+ port_range_max = 5
+ keys = [('remote_ip_prefix', remote_ip_prefix),
+ ('security_group_id', security_group_id),
+ ('direction', direction),
+ ('protocol', protocol),
+ ('port_range_min', port_range_min),
+ ('port_range_max', port_range_max)]
+ with self.security_group_rule(security_group_id, direction,
+ protocol, port_range_min,
+ port_range_max,
+ remote_ip_prefix) as rule:
+ for k, v, in keys:
+ self.assertEqual(rule['security_group_rule'][k], v)
+
+ def test_create_security_group_rule_icmp_with_type_only(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ direction = "ingress"
+ remote_ip_prefix = "10.0.0.0/24"
+ protocol = 'icmp'
+ # ICMP type
+ port_range_min = 8
+ # ICMP code
+ port_range_max = None
+ keys = [('remote_ip_prefix', remote_ip_prefix),
+ ('security_group_id', security_group_id),
+ ('direction', direction),
+ ('protocol', protocol),
+ ('port_range_min', port_range_min),
+ ('port_range_max', port_range_max)]
+ with self.security_group_rule(security_group_id, direction,
+ protocol, port_range_min,
+ port_range_max,
+ remote_ip_prefix) as rule:
+ for k, v, in keys:
+ self.assertEqual(rule['security_group_rule'][k], v)
+
def test_create_security_group_source_group_ip_and_ip_prefix(self):
security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087"
direction = "ingress"
self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_min_port_greater_max(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ with self.security_group_rule(security_group_id):
+ for protocol in ['tcp', 'udp', 6, 17]:
+ rule = self._build_security_group_rule(
+ sg['security_group']['id'],
+ 'ingress', protocol, '50', '22')
+ self._create_security_group_rule(self.fmt, rule)
+ res = self._create_security_group_rule(self.fmt, rule)
+ self.deserialize(self.fmt, res)
+ self.assertEqual(res.status_int, 400)
+
+ def test_create_security_group_rule_ports_but_no_protocol(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
- sg['security_group']['id'], 'ingress', 'tcp', '50', '22')
+ sg['security_group']['id'], 'ingress', None, '22', '22')
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
- def test_create_security_group_rule_ports_but_no_protocol(self):
+ def test_create_security_group_rule_port_range_min_only(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule(
- sg['security_group']['id'], 'ingress', None, '22', '22')
+ sg['security_group']['id'], 'ingress', 'tcp', '22', None)
+ self._create_security_group_rule(self.fmt, rule)
+ res = self._create_security_group_rule(self.fmt, rule)
+ self.deserialize(self.fmt, res)
+ self.assertEqual(res.status_int, 400)
+
+ def test_create_security_group_rule_port_range_max_only(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ with self.security_group_rule(security_group_id):
+ rule = self._build_security_group_rule(
+ sg['security_group']['id'], 'ingress', 'tcp', None, '22')
+ self._create_security_group_rule(self.fmt, rule)
+ res = self._create_security_group_rule(self.fmt, rule)
+ self.deserialize(self.fmt, res)
+ self.assertEqual(res.status_int, 400)
+
+ def test_create_security_group_rule_icmp_type_too_big(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ with self.security_group_rule(security_group_id):
+ rule = self._build_security_group_rule(
+ sg['security_group']['id'], 'ingress', 'icmp', '256', None)
+ self._create_security_group_rule(self.fmt, rule)
+ res = self._create_security_group_rule(self.fmt, rule)
+ self.deserialize(self.fmt, res)
+ self.assertEqual(res.status_int, 400)
+
+ def test_create_security_group_rule_icmp_code_too_big(self):
+ name = 'webservers'
+ description = 'my webservers'
+ with self.security_group(name, description) as sg:
+ security_group_id = sg['security_group']['id']
+ with self.security_group_rule(security_group_id):
+ rule = self._build_security_group_rule(
+ sg['security_group']['id'], 'ingress', 'icmp', '8', '256')
self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)