]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Make protocol and ethertype case insensitive for security groups
authorAaron Rosen <arosen@nicira.com>
Thu, 24 Jan 2013 23:45:04 +0000 (15:45 -0800)
committerAaron Rosen <arosen@nicira.com>
Tue, 29 Jan 2013 16:47:58 +0000 (08:47 -0800)
Fixes bug 1104495

Change-Id: I0d93f5e849ebe0be72fff8c1d82f5825540df338

quantum/extensions/securitygroup.py
quantum/tests/unit/test_extension_security_group.py

index 488890abe5145bc4297e14b95b06add4dd672643..c48e7f890e400c4980bc7346ab265a5fa0e70748 100644 (file)
@@ -54,9 +54,9 @@ class SecurityGroupDefaultAlreadyExists(qexception.InUse):
     message = _("Default security group already exists.")
 
 
-class SecurityGroupRuleInvalidProtocol(qexception.InUse):
-    message = _("Security group rule protocol %(protocol)s not supported "
-                "only protocol values %(values)s supported.")
+class SecurityGroupRuleInvalidProtocol(qexception.InvalidInput):
+    message = _("Security group rule protocol %(protocol)s not supported. "
+                "Only protocol values %(values)s supported.")
 
 
 class SecurityGroupRulesNotSingleTenant(qexception.InvalidInput):
@@ -114,6 +114,23 @@ class SecurityGroupInvalidExternalID(qexception.InvalidInput):
     message = _("external_id wrong type %(data)s")
 
 
+def convert_protocol_to_case_insensitive(value):
+    if value is None:
+        return value
+    try:
+        return value.lower()
+    except AttributeError:
+        raise SecurityGroupRuleInvalidProtocol(
+            protocol=value, values=sg_supported_protocols)
+
+
+def convert_ethertype_to_case_insensitive(value):
+    if isinstance(value, basestring):
+        for ethertype in sg_supported_ethertypes:
+            if ethertype.lower() == value.lower():
+                return ethertype
+
+
 def convert_validate_port_value(port):
     if port is None:
         return port
@@ -199,6 +216,7 @@ RESOURCE_ATTRIBUTE_MAP = {
                       'validate': {'type:values': ['ingress', 'egress']}},
         'protocol': {'allow_post': True, 'allow_put': False,
                      'is_visible': True, 'default': None,
+                     'convert_to': convert_protocol_to_case_insensitive,
                      'validate': {'type:values': sg_supported_protocols}},
         'port_range_min': {'allow_post': True, 'allow_put': False,
                            'convert_to': convert_validate_port_value,
@@ -208,6 +226,7 @@ RESOURCE_ATTRIBUTE_MAP = {
                            'default': None, 'is_visible': True},
         'ethertype': {'allow_post': True, 'allow_put': False,
                       'is_visible': True, 'default': 'IPv4',
+                      'convert_to': convert_ethertype_to_case_insensitive,
                       'validate': {'type:values': sg_supported_ethertypes}},
         'source_ip_prefix': {'allow_post': True, 'allow_put': False,
                              'default': None, 'is_visible': True},
index 879bee8c2a7a7a81786a6508bfc8d44e053b3274..f24c01b456bf0a38914a90a336961c13603ba928 100644 (file)
@@ -287,6 +287,55 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
                 else:
                     self.assertEquals(len(group['security_group_rules']), 0)
 
+    def test_create_security_group_rule_ethertype_invalid_as_number(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            ethertype = 2
+            rule = self._build_security_group_rule(
+                security_group_id, 'ingress', 'tcp', '22', '22', None, None,
+                ethertype=ethertype)
+            res = self._create_security_group_rule('json', rule)
+            self.deserialize('json', res)
+            self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_protocol_invalid_as_number(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            protocol = 2
+            rule = self._build_security_group_rule(
+                security_group_id, 'ingress', protocol, '22', '22',
+                None, None)
+            res = self._create_security_group_rule('json', rule)
+            self.deserialize('json', res)
+            self.assertEqual(res.status_int, 400)
+
+    def test_create_security_group_rule_case_insensitive(self):
+        name = 'webservers'
+        description = 'my webservers'
+        with self.security_group(name, description) as sg:
+            security_group_id = sg['security_group']['id']
+            direction = "ingress"
+            source_ip_prefix = "10.0.0.0/24"
+            protocol = 'TCP'
+            port_range_min = 22
+            port_range_max = 22
+            ethertype = 'ipV4'
+            with self.security_group_rule(security_group_id, direction,
+                                          protocol, port_range_min,
+                                          port_range_max,
+                                          source_ip_prefix,
+                                          ethertype=ethertype) as rule:
+
+                # the lower case value will be return
+                self.assertEquals(rule['security_group_rule']['protocol'],
+                                  protocol.lower())
+                self.assertEquals(rule['security_group_rule']['ethertype'],
+                                  'IPv4')
+
     def test_get_security_group(self):
         name = 'webservers'
         description = 'my webservers'