]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Limit min<=max port check to TCP/UDP in secgroup rule
authorAkihiro MOTOKI <motoki@da.jp.nec.com>
Thu, 4 Jul 2013 16:27:18 +0000 (01:27 +0900)
committerAkihiro MOTOKI <motoki@da.jp.nec.com>
Mon, 8 Jul 2013 13:42:32 +0000 (22:42 +0900)
icmp_type and icmp_code are mapped to port_min_range and port_max_range
respectively. For ICMP there is no constraint between type and code.
Thus port range min<=max check should be enforced only for TCP and UDP.

Also makes sure that ICMP type/code are 0 to 255 (both inclusive).
Previously a value with 0 to 65535 were accepted for ICMP type/code.

Fixes bug 1197760
Fixes bug 1197769

Change-Id: I70aaf6e02fee461fa97dc254db906d9efa173669

neutron/common/constants.py
neutron/db/securitygroups_db.py
neutron/extensions/securitygroup.py
neutron/tests/unit/test_extension_security_group.py

index 06b0c53d2c3aa44abb9fb9e11d0030e3e4feabb1..0fa9c50ed4e42586c6e0524d234d06c2f4694738 100644 (file)
@@ -34,7 +34,10 @@ INTERFACE_KEY = '_interfaces'
 IPv4 = 'IPv4'
 IPv6 = 'IPv6'
 
+ICMP_PROTOCOL = 1
+TCP_PROTOCOL = 6
 UDP_PROTOCOL = 17
+
 DHCP_RESPONSE_PORT = 68
 
 MIN_VLAN_TAG = 1
index 4201d42a38623d5acf01a34a5ab1bd89fcccbf64..198c231a8da1607c923198e2871283e69d8504bd 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.orm import exc
 from sqlalchemy.orm import scoped_session
 
 from neutron.api.v2 import attributes as attr
+from neutron.common import constants
 from neutron.db import db_base_plugin_v2
 from neutron.db import model_base
 from neutron.db import models_v2
@@ -29,6 +30,11 @@ from neutron.extensions import securitygroup as ext_sg
 from neutron.openstack.common import uuidutils
 
 
+IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL,
+                   'udp': constants.UDP_PROTOCOL,
+                   'icmp': constants.ICMP_PROTOCOL}
+
+
 class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
     """Represents a v2 neutron security group."""
 
@@ -284,6 +290,32 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
         return self.create_security_group_rule_bulk_native(context,
                                                            bulk_rule)[0]
 
+    def _get_ip_proto_number(self, protocol):
+        if protocol is None:
+            return
+        return IP_PROTOCOL_MAP.get(protocol, protocol)
+
+    def _validate_port_range(self, rule):
+        """Check that port_range is valid."""
+        if (rule['port_range_min'] is None and
+            rule['port_range_max'] is None):
+            return
+        if not rule['protocol']:
+            raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
+        ip_proto = self._get_ip_proto_number(rule['protocol'])
+        if ip_proto in [constants.TCP_PROTOCOL, constants.UDP_PROTOCOL]:
+            if (rule['port_range_min'] is not None and
+                rule['port_range_min'] <= rule['port_range_max']):
+                pass
+            else:
+                raise ext_sg.SecurityGroupInvalidPortRange()
+        elif ip_proto == constants.ICMP_PROTOCOL:
+            for attr, field in [('port_range_min', 'type'),
+                                ('port_range_max', 'code')]:
+                if rule[attr] > 255:
+                    raise ext_sg.SecurityGroupInvalidIcmpValue(
+                        field=field, attr=attr, value=rule[attr])
+
     def _validate_security_group_rules(self, context, security_group_rule):
         """Check that rules being installed.
 
@@ -297,16 +329,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
             rule = rules.get('security_group_rule')
             new_rules.add(rule['security_group_id'])
 
-            # Check that port_range's are valid
-            if (rule['port_range_min'] is None and
-                rule['port_range_max'] is None):
-                pass
-            elif (rule['port_range_min'] is not None and
-                  rule['port_range_min'] <= rule['port_range_max']):
-                if not rule['protocol']:
-                    raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
-            else:
-                raise ext_sg.SecurityGroupInvalidPortRange()
+            self._validate_port_range(rule)
 
             if rule['remote_ip_prefix'] and rule['remote_group_id']:
                 raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()
index 9fd4c95788649a36f26775648a8d005c22d1bd86..ebc1f780bc4a558aa96aef4f3a2e3fcffdf5323e 100644 (file)
@@ -39,6 +39,11 @@ class SecurityGroupInvalidPortValue(qexception.InvalidInput):
     message = _("Invalid value for port %(port)s")
 
 
+class SecurityGroupInvalidIcmpValue(qexception.InvalidInput):
+    message = _("Invalid value for ICMP %(field)s (%(attr)s) "
+                "%(value)s. It must be 0 to 255.")
+
+
 class SecurityGroupInUse(qexception.InUse):
     message = _("Security Group %(id)s in use.")
 
index a1df601cbacb26303d9a7abd0c6863ba7a679241..a0d3979638478ba028ef72e8760850e8e050907a 100644 (file)
@@ -633,6 +633,57 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
                     for k, v, in keys:
                         self.assertEqual(rule['security_group_rule'][k], v)
 
+    def test_create_security_group_rule_icmp_with_type_and_code(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            direction = "ingress"
+            remote_ip_prefix = "10.0.0.0/24"
+            protocol = 'icmp'
+            # port_range_min (ICMP type) is greater than port_range_max
+            # (ICMP code) in order to confirm min <= max port check is
+            # not called for ICMP.
+            port_range_min = 8
+            port_range_max = 5
+            keys = [('remote_ip_prefix', remote_ip_prefix),
+                    ('security_group_id', security_group_id),
+                    ('direction', direction),
+                    ('protocol', protocol),
+                    ('port_range_min', port_range_min),
+                    ('port_range_max', port_range_max)]
+            with self.security_group_rule(security_group_id, direction,
+                                          protocol, port_range_min,
+                                          port_range_max,
+                                          remote_ip_prefix) as rule:
+                for k, v, in keys:
+                    self.assertEqual(rule['security_group_rule'][k], v)
+
+    def test_create_security_group_rule_icmp_with_type_only(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            direction = "ingress"
+            remote_ip_prefix = "10.0.0.0/24"
+            protocol = 'icmp'
+            # ICMP type
+            port_range_min = 8
+            # ICMP code
+            port_range_max = None
+            keys = [('remote_ip_prefix', remote_ip_prefix),
+                    ('security_group_id', security_group_id),
+                    ('direction', direction),
+                    ('protocol', protocol),
+                    ('port_range_min', port_range_min),
+                    ('port_range_max', port_range_max)]
+            with self.security_group_rule(security_group_id, direction,
+                                          protocol, port_range_min,
+                                          port_range_max,
+                                          remote_ip_prefix) as rule:
+                for k, v, in keys:
+                    self.assertEqual(rule['security_group_rule'][k], v)
+
     def test_create_security_group_source_group_ip_and_ip_prefix(self):
         security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087"
         direction = "ingress"
@@ -752,26 +803,80 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
                 self.assertEqual(res.status_int, 409)
 
     def test_create_security_group_rule_min_port_greater_max(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            with self.security_group_rule(security_group_id):
+                for protocol in ['tcp', 'udp', 6, 17]:
+                    rule = self._build_security_group_rule(
+                        sg['security_group']['id'],
+                        'ingress', protocol, '50', '22')
+                    self._create_security_group_rule(self.fmt, rule)
+                    res = self._create_security_group_rule(self.fmt, rule)
+                    self.deserialize(self.fmt, res)
+                    self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_ports_but_no_protocol(self):
         name = 'webservers'
         description = 'my webservers'
         with self.security_group(name, description) as sg:
             security_group_id = sg['security_group']['id']
             with self.security_group_rule(security_group_id):
                 rule = self._build_security_group_rule(
-                    sg['security_group']['id'], 'ingress', 'tcp', '50', '22')
+                    sg['security_group']['id'], 'ingress', None, '22', '22')
                 self._create_security_group_rule(self.fmt, rule)
                 res = self._create_security_group_rule(self.fmt, rule)
                 self.deserialize(self.fmt, res)
                 self.assertEqual(res.status_int, 400)
 
-    def test_create_security_group_rule_ports_but_no_protocol(self):
+    def test_create_security_group_rule_port_range_min_only(self):
         name = 'webservers'
         description = 'my webservers'
         with self.security_group(name, description) as sg:
             security_group_id = sg['security_group']['id']
             with self.security_group_rule(security_group_id):
                 rule = self._build_security_group_rule(
-                    sg['security_group']['id'], 'ingress', None, '22', '22')
+                    sg['security_group']['id'], 'ingress', 'tcp', '22', None)
+                self._create_security_group_rule(self.fmt, rule)
+                res = self._create_security_group_rule(self.fmt, rule)
+                self.deserialize(self.fmt, res)
+                self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_port_range_max_only(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            with self.security_group_rule(security_group_id):
+                rule = self._build_security_group_rule(
+                    sg['security_group']['id'], 'ingress', 'tcp', None, '22')
+                self._create_security_group_rule(self.fmt, rule)
+                res = self._create_security_group_rule(self.fmt, rule)
+                self.deserialize(self.fmt, res)
+                self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_icmp_type_too_big(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            with self.security_group_rule(security_group_id):
+                rule = self._build_security_group_rule(
+                    sg['security_group']['id'], 'ingress', 'icmp', '256', None)
+                self._create_security_group_rule(self.fmt, rule)
+                res = self._create_security_group_rule(self.fmt, rule)
+                self.deserialize(self.fmt, res)
+                self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_icmp_code_too_big(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            with self.security_group_rule(security_group_id):
+                rule = self._build_security_group_rule(
+                    sg['security_group']['id'], 'ingress', 'icmp', '8', '256')
                 self._create_security_group_rule(self.fmt, rule)
                 res = self._create_security_group_rule(self.fmt, rule)
                 self.deserialize(self.fmt, res)