]> review.fuel-infra Code Review - openstack-build/heat-build.git/commitdiff
HOT parameter validator part
authorJUN JIE NAN <nanjj@cn.ibm.com>
Thu, 1 Aug 2013 03:14:12 +0000 (11:14 +0800)
committerJUN JIE NAN <nanjj@cn.ibm.com>
Wed, 7 Aug 2013 06:09:41 +0000 (14:09 +0800)
Introduce new class ParamSchema to decouple parameter and its schema,
do validation in schema, for hot template, implement HotParamSchema as
it subclass.

Implements blueprint hot-parameters

Change-Id: I8720c62e41a1f182584c4518163d422784c23b37

heat/engine/hot.py
heat/engine/parameters.py
heat/engine/template.py
heat/tests/test_hot.py

index 5dbfa70df54f441c33730ab9e2807bde565ced72..caed57ec476042757ab959fd4241a695f8bd20fe 100644 (file)
@@ -14,6 +14,7 @@
 
 from heat.common import exception
 from heat.engine import template
+from heat.engine.parameters import ParamSchema
 from heat.openstack.common import log as logging
 
 
@@ -32,6 +33,10 @@ _CFN_TO_HOT_SECTIONS = {template.VERSION: VERSION,
                         template.OUTPUTS: OUTPUTS}
 
 
+def snake_to_camel(name):
+    return ''.join([t.capitalize() for t in name.split('_')])
+
+
 class HOTemplate(template.Template):
     """
     A Heat Orchestration Template format stack template.
@@ -81,15 +86,6 @@ class HOTemplate(template.Template):
 
         return default
 
-    @staticmethod
-    def _snake_to_camel(name):
-        tokens = []
-        if name:
-            tokens = name.split('_')
-            for i in xrange(len(tokens)):
-                tokens[i] = tokens[i].capitalize()
-        return "".join(tokens)
-
     def _translate_constraints(self, constraints):
         param = {}
 
@@ -109,7 +105,7 @@ class HOTemplate(template.Template):
         for constraint in constraints:
             desc = constraint.get('description')
             for key, val in constraint.iteritems():
-                key = self._snake_to_camel(key)
+                key = snake_to_camel(key)
                 if key == 'Description':
                     continue
                 elif key == 'Range':
@@ -127,9 +123,9 @@ class HOTemplate(template.Template):
         for name, attrs in parameters.iteritems():
             param = {}
             for key, val in attrs.iteritems():
-                key = self._snake_to_camel(key)
+                key = snake_to_camel(key)
                 if key == 'Type':
-                    val = self._snake_to_camel(val)
+                    val = snake_to_camel(val)
                 elif key == 'Constraints':
                     param.update(self._translate_constraints(val))
                     continue
@@ -221,3 +217,18 @@ class HOTemplate(template.Template):
                                                          key=att)
 
         return template._resolve(match_get_attr, handle_get_attr, s)
+
+    def param_schemata(self):
+        params = self[PARAMETERS].iteritems()
+        return dict((name, HOTParamSchema(schema)) for name, schema in params)
+
+
+class HOTParamSchema(ParamSchema):
+    def do_check(self, name, val, keys):
+        for key in keys:
+            consts = self.get(key)
+            check = self.check(key)
+            if consts is None or check is None:
+                continue
+            for (const, desc) in consts:
+                check(name, val, const, desc)
index be3cba32219cf5ee89228158134b2c55fa375dd6..95c4b935e305dfe293917172fa4ddf9a82074298 100644 (file)
@@ -18,7 +18,7 @@ import json
 import re
 
 from heat.common import exception
-from heat.engine import template
+
 
 PARAMETER_KEYS = (
     TYPE, DEFAULT, NO_ECHO, ALLOWED_VALUES, ALLOWED_PATTERN,
@@ -41,6 +41,76 @@ PSEUDO_PARAMETERS = (
 )
 
 
+class ParamSchema(dict):
+    '''Parameter schema.'''
+
+    def __init__(self, schema):
+        super(ParamSchema, self).__init__(schema)
+
+    def do_check(self, name, value, keys):
+        for k in keys:
+            check = self.check(k)
+            const = self.get(k)
+            if check is None or const is None:
+                continue
+            check(name, value, const)
+
+    def raise_error(self, name, message, desc=True):
+        if desc:
+            message = self.get(CONSTRAINT_DESCRIPTION) or message
+        raise ValueError('%s %s' % (name, message))
+
+    def check_allowed_values(self, name, val, const, desc=None):
+        vals = list(const)
+        if val not in vals:
+            err = '"%s" not in %s "%s"' % (val, ALLOWED_VALUES, vals)
+            self.raise_error(name, desc or err)
+
+    def check_allowed_pattern(self, name, val, p, desc=None):
+        m = re.match(p, val)
+        if m is None or m.end() != len(val):
+            err = '"%s" does not match %s "%s"' % (val, ALLOWED_PATTERN, p)
+            self.raise_error(name, desc or err)
+
+    def check_max_length(self, name, val, const, desc=None):
+        max_len = int(const)
+        val_len = len(val)
+        if val_len > max_len:
+            err = 'length (%d) overflows %s (%d)' % (val_len,
+                                                     MAX_LENGTH, max_len)
+            self.raise_error(name, desc or err)
+
+    def check_min_length(self, name, val, const, desc=None):
+        min_len = int(const)
+        val_len = len(val)
+        if val_len < min_len:
+            err = 'length (%d) underflows %s (%d)' % (val_len,
+                                                      MIN_LENGTH, min_len)
+            self.raise_error(name, desc or err)
+
+    def check_max_value(self, name, val, const, desc=None):
+        max_val = float(const)
+        val = float(val)
+        if val > max_val:
+            err = '%d overflows %s %d' % (val, MAX_VALUE, max_val)
+            self.raise_error(name, desc or err)
+
+    def check_min_value(self, name, val, const, desc=None):
+        min_val = float(const)
+        val = float(val)
+        if val < min_val:
+            err = '%d underflows %s %d' % (val, MIN_VALUE, min_val)
+            self.raise_error(name, desc or err)
+
+    def check(self, const_key):
+        return {ALLOWED_VALUES: self.check_allowed_values,
+                ALLOWED_PATTERN: self.check_allowed_pattern,
+                MAX_LENGTH: self.check_max_length,
+                MIN_LENGTH: self.check_min_length,
+                MAX_VALUE: self.check_max_value,
+                MIN_VALUE: self.check_min_value}.get(const_key)
+
+
 class Parameter(object):
     '''A template parameter.'''
 
@@ -71,27 +141,15 @@ class Parameter(object):
         self.name = name
         self.schema = schema
         self.user_value = value
-        self._constraint_error = self.schema.get(CONSTRAINT_DESCRIPTION)
-
-        if self.has_default():
-            self._validate(self.default())
 
         if validate_value:
+            if self.has_default():
+                self.validate(self.default())
             if self.user_value is not None:
-                self._validate(self.user_value)
+                self.validate(self.user_value)
             elif not self.has_default():
                 raise exception.UserParameterMissing(key=self.name)
 
-    def _error_msg(self, message):
-        return '%s %s' % (self.name, self._constraint_error or message)
-
-    def _validate(self, value):
-        if ALLOWED_VALUES in self.schema:
-            allowed = list(self.schema[ALLOWED_VALUES])
-            if value not in allowed:
-                message = '%s not in %s %s' % (value, ALLOWED_VALUES, allowed)
-                raise ValueError(self._error_msg(message))
-
     def value(self):
         '''Get the parameter value, optionally sanitising it for output.'''
         if self.user_value is not None:
@@ -133,25 +191,6 @@ class Parameter(object):
 class NumberParam(Parameter):
     '''A template parameter of type "Number".'''
 
-    @staticmethod
-    def str_to_num(s):
-        '''Convert a string to an integer (if possible) or float.'''
-        try:
-            return int(s)
-        except ValueError:
-            return float(s)
-
-    def _validate(self, value):
-        '''Check that the supplied value is compatible with the constraints.'''
-        num = self.str_to_num(value)
-        minn = self.str_to_num(self.schema.get(MIN_VALUE, value))
-        maxn = self.str_to_num(self.schema.get(MAX_VALUE, value))
-
-        if num > maxn or num < minn:
-            raise ValueError(self._error_msg('%s is out of range' % value))
-
-        Parameter._validate(self, value)
-
     def __int__(self):
         '''Return an integer representation of the parameter'''
         return int(self.value())
@@ -160,105 +199,81 @@ class NumberParam(Parameter):
         '''Return a float representation of the parameter'''
         return float(self.value())
 
+    def validate(self, val):
+        self.schema.do_check(self.name, val, [ALLOWED_VALUES,
+                                              MAX_VALUE, MIN_VALUE])
+
 
 class StringParam(Parameter):
     '''A template parameter of type "String".'''
 
-    def _validate(self, value):
-        '''Check that the supplied value is compatible with the constraints.'''
-        if not isinstance(value, basestring):
-            raise ValueError(self._error_msg('value must be a string'))
-
-        length = len(value)
-        if MAX_LENGTH in self.schema:
-            max_length = int(self.schema[MAX_LENGTH])
-            if length > max_length:
-                message = 'length (%d) overflows %s %s' % (length,
-                                                           MAX_LENGTH,
-                                                           max_length)
-                raise ValueError(self._error_msg(message))
-
-        if MIN_LENGTH in self.schema:
-            min_length = int(self.schema[MIN_LENGTH])
-            if length < min_length:
-                message = 'length (%d) underflows %s %d' % (length,
-                                                            MIN_LENGTH,
-                                                            min_length)
-                raise ValueError(self._error_msg(message))
-
-        if ALLOWED_PATTERN in self.schema:
-            pattern = self.schema[ALLOWED_PATTERN]
-            match = re.match(pattern, value)
-            if match is None or match.end() != length:
-                message = '"%s" does not match %s "%s"' % (value,
-                                                           ALLOWED_PATTERN,
-                                                           pattern)
-                raise ValueError(self._error_msg(message))
-
-        Parameter._validate(self, value)
+    def validate(self, val):
+        self.schema.do_check(self.name, val,
+                             [ALLOWED_VALUES,
+                              ALLOWED_PATTERN, MAX_LENGTH, MIN_LENGTH])
 
 
 class CommaDelimitedListParam(Parameter, collections.Sequence):
     '''A template parameter of type "CommaDelimitedList".'''
 
-    def _validate(self, value):
-        '''Check that the supplied value is compatible with the constraints.'''
-        try:
-            value.split(',')
-        except AttributeError:
-            raise ValueError('Value must be a comma-delimited list string')
+    def __init__(self, name, schema, value=None, validate_value=True):
+        super(CommaDelimitedListParam, self).__init__(name, schema, value,
+                                                      validate_value)
+        self.parsed = self.parse(self.user_value or self.default())
 
-        for li in self:
-            Parameter._validate(self, li)
+    def parse(self, value):
+        try:
+            if value:
+                return value.split(',')
+        except (KeyError, AttributeError) as err:
+            message = 'Value must be a comma-delimited list string: %s'
+            raise ValueError(message % str(err))
+        return value
 
     def __len__(self):
         '''Return the length of the list.'''
-        return len(self.value().split(','))
+        return len(self.parsed)
 
     def __getitem__(self, index):
         '''Return an item from the list.'''
-        return self.value().split(',')[index]
+        return self.parsed[index]
+
+    def validate(self, val):
+        parsed = self.parse(val)
+        for val in parsed:
+            self.schema.do_check(self.name, val, [ALLOWED_VALUES])
 
 
 class JsonParam(Parameter, collections.Mapping):
     """A template parameter who's value is valid map."""
 
-    def _validate(self, value):
-        message = 'Value must be valid JSON'
-        if isinstance(value, collections.Mapping):
+    def __init__(self, name, schema, value=None, validate_value=True):
+        super(JsonParam, self).__init__(name, schema, value,
+                                        validate_value)
+        self.parsed = self.parse(self.user_value or self.default())
+
+    def parse(self, value):
+        try:
+            val = value
+            if isinstance(val, collections.Mapping):
+                val = json.dumps(val)
+            if val:
+                return json.loads(val)
+        except (ValueError, TypeError) as err:
+            message = 'Value must be valid JSON: %s' % str(err)
+            raise ValueError(message)
+        return value
+
+    def value(self):
+        val = super(JsonParam, self).value()
+        if isinstance(val, collections.Mapping):
             try:
-                self.user_value = json.dumps(value)
+                val = json.dumps(val)
+                self.user_value = val
             except (ValueError, TypeError) as err:
+                message = 'Value must be valid JSON'
                 raise ValueError("%s: %s" % (message, str(err)))
-            self.parsed = value
-        else:
-            try:
-                self.parsed = json.loads(value)
-            except ValueError:
-                raise ValueError(message)
-
-        # check length
-        my_len = len(self.parsed)
-        if MAX_LENGTH in self.schema:
-            max_length = int(self.schema[MAX_LENGTH])
-            if my_len > max_length:
-                message = ('value length (%d) overflows %s %s'
-                           % (my_len, MAX_LENGTH, max_length))
-                raise ValueError(self._error_msg(message))
-        if MIN_LENGTH in self.schema:
-            min_length = int(self.schema[MIN_LENGTH])
-            if my_len < min_length:
-                message = ('value length (%d) underflows %s %s'
-                           % (my_len, MIN_LENGTH, min_length))
-                raise ValueError(self._error_msg(message))
-        # check valid keys
-        if ALLOWED_VALUES in self.schema:
-            allowed = list(self.schema[ALLOWED_VALUES])
-            bad_keys = [k for k in self.parsed if k not in allowed]
-            if bad_keys:
-                message = ('keys %s are not in %s %s'
-                           % (bad_keys, ALLOWED_VALUES, allowed))
-                raise ValueError(self._error_msg(message))
+        return val
 
     def __getitem__(self, key):
         return self.parsed[key]
@@ -269,6 +284,12 @@ class JsonParam(Parameter, collections.Mapping):
     def __len__(self):
         return len(self.parsed)
 
+    def validate(self, val):
+        val = self.parse(val)
+        self.schema.do_check(self.name, val, [MAX_LENGTH, MIN_LENGTH])
+        for key in val:
+            self.schema.do_check(self.name, key, [ALLOWED_VALUES])
+
 
 class Parameters(collections.Mapping):
     '''
@@ -283,27 +304,30 @@ class Parameters(collections.Mapping):
         '''
         def parameters():
             yield Parameter(PARAM_STACK_ID,
-                            {TYPE: STRING,
-                             DESCRIPTION: 'Stack ID',
-                             DEFAULT: str(stack_id)})
+                            ParamSchema({TYPE: STRING,
+                                         DESCRIPTION: 'Stack ID',
+                                         DEFAULT: str(stack_id)}))
             if stack_name is not None:
                 yield Parameter(PARAM_STACK_NAME,
-                                {TYPE: STRING,
-                                 DESCRIPTION: 'Stack Name',
-                                 DEFAULT: stack_name})
+                                ParamSchema({TYPE: STRING,
+                                             DESCRIPTION: 'Stack Name',
+                                             DEFAULT: stack_name}))
                 yield Parameter(PARAM_REGION,
-                                {TYPE: STRING,
-                                 DEFAULT: 'ap-southeast-1',
-                                 ALLOWED_VALUES: ['us-east-1',
-                                                  'us-west-1', 'us-west-2',
-                                                  'sa-east-1',
-                                                  'eu-west-1',
-                                                  'ap-southeast-1',
-                                                  'ap-northeast-1']})
-
-            for name, schema in tmpl[template.PARAMETERS].iteritems():
-                yield Parameter(name, schema, user_params.get(name),
-                                validate_value)
+                                ParamSchema({TYPE: STRING,
+                                             DEFAULT: 'ap-southeast-1',
+                                             ALLOWED_VALUES:
+                                             ['us-east-1',
+                                              'us-west-1',
+                                              'us-west-2',
+                                              'sa-east-1',
+                                              'eu-west-1',
+                                              'ap-southeast-1',
+                                              'ap-northeast-1']}))
+
+            schemata = self.tmpl.param_schemata().iteritems()
+            for name, schema in schemata:
+                value = user_params.get(name)
+                yield Parameter(name, schema, value, validate_value)
 
         self.tmpl = tmpl
         self._validate(user_params)
@@ -340,6 +364,7 @@ class Parameters(collections.Mapping):
         self.params[PARAM_STACK_ID].schema[DEFAULT] = stack_id
 
     def _validate(self, user_params):
+        schemata = self.tmpl.param_schemata()
         for param in user_params:
-            if param not in self.tmpl[template.PARAMETERS]:
+            if param not in schemata:
                 raise exception.UnknownUserParameter(key=param)
index 93d17571a6967be65d6b29069a84b40b93d28625..5106f18d0950d791b35e47a0614cb4886f909bf1 100644 (file)
@@ -18,7 +18,7 @@ import json
 
 from heat.db import api as db_api
 from heat.common import exception
-
+from heat.engine.parameters import ParamSchema
 
 SECTIONS = (VERSION, DESCRIPTION, MAPPINGS,
             PARAMETERS, RESOURCES, OUTPUTS) = \
@@ -386,7 +386,8 @@ class Template(collections.Mapping):
                         s)
 
     def param_schemata(self):
-        return self[PARAMETERS]
+        parameters = self[PARAMETERS].iteritems()
+        return dict((name, ParamSchema(schema)) for name, schema in parameters)
 
 
 def _resolve(match, handle, snippet):
index f7f759e19512dfbce21d804cf218888a8c0243d9..fd6a85481c9e2adf60441987e93f8ac66a19c895 100644 (file)
@@ -15,6 +15,7 @@ from heat.common import template_format
 from heat.common import exception
 from heat.engine import parser
 from heat.engine import hot
+from heat.engine import parameters
 from heat.engine import template
 
 from heat.tests.common import HeatTestCase
@@ -290,3 +291,102 @@ class StackTest(test_parser.StackTest):
                           hot.HOTemplate.resolve_attributes,
                           {'Value': {'get_attr': ['resource1', 'NotThere']}},
                           self.stack)
+
+
+class HOTParamValidatorTest(HeatTestCase):
+    "Test HOTParamValidator"
+
+    def test_multiple_constraint_descriptions(self):
+        len_desc = 'string length should between 8 and 16'
+        pattern_desc1 = 'Value must consist of characters only'
+        pattern_desc2 = 'Value must start with a lowercase character'
+        param = {
+            'db_name': {
+                'Description': 'The WordPress database name',
+                'Type': 'String',
+                'Default': 'wordpress',
+                'MinLength': [(8, len_desc)],
+                'MaxLength': [(16, len_desc)],
+                'AllowedPattern': [
+                    ('[a-zA-Z]+', pattern_desc1),
+                    ('[a-z]+[a-zA-Z]*', pattern_desc2)]}}
+
+        name = 'db_name'
+        schema = param['db_name']
+
+        def v(value):
+            hot.HOTParamSchema(schema).do_check(name, value,
+                                                [parameters.ALLOWED_VALUES,
+                                                 parameters.ALLOWED_PATTERN,
+                                                 parameters.MAX_LENGTH,
+                                                 parameters.MIN_LENGTH])
+            return True
+
+        value = 'wp'
+        err = self.assertRaises(ValueError, v, value)
+        self.assertIn(len_desc, str(err))
+
+        value = 'abcdefghijklmnopq'
+        err = self.assertRaises(ValueError, v, value)
+        self.assertIn(len_desc, str(err))
+
+        value = 'abcdefgh1'
+        err = self.assertRaises(ValueError, v, value)
+        self.assertIn(pattern_desc1, str(err))
+
+        value = 'Abcdefghi'
+        err = self.assertRaises(ValueError, v, value)
+        self.assertIn(pattern_desc2, str(err))
+
+        value = 'abcdefghi'
+        self.assertTrue(v(value))
+
+        value = 'abcdefghI'
+        self.assertTrue(v(value))
+
+    def test_hot_template_validate_param(self):
+        len_desc = 'string length should between 8 and 16'
+        pattern_desc1 = 'Value must consist of characters only'
+        pattern_desc2 = 'Value must start with a lowercase character'
+        hot_tpl = template_format.parse('''
+        heat_template_version: 2013-05-23
+        parameters:
+          db_name:
+            description: The WordPress database name
+            type: string
+            default: wordpress
+            constraints:
+              - length: { min: 8, max: 16 }
+                description: %s
+              - allowed_pattern: "[a-zA-Z]+"
+                description: %s
+              - allowed_pattern: "[a-z]+[a-zA-Z]*"
+                description: %s
+        ''' % (len_desc, pattern_desc1, pattern_desc2))
+        tmpl = parser.Template(hot_tpl)
+
+        def run_parameters(value):
+            parameters.Parameters("stack_testit", tmpl, {'db_name': value})
+            return True
+
+        value = 'wp'
+        err = self.assertRaises(ValueError, run_parameters, value)
+        self.assertIn(len_desc, str(err))
+
+        value = 'abcdefghijklmnopq'
+        err = self.assertRaises(ValueError, run_parameters, value)
+        self.assertIn(len_desc, str(err))
+
+        value = 'abcdefgh1'
+        err = self.assertRaises(ValueError, run_parameters, value)
+        self.assertIn(pattern_desc1, str(err))
+
+        value = 'Abcdefghi'
+        err = self.assertRaises(ValueError, run_parameters, value)
+        self.assertIn(pattern_desc2, str(err))
+
+        value = 'abcdefghi'
+        self.assertTrue(run_parameters(value))
+
+        value = 'abcdefghI'
+        self.assertTrue(run_parameters(value))