--- /dev/null
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import collections
+import re
+
+from heat.engine import template
+
+PARAMETER_KEYS = (
+ TYPE, DEFAULT, NO_ECHO, VALUES, PATTERN,
+ MAX_LENGTH, MIN_LENGTH, MAX_VALUE, MIN_VALUE,
+ DESCRIPTION, CONSTRAINT_DESCRIPTION
+) = (
+ 'Type', 'Default', 'NoEcho', 'AllowedValues', 'AllowedPattern',
+ 'MaxLength', 'MinLength', 'MaxValue', 'MinValue',
+ 'Description', 'ConstraintDescription'
+)
+PARAMETER_TYPES = (
+ STRING, NUMBER, COMMA_DELIMITED_LIST
+) = (
+ 'String', 'Number', 'CommaDelimitedList'
+)
+(PARAM_STACK_NAME, PARAM_REGION) = ('AWS::StackName', 'AWS::Region')
+
+
+class Parameter(object):
+ '''A template parameter.'''
+
+ def __new__(cls, name, schema, value=None):
+ '''Create a new Parameter of the appropriate type.'''
+ if cls is not Parameter:
+ return super(Parameter, cls).__new__(cls)
+
+ param_type = schema[TYPE]
+ if param_type == STRING:
+ ParamClass = StringParam
+ elif param_type == NUMBER:
+ ParamClass = NumberParam
+ elif param_type == COMMA_DELIMITED_LIST:
+ ParamClass = CommaDelimitedListParam
+ else:
+ raise ValueError('Invalid Parameter type "%s"' % param_type)
+
+ return ParamClass(name, schema, value)
+
+ def __init__(self, name, schema, value=None):
+ '''
+ Initialise the Parameter with a name, schema and optional user-supplied
+ value.
+ '''
+ self.name = name
+ self.schema = schema
+ self.user_value = value
+ self._constraint_error = self.schema.get(CONSTRAINT_DESCRIPTION)
+
+ if self.has_default():
+ self._validate(self.default())
+
+ if self.user_value is not None:
+ self._validate(self.user_value)
+
+ def _error_msg(self, message):
+ return '%s %s' % (self.name, self._constraint_error or message)
+
+ def _validate(self, value):
+ if VALUES in self.schema:
+ allowed = self.schema[VALUES]
+ if value not in allowed:
+ message = '%s not in %s %s' % (value, VALUES, allowed)
+ raise ValueError(self._error_msg(message))
+
+ def value(self):
+ '''Get the parameter value, optionally sanitising it for output.'''
+ if self.user_value is not None:
+ return self.user_value
+
+ if self.has_default():
+ return self.default()
+
+ raise KeyError('Missing parameter %s' % self.name)
+
+ def description(self):
+ '''Return the description of the parameter.'''
+ return self.schema.get(DESCRIPTION, '')
+
+ def has_default(self):
+ '''Return whether the parameter has a default value.'''
+ return DEFAULT in self.schema
+
+ def default(self):
+ '''Return the default value of the parameter.'''
+ return self.schema.get(DEFAULT)
+
+ def __str__(self):
+ '''Return a string representation of the parameter'''
+ return self.value()
+
+
+class NumberParam(Parameter):
+ '''A template parameter of type "Number".'''
+
+ @staticmethod
+ def str_to_num(s):
+ '''Convert a string to an integer (if possible) or float.'''
+ try:
+ return int(s)
+ except ValueError:
+ return float(s)
+
+ def _validate(self, value):
+ '''Check that the supplied value is compatible with the constraints.'''
+ num = self.str_to_num(value)
+ minn = self.str_to_num(self.schema.get(MIN_VALUE, value))
+ maxn = self.str_to_num(self.schema.get(MAX_VALUE, value))
+
+ if num > maxn or num < minn:
+ raise ValueError(self._error_msg('%s is out of range' % value))
+
+ Parameter._validate(self, value)
+
+ def __int__(self):
+ '''Return an integer representation of the parameter'''
+ return int(self.value())
+
+ def __float__(self):
+ '''Return a float representation of the parameter'''
+ return float(self.value())
+
+
+class StringParam(Parameter):
+ '''A template parameter of type "String".'''
+
+ def _validate(self, value):
+ '''Check that the supplied value is compatible with the constraints'''
+ if not isinstance(value, basestring):
+ raise ValueError(self._error_msg('value must be a string'))
+
+ length = len(value)
+ if MAX_LENGTH in self.schema:
+ max_length = int(self.schema[MAX_LENGTH])
+ if length > max_length:
+ message = 'length (%d) overflows %s %s' % (length,
+ MAX_LENGTH,
+ max_length)
+ raise ValueError(self._error_msg(message))
+
+ if MIN_LENGTH in self.schema:
+ min_length = int(self.schema[MIN_LENGTH])
+ if length < min_length:
+ message = 'length (%d) underflows %s %d' % (length,
+ MIN_LENGTH,
+ min_length)
+ raise ValueError(self._error_msg(message))
+
+ if PATTERN in self.schema:
+ pattern = self.schema[PATTERN]
+ match = re.match(pattern, value)
+ if match is None or match.end() != length:
+ message = '"%s" does not match %s "%s"' % (value,
+ PATTERN,
+ pattern)
+ raise ValueError(self._error_msg(message))
+
+ Parameter._validate(self, value)
+
+
+class CommaDelimitedListParam(Parameter, collections.Sequence):
+ '''A template parameter of type "CommaDelimitedList".'''
+
+ def _validate(self, value):
+ '''Check that the supplied value is compatible with the constraints'''
+ try:
+ sp = value.split(',')
+ except AttributeError:
+ raise ValueError('Value must be a comma-delimited list string')
+
+ for li in self:
+ Parameter._validate(self, li)
+
+ def __len__(self):
+ '''Return the length of the list'''
+ return len(self.value().split(','))
+
+ def __getitem__(self, index):
+ '''Return an item from the list'''
+ return self.value().split(',')[index]
+
+
+class Parameters(collections.Mapping):
+ '''
+ The parameters of a stack, with type checking, defaults &c. specified by
+ the stack's template.
+ '''
+
+ def __init__(self, stack_name, tmpl, user_params={}):
+ '''
+ Create the parameter container for a stack from the stack name and
+ template, optionally setting the user-supplied parameter values.
+ '''
+ def parameters():
+ if stack_name is not None:
+ yield Parameter(PARAM_STACK_NAME,
+ {TYPE: STRING,
+ DESCRIPTION: 'Stack Name',
+ DEFAULT: stack_name})
+ yield Parameter(PARAM_REGION,
+ {TYPE: STRING,
+ DEFAULT: 'ap-southeast-1',
+ VALUES: ['us-east-1',
+ 'us-west-1', 'us-west-2',
+ 'sa-east-1',
+ 'eu-west-1',
+ 'ap-southeast-1',
+ 'ap-northeast-1']})
+
+ for name, schema in tmpl[template.PARAMETERS].iteritems():
+ yield Parameter(name, schema, user_params.get(name))
+
+ self.params = dict((p.name, p) for p in parameters())
+
+ def __contains__(self, key):
+ '''Return whether the specified parameter exists'''
+ return key in self.params
+
+ def __iter__(self):
+ '''Return an iterator over the parameter names.'''
+ return iter(self.params)
+
+ def __len__(self):
+ '''Return the number of parameters defined'''
+ return len(self.params)
+
+ def __getitem__(self, key):
+ '''Get a parameter value.'''
+ return self.params[key].value()
+
+ def map(self, func, filter_func=lambda p: True):
+ '''
+ Map the supplied filter function onto each Parameter (with an
+ optional filter function) and return the resulting dictionary.
+ '''
+ return dict((n, func(p)) for n, p in self.params.iteritems()
+ if filter_func(p))
+
+ def user_parameters(self):
+ '''
+ Return a dictionary of all the parameters passed in by the user
+ '''
+ return self.map(lambda p: p.user_value,
+ lambda p: p.user_value is not None)
import copy
from heat.common import exception
-from heat.engine import checkeddict
from heat.engine import dependencies
from heat.engine import identifier
from heat.engine import resources
from heat.engine import template
from heat.engine import timestamp
+from heat.engine.parameters import Parameters
from heat.engine.template import Template
from heat.db import api as db_api
(PARAM_STACK_NAME, PARAM_REGION) = ('AWS::StackName', 'AWS::Region')
-class Parameters(checkeddict.CheckedDict):
- '''
- The parameters of a stack, with type checking, defaults &c. specified by
- the stack's template.
- '''
-
- def __init__(self, stack_name, tmpl, user_params={}):
- '''
- Create the parameter container for a stack from the stack name and
- template, optionally setting the initial set of parameters.
- '''
- checkeddict.CheckedDict.__init__(self, template.PARAMETERS)
- self._init_schemata(tmpl[template.PARAMETERS])
-
- self[PARAM_STACK_NAME] = stack_name
- self.update(user_params)
-
- def _init_schemata(self, schemata):
- '''
- Initialise the parameter schemata with the pseudo-parameters and the
- list of schemata obtained from the template.
- '''
- self.addschema(PARAM_STACK_NAME, {"Description": "AWS StackName",
- "Type": "String"})
- self.addschema(PARAM_REGION, {
- "Description": "AWS Regions",
- "Default": "ap-southeast-1",
- "Type": "String",
- "AllowedValues": ["us-east-1", "us-west-1", "us-west-2",
- "sa-east-1", "eu-west-1", "ap-southeast-1",
- "ap-northeast-1"],
- "ConstraintDescription": "must be a valid EC2 instance type.",
- })
-
- for param, schema in schemata.items():
- self.addschema(param, copy.deepcopy(schema))
-
- def user_parameters(self):
- '''
- Return a dictionary of all the parameters passed in by the user
- '''
- return dict((k, v['Value']) for k, v in self.data.iteritems()
- if 'Value' in v)
-
-
class Stack(object):
CREATE_IN_PROGRESS = 'CREATE_IN_PROGRESS'
CREATE_FAILED = 'CREATE_FAILED'
'Parameters': []}
return response
- def format_param(p):
+ def describe_param(p):
return {'NoEcho': 'false',
- 'ParameterKey': p,
- 'Description': self.parameters.get_attr(p, 'Description'),
- 'DefaultValue': self.parameters.get_attr(p, 'Default')}
+ 'ParameterKey': p.name,
+ 'Description': p.description(),
+ 'DefaultValue': p.default()}
+
+ params = self.parameters.map(describe_param)
response = {'Description': 'Successfully validated',
- 'Parameters': [format_param(p) for p in self.parameters]}
+ 'Parameters': params.values()}
return response
--- /dev/null
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+import nose
+import unittest
+from nose.plugins.attrib import attr
+import mox
+import json
+
+from heat.common import context
+from heat.common import exception
+from heat.engine import parameters
+
+
+@attr(tag=['unit', 'parameters'])
+@attr(speed='fast')
+class ParameterTest(unittest.TestCase):
+ def test_new_string(self):
+ p = parameters.Parameter('p', {'Type': 'String'})
+ self.assertTrue(isinstance(p, parameters.StringParam))
+
+ def test_new_number(self):
+ p = parameters.Parameter('p', {'Type': 'Number'})
+ self.assertTrue(isinstance(p, parameters.NumberParam))
+
+ def test_new_list(self):
+ p = parameters.Parameter('p', {'Type': 'CommaDelimitedList'})
+ self.assertTrue(isinstance(p, parameters.CommaDelimitedListParam))
+
+ def test_new_bad_type(self):
+ self.assertRaises(ValueError, parameters.Parameter,
+ 'p', {'Type': 'List'})
+
+ def test_new_no_type(self):
+ self.assertRaises(KeyError, parameters.Parameter,
+ 'p', {'Default': 'blarg'})
+
+ def test_default_no_override(self):
+ p = parameters.Parameter('defaulted', {'Type': 'String',
+ 'Default': 'blarg'})
+ self.assertTrue(p.has_default())
+ self.assertEqual(p.default(), 'blarg')
+ self.assertEqual(p.value(), 'blarg')
+
+ def test_default_override(self):
+ p = parameters.Parameter('defaulted',
+ {'Type': 'String',
+ 'Default': 'blarg'},
+ 'wibble')
+ self.assertTrue(p.has_default())
+ self.assertEqual(p.default(), 'blarg')
+ self.assertEqual(p.value(), 'wibble')
+
+ def test_default_invalid(self):
+ schema = {'Type': 'String',
+ 'AllowedValues': ['foo'],
+ 'ConstraintDescription': 'wibble',
+ 'Default': 'bar'}
+ try:
+ parameters.Parameter('p', schema, 'foo')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_description(self):
+ description = 'Description of the parameter'
+ p = parameters.Parameter('p', {'Type': 'String',
+ 'Description': description})
+ self.assertEqual(p.description(), description)
+
+ def test_no_description(self):
+ p = parameters.Parameter('p', {'Type': 'String'})
+ self.assertEqual(p.description(), '')
+
+ def test_string_len_good(self):
+ schema = {'Type': 'String',
+ 'MinLength': '3',
+ 'MaxLength': '3'}
+ p = parameters.Parameter('p', schema, 'foo')
+ self.assertEqual(p.value(), 'foo')
+
+ def test_string_underflow(self):
+ schema = {'Type': 'String',
+ 'ConstraintDescription': 'wibble',
+ 'MinLength': '4'}
+ try:
+ parameters.Parameter('p', schema, 'foo')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_string_overflow(self):
+ schema = {'Type': 'String',
+ 'ConstraintDescription': 'wibble',
+ 'MaxLength': '2'}
+ try:
+ parameters.Parameter('p', schema, 'foo')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_string_pattern_good(self):
+ schema = {'Type': 'String',
+ 'AllowedPattern': '[a-z]*'}
+ p = parameters.Parameter('p', schema, 'foo')
+ self.assertEqual(p.value(), 'foo')
+
+ def test_string_pattern_bad_prefix(self):
+ schema = {'Type': 'String',
+ 'ConstraintDescription': 'wibble',
+ 'AllowedPattern': '[a-z]*'}
+ try:
+ parameters.Parameter('p', schema, '1foo')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_string_pattern_bad_suffix(self):
+ schema = {'Type': 'String',
+ 'ConstraintDescription': 'wibble',
+ 'AllowedPattern': '[a-z]*'}
+ try:
+ parameters.Parameter('p', schema, 'foo1')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_string_value_list_good(self):
+ schema = {'Type': 'String',
+ 'AllowedValues': ['foo', 'bar', 'baz']}
+ p = parameters.Parameter('p', schema, 'bar')
+ self.assertEqual(p.value(), 'bar')
+
+ def test_string_value_list_bad(self):
+ schema = {'Type': 'String',
+ 'ConstraintDescription': 'wibble',
+ 'AllowedValues': ['foo', 'bar', 'baz']}
+ try:
+ parameters.Parameter('p', schema, 'blarg')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_number_int_good(self):
+ schema = {'Type': 'Number',
+ 'MinValue': '3',
+ 'MaxValue': '3'}
+ p = parameters.Parameter('p', schema, '3')
+ self.assertEqual(p.value(), '3')
+
+ def test_number_float_good(self):
+ schema = {'Type': 'Number',
+ 'MinValue': '3.0',
+ 'MaxValue': '3.0'}
+ p = parameters.Parameter('p', schema, '3.0')
+ self.assertEqual(p.value(), '3.0')
+
+ def test_number_low(self):
+ schema = {'Type': 'Number',
+ 'ConstraintDescription': 'wibble',
+ 'MinValue': '4'}
+ try:
+ parameters.Parameter('p', schema, '3')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_number_high(self):
+ schema = {'Type': 'Number',
+ 'ConstraintDescription': 'wibble',
+ 'MaxValue': '2'}
+ try:
+ parameters.Parameter('p', schema, '3')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_number_value_list_good(self):
+ schema = {'Type': 'Number',
+ 'AllowedValues': ['1', '3', '5']}
+ p = parameters.Parameter('p', schema, '5')
+ self.assertEqual(p.value(), '5')
+
+ def test_number_value_list_bad(self):
+ schema = {'Type': 'Number',
+ 'ConstraintDescription': 'wibble',
+ 'AllowedValues': ['1', '3', '5']}
+ try:
+ parameters.Parameter('p', schema, '2')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+ def test_list_value_list_good(self):
+ schema = {'Type': 'CommaDelimitedList',
+ 'AllowedValues': ['foo', 'bar', 'baz']}
+ p = parameters.Parameter('p', schema, 'baz,foo,bar')
+ self.assertEqual(p.value(), 'baz,foo,bar')
+
+ def test_list_value_list_bad(self):
+ schema = {'Type': 'CommaDelimitedList',
+ 'ConstraintDescription': 'wibble',
+ 'AllowedValues': ['foo', 'bar', 'baz']}
+ try:
+ parameters.Parameter('p', schema, 'foo,baz,blarg')
+ except ValueError as ve:
+ msg = str(ve)
+ self.assertNotEqual(msg.find('wibble'), -1)
+ else:
+ self.fail('ValueError not raised')
+
+
+params_schema = json.loads('''{
+ "Parameters" : {
+ "User" : { "Type": "String" },
+ "Defaulted" : {
+ "Type": "String",
+ "Default": "foobar"
+ }
+ }
+}''')
+
+
+@attr(tag=['unit', 'parameters'])
+@attr(speed='fast')
+class ParametersTest(unittest.TestCase):
+ def test_pseudo_params(self):
+ params = parameters.Parameters('test_stack', {"Parameters": {}})
+
+ self.assertEqual(params['AWS::StackName'], 'test_stack')
+ self.assertTrue('AWS::Region' in params)
+
+ def test_user_param(self):
+ user_params = {'User': 'wibble'}
+ params = parameters.Parameters('test', params_schema, user_params)
+ self.assertEqual(params.user_parameters(), user_params)
+
+ def test_user_param_nonexist(self):
+ params = parameters.Parameters('test', params_schema)
+ self.assertEqual(params.user_parameters(), {})
+
+ def test_schema_invariance(self):
+ params1 = parameters.Parameters('test', params_schema,
+ {'Defaulted': 'wibble'})
+ self.assertEqual(params1['Defaulted'], 'wibble')
+
+ params2 = parameters.Parameters('test', params_schema)
+ self.assertEqual(params2['Defaulted'], 'foobar')
+
+ def test_to_dict(self):
+ template = {'Parameters': {'Foo': {'Type': 'String'},
+ 'Bar': {'Type': 'Number', 'Default': '42'}}}
+ params = parameters.Parameters('test_params', template, {'Foo': 'foo'})
+
+ as_dict = dict(params)
+ self.assertEqual(as_dict['Foo'], 'foo')
+ self.assertEqual(as_dict['Bar'], '42')
+ self.assertEqual(as_dict['AWS::StackName'], 'test_params')
+ self.assertTrue('AWS::Region' in as_dict)
+
+ def test_map(self):
+ template = {'Parameters': {'Foo': {'Type': 'String'},
+ 'Bar': {'Type': 'Number', 'Default': '42'}}}
+ params = parameters.Parameters('test_params', template, {'Foo': 'foo'})
+
+ expected = {'Foo': False,
+ 'Bar': True,
+ 'AWS::Region': True,
+ 'AWS::StackName': True}
+
+ self.assertEqual(params.map(lambda p: p.has_default()), expected)
+
+
+# allows testing of the test directly, shown below
+if __name__ == '__main__':
+ sys.argv.append(__file__)
+ nose.main()
dict_snippet)
-params_schema = json.loads('''{
- "Parameters" : {
- "User" : { "Type": "String" },
- "Defaulted" : {
- "Type": "String",
- "Default": "foobar"
- }
- }
-}''')
-
-
-@attr(tag=['unit', 'parser', 'parameters'])
-@attr(speed='fast')
-class ParametersTest(unittest.TestCase):
- def test_pseudo_params(self):
- params = parser.Parameters('test_stack', {"Parameters": {}})
-
- self.assertEqual(params['AWS::StackName'], 'test_stack')
- self.assertTrue('AWS::Region' in params)
-
- def test_user_param(self):
- params = parser.Parameters('test', params_schema, {'User': 'wibble'})
- user_params = params.user_parameters()
- self.assertEqual(user_params['User'], 'wibble')
-
- def test_user_param_default(self):
- params = parser.Parameters('test', params_schema)
- user_params = params.user_parameters()
- self.assertTrue('Defaulted' not in user_params)
-
- def test_user_param_nonexist(self):
- params = parser.Parameters('test', params_schema)
- user_params = params.user_parameters()
- self.assertTrue('User' not in user_params)
-
- def test_schema_invariance(self):
- params1 = parser.Parameters('test', params_schema)
- params1['Defaulted'] = "wibble"
- self.assertEqual(params1['Defaulted'], 'wibble')
-
- params2 = parser.Parameters('test', params_schema)
- self.assertEqual(params2['Defaulted'], 'foobar')
-
-
@attr(tag=['unit', 'parser', 'stack'])
@attr(speed='fast')
class StackTest(unittest.TestCase):