import re
from heat.common import exception
-from heat.engine import template
+
PARAMETER_KEYS = (
TYPE, DEFAULT, NO_ECHO, ALLOWED_VALUES, ALLOWED_PATTERN,
)
+class ParamSchema(dict):
+ '''Parameter schema.'''
+
+ def __init__(self, schema):
+ super(ParamSchema, self).__init__(schema)
+
+ def do_check(self, name, value, keys):
+ for k in keys:
+ check = self.check(k)
+ const = self.get(k)
+ if check is None or const is None:
+ continue
+ check(name, value, const)
+
+ def raise_error(self, name, message, desc=True):
+ if desc:
+ message = self.get(CONSTRAINT_DESCRIPTION) or message
+ raise ValueError('%s %s' % (name, message))
+
+ def check_allowed_values(self, name, val, const, desc=None):
+ vals = list(const)
+ if val not in vals:
+ err = '"%s" not in %s "%s"' % (val, ALLOWED_VALUES, vals)
+ self.raise_error(name, desc or err)
+
+ def check_allowed_pattern(self, name, val, p, desc=None):
+ m = re.match(p, val)
+ if m is None or m.end() != len(val):
+ err = '"%s" does not match %s "%s"' % (val, ALLOWED_PATTERN, p)
+ self.raise_error(name, desc or err)
+
+ def check_max_length(self, name, val, const, desc=None):
+ max_len = int(const)
+ val_len = len(val)
+ if val_len > max_len:
+ err = 'length (%d) overflows %s (%d)' % (val_len,
+ MAX_LENGTH, max_len)
+ self.raise_error(name, desc or err)
+
+ def check_min_length(self, name, val, const, desc=None):
+ min_len = int(const)
+ val_len = len(val)
+ if val_len < min_len:
+ err = 'length (%d) underflows %s (%d)' % (val_len,
+ MIN_LENGTH, min_len)
+ self.raise_error(name, desc or err)
+
+ def check_max_value(self, name, val, const, desc=None):
+ max_val = float(const)
+ val = float(val)
+ if val > max_val:
+ err = '%d overflows %s %d' % (val, MAX_VALUE, max_val)
+ self.raise_error(name, desc or err)
+
+ def check_min_value(self, name, val, const, desc=None):
+ min_val = float(const)
+ val = float(val)
+ if val < min_val:
+ err = '%d underflows %s %d' % (val, MIN_VALUE, min_val)
+ self.raise_error(name, desc or err)
+
+ def check(self, const_key):
+ return {ALLOWED_VALUES: self.check_allowed_values,
+ ALLOWED_PATTERN: self.check_allowed_pattern,
+ MAX_LENGTH: self.check_max_length,
+ MIN_LENGTH: self.check_min_length,
+ MAX_VALUE: self.check_max_value,
+ MIN_VALUE: self.check_min_value}.get(const_key)
+
+
class Parameter(object):
'''A template parameter.'''
self.name = name
self.schema = schema
self.user_value = value
- self._constraint_error = self.schema.get(CONSTRAINT_DESCRIPTION)
-
- if self.has_default():
- self._validate(self.default())
if validate_value:
+ if self.has_default():
+ self.validate(self.default())
if self.user_value is not None:
- self._validate(self.user_value)
+ self.validate(self.user_value)
elif not self.has_default():
raise exception.UserParameterMissing(key=self.name)
- def _error_msg(self, message):
- return '%s %s' % (self.name, self._constraint_error or message)
-
- def _validate(self, value):
- if ALLOWED_VALUES in self.schema:
- allowed = list(self.schema[ALLOWED_VALUES])
- if value not in allowed:
- message = '%s not in %s %s' % (value, ALLOWED_VALUES, allowed)
- raise ValueError(self._error_msg(message))
-
def value(self):
'''Get the parameter value, optionally sanitising it for output.'''
if self.user_value is not None:
class NumberParam(Parameter):
'''A template parameter of type "Number".'''
- @staticmethod
- def str_to_num(s):
- '''Convert a string to an integer (if possible) or float.'''
- try:
- return int(s)
- except ValueError:
- return float(s)
-
- def _validate(self, value):
- '''Check that the supplied value is compatible with the constraints.'''
- num = self.str_to_num(value)
- minn = self.str_to_num(self.schema.get(MIN_VALUE, value))
- maxn = self.str_to_num(self.schema.get(MAX_VALUE, value))
-
- if num > maxn or num < minn:
- raise ValueError(self._error_msg('%s is out of range' % value))
-
- Parameter._validate(self, value)
-
def __int__(self):
'''Return an integer representation of the parameter'''
return int(self.value())
'''Return a float representation of the parameter'''
return float(self.value())
+ def validate(self, val):
+ self.schema.do_check(self.name, val, [ALLOWED_VALUES,
+ MAX_VALUE, MIN_VALUE])
+
class StringParam(Parameter):
'''A template parameter of type "String".'''
- def _validate(self, value):
- '''Check that the supplied value is compatible with the constraints.'''
- if not isinstance(value, basestring):
- raise ValueError(self._error_msg('value must be a string'))
-
- length = len(value)
- if MAX_LENGTH in self.schema:
- max_length = int(self.schema[MAX_LENGTH])
- if length > max_length:
- message = 'length (%d) overflows %s %s' % (length,
- MAX_LENGTH,
- max_length)
- raise ValueError(self._error_msg(message))
-
- if MIN_LENGTH in self.schema:
- min_length = int(self.schema[MIN_LENGTH])
- if length < min_length:
- message = 'length (%d) underflows %s %d' % (length,
- MIN_LENGTH,
- min_length)
- raise ValueError(self._error_msg(message))
-
- if ALLOWED_PATTERN in self.schema:
- pattern = self.schema[ALLOWED_PATTERN]
- match = re.match(pattern, value)
- if match is None or match.end() != length:
- message = '"%s" does not match %s "%s"' % (value,
- ALLOWED_PATTERN,
- pattern)
- raise ValueError(self._error_msg(message))
-
- Parameter._validate(self, value)
+ def validate(self, val):
+ self.schema.do_check(self.name, val,
+ [ALLOWED_VALUES,
+ ALLOWED_PATTERN, MAX_LENGTH, MIN_LENGTH])
class CommaDelimitedListParam(Parameter, collections.Sequence):
'''A template parameter of type "CommaDelimitedList".'''
- def _validate(self, value):
- '''Check that the supplied value is compatible with the constraints.'''
- try:
- value.split(',')
- except AttributeError:
- raise ValueError('Value must be a comma-delimited list string')
+ def __init__(self, name, schema, value=None, validate_value=True):
+ super(CommaDelimitedListParam, self).__init__(name, schema, value,
+ validate_value)
+ self.parsed = self.parse(self.user_value or self.default())
- for li in self:
- Parameter._validate(self, li)
+ def parse(self, value):
+ try:
+ if value:
+ return value.split(',')
+ except (KeyError, AttributeError) as err:
+ message = 'Value must be a comma-delimited list string: %s'
+ raise ValueError(message % str(err))
+ return value
def __len__(self):
'''Return the length of the list.'''
- return len(self.value().split(','))
+ return len(self.parsed)
def __getitem__(self, index):
'''Return an item from the list.'''
- return self.value().split(',')[index]
+ return self.parsed[index]
+
+ def validate(self, val):
+ parsed = self.parse(val)
+ for val in parsed:
+ self.schema.do_check(self.name, val, [ALLOWED_VALUES])
class JsonParam(Parameter, collections.Mapping):
"""A template parameter who's value is valid map."""
- def _validate(self, value):
- message = 'Value must be valid JSON'
- if isinstance(value, collections.Mapping):
+ def __init__(self, name, schema, value=None, validate_value=True):
+ super(JsonParam, self).__init__(name, schema, value,
+ validate_value)
+ self.parsed = self.parse(self.user_value or self.default())
+
+ def parse(self, value):
+ try:
+ val = value
+ if isinstance(val, collections.Mapping):
+ val = json.dumps(val)
+ if val:
+ return json.loads(val)
+ except (ValueError, TypeError) as err:
+ message = 'Value must be valid JSON: %s' % str(err)
+ raise ValueError(message)
+ return value
+
+ def value(self):
+ val = super(JsonParam, self).value()
+ if isinstance(val, collections.Mapping):
try:
- self.user_value = json.dumps(value)
+ val = json.dumps(val)
+ self.user_value = val
except (ValueError, TypeError) as err:
+ message = 'Value must be valid JSON'
raise ValueError("%s: %s" % (message, str(err)))
- self.parsed = value
- else:
- try:
- self.parsed = json.loads(value)
- except ValueError:
- raise ValueError(message)
-
- # check length
- my_len = len(self.parsed)
- if MAX_LENGTH in self.schema:
- max_length = int(self.schema[MAX_LENGTH])
- if my_len > max_length:
- message = ('value length (%d) overflows %s %s'
- % (my_len, MAX_LENGTH, max_length))
- raise ValueError(self._error_msg(message))
- if MIN_LENGTH in self.schema:
- min_length = int(self.schema[MIN_LENGTH])
- if my_len < min_length:
- message = ('value length (%d) underflows %s %s'
- % (my_len, MIN_LENGTH, min_length))
- raise ValueError(self._error_msg(message))
- # check valid keys
- if ALLOWED_VALUES in self.schema:
- allowed = list(self.schema[ALLOWED_VALUES])
- bad_keys = [k for k in self.parsed if k not in allowed]
- if bad_keys:
- message = ('keys %s are not in %s %s'
- % (bad_keys, ALLOWED_VALUES, allowed))
- raise ValueError(self._error_msg(message))
+ return val
def __getitem__(self, key):
return self.parsed[key]
def __len__(self):
return len(self.parsed)
+ def validate(self, val):
+ val = self.parse(val)
+ self.schema.do_check(self.name, val, [MAX_LENGTH, MIN_LENGTH])
+ for key in val:
+ self.schema.do_check(self.name, key, [ALLOWED_VALUES])
+
class Parameters(collections.Mapping):
'''
'''
def parameters():
yield Parameter(PARAM_STACK_ID,
- {TYPE: STRING,
- DESCRIPTION: 'Stack ID',
- DEFAULT: str(stack_id)})
+ ParamSchema({TYPE: STRING,
+ DESCRIPTION: 'Stack ID',
+ DEFAULT: str(stack_id)}))
if stack_name is not None:
yield Parameter(PARAM_STACK_NAME,
- {TYPE: STRING,
- DESCRIPTION: 'Stack Name',
- DEFAULT: stack_name})
+ ParamSchema({TYPE: STRING,
+ DESCRIPTION: 'Stack Name',
+ DEFAULT: stack_name}))
yield Parameter(PARAM_REGION,
- {TYPE: STRING,
- DEFAULT: 'ap-southeast-1',
- ALLOWED_VALUES: ['us-east-1',
- 'us-west-1', 'us-west-2',
- 'sa-east-1',
- 'eu-west-1',
- 'ap-southeast-1',
- 'ap-northeast-1']})
-
- for name, schema in tmpl[template.PARAMETERS].iteritems():
- yield Parameter(name, schema, user_params.get(name),
- validate_value)
+ ParamSchema({TYPE: STRING,
+ DEFAULT: 'ap-southeast-1',
+ ALLOWED_VALUES:
+ ['us-east-1',
+ 'us-west-1',
+ 'us-west-2',
+ 'sa-east-1',
+ 'eu-west-1',
+ 'ap-southeast-1',
+ 'ap-northeast-1']}))
+
+ schemata = self.tmpl.param_schemata().iteritems()
+ for name, schema in schemata:
+ value = user_params.get(name)
+ yield Parameter(name, schema, value, validate_value)
self.tmpl = tmpl
self._validate(user_params)
self.params[PARAM_STACK_ID].schema[DEFAULT] = stack_id
def _validate(self, user_params):
+ schemata = self.tmpl.param_schemata()
for param in user_params:
- if param not in self.tmpl[template.PARAMETERS]:
+ if param not in schemata:
raise exception.UnknownUserParameter(key=param)
from heat.common import exception
from heat.engine import parser
from heat.engine import hot
+from heat.engine import parameters
from heat.engine import template
from heat.tests.common import HeatTestCase
hot.HOTemplate.resolve_attributes,
{'Value': {'get_attr': ['resource1', 'NotThere']}},
self.stack)
+
+
+class HOTParamValidatorTest(HeatTestCase):
+ "Test HOTParamValidator"
+
+ def test_multiple_constraint_descriptions(self):
+ len_desc = 'string length should between 8 and 16'
+ pattern_desc1 = 'Value must consist of characters only'
+ pattern_desc2 = 'Value must start with a lowercase character'
+ param = {
+ 'db_name': {
+ 'Description': 'The WordPress database name',
+ 'Type': 'String',
+ 'Default': 'wordpress',
+ 'MinLength': [(8, len_desc)],
+ 'MaxLength': [(16, len_desc)],
+ 'AllowedPattern': [
+ ('[a-zA-Z]+', pattern_desc1),
+ ('[a-z]+[a-zA-Z]*', pattern_desc2)]}}
+
+ name = 'db_name'
+ schema = param['db_name']
+
+ def v(value):
+ hot.HOTParamSchema(schema).do_check(name, value,
+ [parameters.ALLOWED_VALUES,
+ parameters.ALLOWED_PATTERN,
+ parameters.MAX_LENGTH,
+ parameters.MIN_LENGTH])
+ return True
+
+ value = 'wp'
+ err = self.assertRaises(ValueError, v, value)
+ self.assertIn(len_desc, str(err))
+
+ value = 'abcdefghijklmnopq'
+ err = self.assertRaises(ValueError, v, value)
+ self.assertIn(len_desc, str(err))
+
+ value = 'abcdefgh1'
+ err = self.assertRaises(ValueError, v, value)
+ self.assertIn(pattern_desc1, str(err))
+
+ value = 'Abcdefghi'
+ err = self.assertRaises(ValueError, v, value)
+ self.assertIn(pattern_desc2, str(err))
+
+ value = 'abcdefghi'
+ self.assertTrue(v(value))
+
+ value = 'abcdefghI'
+ self.assertTrue(v(value))
+
+ def test_hot_template_validate_param(self):
+ len_desc = 'string length should between 8 and 16'
+ pattern_desc1 = 'Value must consist of characters only'
+ pattern_desc2 = 'Value must start with a lowercase character'
+ hot_tpl = template_format.parse('''
+ heat_template_version: 2013-05-23
+ parameters:
+ db_name:
+ description: The WordPress database name
+ type: string
+ default: wordpress
+ constraints:
+ - length: { min: 8, max: 16 }
+ description: %s
+ - allowed_pattern: "[a-zA-Z]+"
+ description: %s
+ - allowed_pattern: "[a-z]+[a-zA-Z]*"
+ description: %s
+ ''' % (len_desc, pattern_desc1, pattern_desc2))
+ tmpl = parser.Template(hot_tpl)
+
+ def run_parameters(value):
+ parameters.Parameters("stack_testit", tmpl, {'db_name': value})
+ return True
+
+ value = 'wp'
+ err = self.assertRaises(ValueError, run_parameters, value)
+ self.assertIn(len_desc, str(err))
+
+ value = 'abcdefghijklmnopq'
+ err = self.assertRaises(ValueError, run_parameters, value)
+ self.assertIn(len_desc, str(err))
+
+ value = 'abcdefgh1'
+ err = self.assertRaises(ValueError, run_parameters, value)
+ self.assertIn(pattern_desc1, str(err))
+
+ value = 'Abcdefghi'
+ err = self.assertRaises(ValueError, run_parameters, value)
+ self.assertIn(pattern_desc2, str(err))
+
+ value = 'abcdefghi'
+ self.assertTrue(run_parameters(value))
+
+ value = 'abcdefghI'
+ self.assertTrue(run_parameters(value))