From: JUN JIE NAN Date: Thu, 1 Aug 2013 03:14:12 +0000 (+0800) Subject: HOT parameter validator part X-Git-Tag: 2014.1~244^2 X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=46041586891cd900e4c3931643313c8f73039dd3;p=openstack-build%2Fheat-build.git HOT parameter validator part Introduce new class ParamSchema to decouple parameter and its schema, do validation in schema, for hot template, implement HotParamSchema as it subclass. Implements blueprint hot-parameters Change-Id: I8720c62e41a1f182584c4518163d422784c23b37 --- diff --git a/heat/engine/hot.py b/heat/engine/hot.py index 5dbfa70d..caed57ec 100644 --- a/heat/engine/hot.py +++ b/heat/engine/hot.py @@ -14,6 +14,7 @@ from heat.common import exception from heat.engine import template +from heat.engine.parameters import ParamSchema from heat.openstack.common import log as logging @@ -32,6 +33,10 @@ _CFN_TO_HOT_SECTIONS = {template.VERSION: VERSION, template.OUTPUTS: OUTPUTS} +def snake_to_camel(name): + return ''.join([t.capitalize() for t in name.split('_')]) + + class HOTemplate(template.Template): """ A Heat Orchestration Template format stack template. @@ -81,15 +86,6 @@ class HOTemplate(template.Template): return default - @staticmethod - def _snake_to_camel(name): - tokens = [] - if name: - tokens = name.split('_') - for i in xrange(len(tokens)): - tokens[i] = tokens[i].capitalize() - return "".join(tokens) - def _translate_constraints(self, constraints): param = {} @@ -109,7 +105,7 @@ class HOTemplate(template.Template): for constraint in constraints: desc = constraint.get('description') for key, val in constraint.iteritems(): - key = self._snake_to_camel(key) + key = snake_to_camel(key) if key == 'Description': continue elif key == 'Range': @@ -127,9 +123,9 @@ class HOTemplate(template.Template): for name, attrs in parameters.iteritems(): param = {} for key, val in attrs.iteritems(): - key = self._snake_to_camel(key) + key = snake_to_camel(key) if key == 'Type': - val = self._snake_to_camel(val) + val = snake_to_camel(val) elif key == 'Constraints': param.update(self._translate_constraints(val)) continue @@ -221,3 +217,18 @@ class HOTemplate(template.Template): key=att) return template._resolve(match_get_attr, handle_get_attr, s) + + def param_schemata(self): + params = self[PARAMETERS].iteritems() + return dict((name, HOTParamSchema(schema)) for name, schema in params) + + +class HOTParamSchema(ParamSchema): + def do_check(self, name, val, keys): + for key in keys: + consts = self.get(key) + check = self.check(key) + if consts is None or check is None: + continue + for (const, desc) in consts: + check(name, val, const, desc) diff --git a/heat/engine/parameters.py b/heat/engine/parameters.py index be3cba32..95c4b935 100644 --- a/heat/engine/parameters.py +++ b/heat/engine/parameters.py @@ -18,7 +18,7 @@ import json import re from heat.common import exception -from heat.engine import template + PARAMETER_KEYS = ( TYPE, DEFAULT, NO_ECHO, ALLOWED_VALUES, ALLOWED_PATTERN, @@ -41,6 +41,76 @@ PSEUDO_PARAMETERS = ( ) +class ParamSchema(dict): + '''Parameter schema.''' + + def __init__(self, schema): + super(ParamSchema, self).__init__(schema) + + def do_check(self, name, value, keys): + for k in keys: + check = self.check(k) + const = self.get(k) + if check is None or const is None: + continue + check(name, value, const) + + def raise_error(self, name, message, desc=True): + if desc: + message = self.get(CONSTRAINT_DESCRIPTION) or message + raise ValueError('%s %s' % (name, message)) + + def check_allowed_values(self, name, val, const, desc=None): + vals = list(const) + if val not in vals: + err = '"%s" not in %s "%s"' % (val, ALLOWED_VALUES, vals) + self.raise_error(name, desc or err) + + def check_allowed_pattern(self, name, val, p, desc=None): + m = re.match(p, val) + if m is None or m.end() != len(val): + err = '"%s" does not match %s "%s"' % (val, ALLOWED_PATTERN, p) + self.raise_error(name, desc or err) + + def check_max_length(self, name, val, const, desc=None): + max_len = int(const) + val_len = len(val) + if val_len > max_len: + err = 'length (%d) overflows %s (%d)' % (val_len, + MAX_LENGTH, max_len) + self.raise_error(name, desc or err) + + def check_min_length(self, name, val, const, desc=None): + min_len = int(const) + val_len = len(val) + if val_len < min_len: + err = 'length (%d) underflows %s (%d)' % (val_len, + MIN_LENGTH, min_len) + self.raise_error(name, desc or err) + + def check_max_value(self, name, val, const, desc=None): + max_val = float(const) + val = float(val) + if val > max_val: + err = '%d overflows %s %d' % (val, MAX_VALUE, max_val) + self.raise_error(name, desc or err) + + def check_min_value(self, name, val, const, desc=None): + min_val = float(const) + val = float(val) + if val < min_val: + err = '%d underflows %s %d' % (val, MIN_VALUE, min_val) + self.raise_error(name, desc or err) + + def check(self, const_key): + return {ALLOWED_VALUES: self.check_allowed_values, + ALLOWED_PATTERN: self.check_allowed_pattern, + MAX_LENGTH: self.check_max_length, + MIN_LENGTH: self.check_min_length, + MAX_VALUE: self.check_max_value, + MIN_VALUE: self.check_min_value}.get(const_key) + + class Parameter(object): '''A template parameter.''' @@ -71,27 +141,15 @@ class Parameter(object): self.name = name self.schema = schema self.user_value = value - self._constraint_error = self.schema.get(CONSTRAINT_DESCRIPTION) - - if self.has_default(): - self._validate(self.default()) if validate_value: + if self.has_default(): + self.validate(self.default()) if self.user_value is not None: - self._validate(self.user_value) + self.validate(self.user_value) elif not self.has_default(): raise exception.UserParameterMissing(key=self.name) - def _error_msg(self, message): - return '%s %s' % (self.name, self._constraint_error or message) - - def _validate(self, value): - if ALLOWED_VALUES in self.schema: - allowed = list(self.schema[ALLOWED_VALUES]) - if value not in allowed: - message = '%s not in %s %s' % (value, ALLOWED_VALUES, allowed) - raise ValueError(self._error_msg(message)) - def value(self): '''Get the parameter value, optionally sanitising it for output.''' if self.user_value is not None: @@ -133,25 +191,6 @@ class Parameter(object): class NumberParam(Parameter): '''A template parameter of type "Number".''' - @staticmethod - def str_to_num(s): - '''Convert a string to an integer (if possible) or float.''' - try: - return int(s) - except ValueError: - return float(s) - - def _validate(self, value): - '''Check that the supplied value is compatible with the constraints.''' - num = self.str_to_num(value) - minn = self.str_to_num(self.schema.get(MIN_VALUE, value)) - maxn = self.str_to_num(self.schema.get(MAX_VALUE, value)) - - if num > maxn or num < minn: - raise ValueError(self._error_msg('%s is out of range' % value)) - - Parameter._validate(self, value) - def __int__(self): '''Return an integer representation of the parameter''' return int(self.value()) @@ -160,105 +199,81 @@ class NumberParam(Parameter): '''Return a float representation of the parameter''' return float(self.value()) + def validate(self, val): + self.schema.do_check(self.name, val, [ALLOWED_VALUES, + MAX_VALUE, MIN_VALUE]) + class StringParam(Parameter): '''A template parameter of type "String".''' - def _validate(self, value): - '''Check that the supplied value is compatible with the constraints.''' - if not isinstance(value, basestring): - raise ValueError(self._error_msg('value must be a string')) - - length = len(value) - if MAX_LENGTH in self.schema: - max_length = int(self.schema[MAX_LENGTH]) - if length > max_length: - message = 'length (%d) overflows %s %s' % (length, - MAX_LENGTH, - max_length) - raise ValueError(self._error_msg(message)) - - if MIN_LENGTH in self.schema: - min_length = int(self.schema[MIN_LENGTH]) - if length < min_length: - message = 'length (%d) underflows %s %d' % (length, - MIN_LENGTH, - min_length) - raise ValueError(self._error_msg(message)) - - if ALLOWED_PATTERN in self.schema: - pattern = self.schema[ALLOWED_PATTERN] - match = re.match(pattern, value) - if match is None or match.end() != length: - message = '"%s" does not match %s "%s"' % (value, - ALLOWED_PATTERN, - pattern) - raise ValueError(self._error_msg(message)) - - Parameter._validate(self, value) + def validate(self, val): + self.schema.do_check(self.name, val, + [ALLOWED_VALUES, + ALLOWED_PATTERN, MAX_LENGTH, MIN_LENGTH]) class CommaDelimitedListParam(Parameter, collections.Sequence): '''A template parameter of type "CommaDelimitedList".''' - def _validate(self, value): - '''Check that the supplied value is compatible with the constraints.''' - try: - value.split(',') - except AttributeError: - raise ValueError('Value must be a comma-delimited list string') + def __init__(self, name, schema, value=None, validate_value=True): + super(CommaDelimitedListParam, self).__init__(name, schema, value, + validate_value) + self.parsed = self.parse(self.user_value or self.default()) - for li in self: - Parameter._validate(self, li) + def parse(self, value): + try: + if value: + return value.split(',') + except (KeyError, AttributeError) as err: + message = 'Value must be a comma-delimited list string: %s' + raise ValueError(message % str(err)) + return value def __len__(self): '''Return the length of the list.''' - return len(self.value().split(',')) + return len(self.parsed) def __getitem__(self, index): '''Return an item from the list.''' - return self.value().split(',')[index] + return self.parsed[index] + + def validate(self, val): + parsed = self.parse(val) + for val in parsed: + self.schema.do_check(self.name, val, [ALLOWED_VALUES]) class JsonParam(Parameter, collections.Mapping): """A template parameter who's value is valid map.""" - def _validate(self, value): - message = 'Value must be valid JSON' - if isinstance(value, collections.Mapping): + def __init__(self, name, schema, value=None, validate_value=True): + super(JsonParam, self).__init__(name, schema, value, + validate_value) + self.parsed = self.parse(self.user_value or self.default()) + + def parse(self, value): + try: + val = value + if isinstance(val, collections.Mapping): + val = json.dumps(val) + if val: + return json.loads(val) + except (ValueError, TypeError) as err: + message = 'Value must be valid JSON: %s' % str(err) + raise ValueError(message) + return value + + def value(self): + val = super(JsonParam, self).value() + if isinstance(val, collections.Mapping): try: - self.user_value = json.dumps(value) + val = json.dumps(val) + self.user_value = val except (ValueError, TypeError) as err: + message = 'Value must be valid JSON' raise ValueError("%s: %s" % (message, str(err))) - self.parsed = value - else: - try: - self.parsed = json.loads(value) - except ValueError: - raise ValueError(message) - - # check length - my_len = len(self.parsed) - if MAX_LENGTH in self.schema: - max_length = int(self.schema[MAX_LENGTH]) - if my_len > max_length: - message = ('value length (%d) overflows %s %s' - % (my_len, MAX_LENGTH, max_length)) - raise ValueError(self._error_msg(message)) - if MIN_LENGTH in self.schema: - min_length = int(self.schema[MIN_LENGTH]) - if my_len < min_length: - message = ('value length (%d) underflows %s %s' - % (my_len, MIN_LENGTH, min_length)) - raise ValueError(self._error_msg(message)) - # check valid keys - if ALLOWED_VALUES in self.schema: - allowed = list(self.schema[ALLOWED_VALUES]) - bad_keys = [k for k in self.parsed if k not in allowed] - if bad_keys: - message = ('keys %s are not in %s %s' - % (bad_keys, ALLOWED_VALUES, allowed)) - raise ValueError(self._error_msg(message)) + return val def __getitem__(self, key): return self.parsed[key] @@ -269,6 +284,12 @@ class JsonParam(Parameter, collections.Mapping): def __len__(self): return len(self.parsed) + def validate(self, val): + val = self.parse(val) + self.schema.do_check(self.name, val, [MAX_LENGTH, MIN_LENGTH]) + for key in val: + self.schema.do_check(self.name, key, [ALLOWED_VALUES]) + class Parameters(collections.Mapping): ''' @@ -283,27 +304,30 @@ class Parameters(collections.Mapping): ''' def parameters(): yield Parameter(PARAM_STACK_ID, - {TYPE: STRING, - DESCRIPTION: 'Stack ID', - DEFAULT: str(stack_id)}) + ParamSchema({TYPE: STRING, + DESCRIPTION: 'Stack ID', + DEFAULT: str(stack_id)})) if stack_name is not None: yield Parameter(PARAM_STACK_NAME, - {TYPE: STRING, - DESCRIPTION: 'Stack Name', - DEFAULT: stack_name}) + ParamSchema({TYPE: STRING, + DESCRIPTION: 'Stack Name', + DEFAULT: stack_name})) yield Parameter(PARAM_REGION, - {TYPE: STRING, - DEFAULT: 'ap-southeast-1', - ALLOWED_VALUES: ['us-east-1', - 'us-west-1', 'us-west-2', - 'sa-east-1', - 'eu-west-1', - 'ap-southeast-1', - 'ap-northeast-1']}) - - for name, schema in tmpl[template.PARAMETERS].iteritems(): - yield Parameter(name, schema, user_params.get(name), - validate_value) + ParamSchema({TYPE: STRING, + DEFAULT: 'ap-southeast-1', + ALLOWED_VALUES: + ['us-east-1', + 'us-west-1', + 'us-west-2', + 'sa-east-1', + 'eu-west-1', + 'ap-southeast-1', + 'ap-northeast-1']})) + + schemata = self.tmpl.param_schemata().iteritems() + for name, schema in schemata: + value = user_params.get(name) + yield Parameter(name, schema, value, validate_value) self.tmpl = tmpl self._validate(user_params) @@ -340,6 +364,7 @@ class Parameters(collections.Mapping): self.params[PARAM_STACK_ID].schema[DEFAULT] = stack_id def _validate(self, user_params): + schemata = self.tmpl.param_schemata() for param in user_params: - if param not in self.tmpl[template.PARAMETERS]: + if param not in schemata: raise exception.UnknownUserParameter(key=param) diff --git a/heat/engine/template.py b/heat/engine/template.py index 93d17571..5106f18d 100644 --- a/heat/engine/template.py +++ b/heat/engine/template.py @@ -18,7 +18,7 @@ import json from heat.db import api as db_api from heat.common import exception - +from heat.engine.parameters import ParamSchema SECTIONS = (VERSION, DESCRIPTION, MAPPINGS, PARAMETERS, RESOURCES, OUTPUTS) = \ @@ -386,7 +386,8 @@ class Template(collections.Mapping): s) def param_schemata(self): - return self[PARAMETERS] + parameters = self[PARAMETERS].iteritems() + return dict((name, ParamSchema(schema)) for name, schema in parameters) def _resolve(match, handle, snippet): diff --git a/heat/tests/test_hot.py b/heat/tests/test_hot.py index f7f759e1..fd6a8548 100644 --- a/heat/tests/test_hot.py +++ b/heat/tests/test_hot.py @@ -15,6 +15,7 @@ from heat.common import template_format from heat.common import exception from heat.engine import parser from heat.engine import hot +from heat.engine import parameters from heat.engine import template from heat.tests.common import HeatTestCase @@ -290,3 +291,102 @@ class StackTest(test_parser.StackTest): hot.HOTemplate.resolve_attributes, {'Value': {'get_attr': ['resource1', 'NotThere']}}, self.stack) + + +class HOTParamValidatorTest(HeatTestCase): + "Test HOTParamValidator" + + def test_multiple_constraint_descriptions(self): + len_desc = 'string length should between 8 and 16' + pattern_desc1 = 'Value must consist of characters only' + pattern_desc2 = 'Value must start with a lowercase character' + param = { + 'db_name': { + 'Description': 'The WordPress database name', + 'Type': 'String', + 'Default': 'wordpress', + 'MinLength': [(8, len_desc)], + 'MaxLength': [(16, len_desc)], + 'AllowedPattern': [ + ('[a-zA-Z]+', pattern_desc1), + ('[a-z]+[a-zA-Z]*', pattern_desc2)]}} + + name = 'db_name' + schema = param['db_name'] + + def v(value): + hot.HOTParamSchema(schema).do_check(name, value, + [parameters.ALLOWED_VALUES, + parameters.ALLOWED_PATTERN, + parameters.MAX_LENGTH, + parameters.MIN_LENGTH]) + return True + + value = 'wp' + err = self.assertRaises(ValueError, v, value) + self.assertIn(len_desc, str(err)) + + value = 'abcdefghijklmnopq' + err = self.assertRaises(ValueError, v, value) + self.assertIn(len_desc, str(err)) + + value = 'abcdefgh1' + err = self.assertRaises(ValueError, v, value) + self.assertIn(pattern_desc1, str(err)) + + value = 'Abcdefghi' + err = self.assertRaises(ValueError, v, value) + self.assertIn(pattern_desc2, str(err)) + + value = 'abcdefghi' + self.assertTrue(v(value)) + + value = 'abcdefghI' + self.assertTrue(v(value)) + + def test_hot_template_validate_param(self): + len_desc = 'string length should between 8 and 16' + pattern_desc1 = 'Value must consist of characters only' + pattern_desc2 = 'Value must start with a lowercase character' + hot_tpl = template_format.parse(''' + heat_template_version: 2013-05-23 + parameters: + db_name: + description: The WordPress database name + type: string + default: wordpress + constraints: + - length: { min: 8, max: 16 } + description: %s + - allowed_pattern: "[a-zA-Z]+" + description: %s + - allowed_pattern: "[a-z]+[a-zA-Z]*" + description: %s + ''' % (len_desc, pattern_desc1, pattern_desc2)) + tmpl = parser.Template(hot_tpl) + + def run_parameters(value): + parameters.Parameters("stack_testit", tmpl, {'db_name': value}) + return True + + value = 'wp' + err = self.assertRaises(ValueError, run_parameters, value) + self.assertIn(len_desc, str(err)) + + value = 'abcdefghijklmnopq' + err = self.assertRaises(ValueError, run_parameters, value) + self.assertIn(len_desc, str(err)) + + value = 'abcdefgh1' + err = self.assertRaises(ValueError, run_parameters, value) + self.assertIn(pattern_desc1, str(err)) + + value = 'Abcdefghi' + err = self.assertRaises(ValueError, run_parameters, value) + self.assertIn(pattern_desc2, str(err)) + + value = 'abcdefghi' + self.assertTrue(run_parameters(value)) + + value = 'abcdefghI' + self.assertTrue(run_parameters(value))