]> review.fuel-infra Code Review - openstack-build/heat-build.git/commitdiff
Make Stacks the owners of their own DB representation
authorZane Bitter <zbitter@redhat.com>
Tue, 3 Jul 2012 11:02:03 +0000 (13:02 +0200)
committerZane Bitter <zbitter@redhat.com>
Wed, 4 Jul 2012 15:20:39 +0000 (17:20 +0200)
Change the way that Stack objects are initialised, so that they know how to
load and store themselves in the database (they were already updating their
own data, but the actual creation was being done externally). This
consolidates a lot of existing Stack database creation code (in the
manager, nested stacks and unit tests) into one place.

Also, split the template and parameter handling out into separate classes,
and pass these through the constructor for easier testing.

Change-Id: I65bec175191713d0a4a6aa1d3d5442a1b64042f8
Signed-off-by: Zane Bitter <zbitter@redhat.com>
heat/engine/manager.py
heat/engine/parser.py
heat/engine/stack.py
heat/tests/functional/test_bin_heat.py
heat/tests/test_parser.py
heat/tests/test_resources.py
heat/tests/test_stacks.py
heat/tests/test_validate.py
heat/tests/test_waitcondition.py

index 117d22d24b66eaa693316f06bedec9c4dcb2a6a9..cce074779215636a8649848baf464b50f1de6163 100644 (file)
@@ -105,15 +105,12 @@ class EngineManager(manager.Manager):
         if stacks is None:
             return res
         for s in stacks:
-            ps = parser.Stack(context, s.name,
-                              s.raw_template.template,
-                              s.id, s.parameters)
+            stack = parser.Stack.load(context, s.id)
             mem = {}
             mem['StackId'] = "/".join([s.name, str(s.id)])
             mem['StackName'] = s.name
             mem['CreationTime'] = heat_utils.strtime(s.created_at)
-            mem['TemplateDescription'] = ps.t.get('Description',
-                                                   'No description')
+            mem['TemplateDescription'] = stack.t[parser.DESCRIPTION]
             mem['StackStatus'] = s.status
             res['stacks'].append(mem)
 
@@ -144,24 +141,21 @@ class EngineManager(manager.Manager):
             logging.debug("Processing show_stack for %s" % stack)
             s = db_api.stack_get_by_name(context, stack)
             if s:
-                ps = parser.Stack(context, s.name,
-                                  s.raw_template.template,
-                                  s.id, s.parameters)
+                stack = parser.Stack.load(context, s.id)
                 mem = {}
                 mem['StackId'] = "/".join([s.name, str(s.id)])
                 mem['StackName'] = s.name
                 mem['CreationTime'] = heat_utils.strtime(s.created_at)
                 mem['LastUpdatedTimestamp'] = heat_utils.strtime(s.updated_at)
                 mem['NotificationARNs'] = 'TODO'
-                mem['Parameters'] = ps.t['Parameters']
-                mem['Description'] = ps.t.get('Description',
-                                              'No description')
+                mem['Parameters'] = stack.t[parser.PARAMETERS]
+                mem['Description'] = stack.t[parser.DESCRIPTION]
                 mem['StackStatus'] = s.status
                 mem['StackStatusReason'] = s.status_reason
 
                 # only show the outputs on a completely created stack
-                if s.status == ps.CREATE_COMPLETE:
-                    mem['Outputs'] = ps.get_outputs()
+                if s.status == stack.CREATE_COMPLETE:
+                    mem['Outputs'] = stack.get_outputs()
 
                 res['stacks'].append(mem)
 
@@ -185,40 +179,19 @@ class EngineManager(manager.Manager):
         if db_api.stack_get_by_name(None, stack_name):
             return {'Error': 'Stack already exists with that name.'}
 
-        user_params = _extract_user_params(params)
-        # We don't want to reset the stack template, so we are making
-        # an instance just for validation.
-        template_copy = deepcopy(template)
-        stack_validator = parser.Stack(context, stack_name,
-                                       template_copy, 0,
-                                       user_params)
-        response = stack_validator.validate()
-        stack_validator = None
-        template_copy = None
-        if 'Malformed Query Response' in \
-                response['ValidateTemplateResult']['Description']:
-            return response
-
-        stack = parser.Stack(context, stack_name, template, 0, user_params)
-        rt = {}
-        rt['template'] = template
-        rt['StackName'] = stack_name
-        new_rt = db_api.raw_template_create(None, rt)
-
-        new_creds = db_api.user_creds_create(context.to_dict())
+        tmpl = parser.Template(template)
+        user_params = parser.Parameters(stack_name, tmpl,
+                                        _extract_user_params(params))
+        stack = parser.Stack(context, stack_name, tmpl, user_params)
 
-        s = {}
-        s['name'] = stack_name
-        s['raw_template_id'] = new_rt.id
-        s['user_creds_id'] = new_creds.id
-        s['username'] = context.username
-        s['parameters'] = user_params
-        new_s = db_api.stack_create(context, s)
-        stack.id = new_s.id
+        response = stack.validate()
+        if response['Description'] != 'Successfully validated':
+            return response
 
+        stack_id = stack.store()
         greenpool.spawn_n(stack.create, **_extract_args(params))
 
-        return {'StackId': "/".join([new_s.name, str(new_s.id)])}
+        return {'StackId': "/".join([stack.name, str(stack.id)])}
 
     def validate_template(self, context, template, params):
         """
@@ -237,21 +210,22 @@ class EngineManager(manager.Manager):
             msg = _("No Template provided.")
             return webob.exc.HTTPBadRequest(explanation=msg)
 
+        stack_name = 'validate'
         try:
-            s = parser.Stack(context, 'validate', template, 0,
-                             _extract_user_params(params))
+            tmpl = parser.Template(template)
+            user_params = parser.Parameters(stack_name, tmpl,
+                                            _extract_user_params(params))
+            s = parser.Stack(context, stack_name, tmpl, user_params)
         except KeyError as ex:
             res = ('A Fn::FindInMap operation referenced '
                    'a non-existent map [%s]' % str(ex))
 
-            response = {'ValidateTemplateResult': {
-                        'Description': 'Malformed Query Response [%s]' % (res),
-                        'Parameters': []}}
-            return response
-
-        res = s.validate()
+            result = {'Description': 'Malformed Query Response [%s]' % (res),
+                      'Parameters': []}
+        else:
+            result = s.validate()
 
-        return res
+        return {'ValidateTemplateResult': result}
 
     def get_template(self, context, stack_name, params):
         """
@@ -282,10 +256,8 @@ class EngineManager(manager.Manager):
 
         logger.info('deleting stack %s' % stack_name)
 
-        ps = parser.Stack(context, st.name,
-                          st.raw_template.template,
-                          st.id, st.parameters)
-        greenpool.spawn_n(ps.delete)
+        stack = parser.Stack.load(context, st.id)
+        greenpool.spawn_n(stack.delete)
         return None
 
     # Helper for list_events.  It's here so we can use it in tests.
@@ -356,38 +328,43 @@ class EngineManager(manager.Manager):
     def describe_stack_resource(self, context, stack_name, resource_name):
         auth.authenticate(context)
 
-        stack = db_api.stack_get_by_name(context, stack_name)
-        if not stack:
+        s = db_api.stack_get_by_name(context, stack_name)
+        if not s:
             raise AttributeError('Unknown stack name')
-        resource = db_api.resource_get_by_name_and_stack(context,
-                                                         resource_name,
-                                                         stack.id)
-        if not resource:
+
+        stack = parser.Stack.load(context, s.id)
+        if resource_name not in stack:
             raise AttributeError('Unknown resource name')
-        return format_resource_attributes(stack, resource)
+
+        resource = stack[resource_name]
+        if resource.id is None:
+            raise AttributeError('Resource not created')
+
+        return format_stack_resource(stack[resource_name])
 
     def describe_stack_resources(self, context, stack_name,
                                  physical_resource_id, logical_resource_id):
         auth.authenticate(context)
 
-        if stack_name:
-            stack = db_api.stack_get_by_name(context, stack_name)
+        if stack_name is not None:
+            s = db_api.stack_get_by_name(context, stack_name)
         else:
-            resource = db_api.resource_get_by_physical_resource_id(context,
+            rs = db_api.resource_get_by_physical_resource_id(context,
                     physical_resource_id)
-            if not resource:
+            if not rs:
                 msg = "The specified PhysicalResourceId doesn't exist"
                 raise AttributeError(msg)
-            stack = resource.stack
+            s = rs.stack
 
-        if not stack:
+        if not s:
             raise AttributeError("The specified stack doesn't exist")
 
+        stack = parser.Stack.load(context, s.id)
         resources = []
-        for r in stack.resources:
-            if logical_resource_id and r.name != logical_resource_id:
+        for resource in stack:
+            if logical_resource_id and resource.name != logical_resource_id:
                 continue
-            formatted = format_resource_attributes(stack, r)
+            formatted = format_stack_resource(resource)
             # this API call uses Timestamp instead of LastUpdatedTimestamp
             formatted['Timestamp'] = formatted['LastUpdatedTimestamp']
             del formatted['LastUpdatedTimestamp']
@@ -398,16 +375,18 @@ class EngineManager(manager.Manager):
     def list_stack_resources(self, context, stack_name):
         auth.authenticate(context)
 
-        stack = db_api.stack_get_by_name(context, stack_name)
-        if not stack:
+        s = db_api.stack_get_by_name(context, stack_name)
+        if not s:
             raise AttributeError('Unknown stack name')
 
+        stack = parser.Stack.load(context, s.id)
+
         resources = []
         response_keys = ('ResourceStatus', 'LogicalResourceId',
                          'LastUpdatedTimestamp', 'PhysicalResourceId',
                          'ResourceType')
-        for r in stack.resources:
-            formatted = format_resource_attributes(stack, r)
+        for resource in stack:
+            formatted = format_stack_resource(resource)
             for key in formatted.keys():
                 if not key in response_keys:
                     del formatted[key]
@@ -499,11 +478,9 @@ class EngineManager(manager.Manager):
                 if s:
                     user_creds = db_api.user_creds_get(s.user_creds_id)
                     ctxt = ctxtlib.RequestContext.from_dict(dict(user_creds))
-                    ps = parser.Stack(ctxt, s.name,
-                                      s.raw_template.template,
-                                      s.id, s.parameters)
+                    stack = parser.Stack.load(ctxt, s.id)
                     for a in wr.rule[action_map[new_state]]:
-                        greenpool.spawn_n(ps[a].alarm)
+                        greenpool.spawn_n(stack[a].alarm)
 
         wr.last_evaluated = now
 
@@ -534,22 +511,20 @@ class EngineManager(manager.Manager):
         return [None, wd.data]
 
 
-def format_resource_attributes(stack, resource):
+def format_stack_resource(resource):
     """
     Return a representation of the given resource that mathes the API output
     expectations.
     """
-    template = resource.parsed_template.template
-    template_resources = template.get('Resources', {})
-    resource_type = template_resources.get(resource.name, {}).get('Type', '')
-    last_updated_time = resource.updated_at or resource.created_at
+    rs = db_api.resource_get(resource.stack.context, resource.id)
+    last_updated_time = rs.updated_at or rs.created_at
     return {
-        'StackId': stack.id,
-        'StackName': stack.name,
+        'StackId': resource.stack.id,
+        'StackName': resource.stack.name,
         'LogicalResourceId': resource.name,
-        'PhysicalResourceId': resource.nova_instance or '',
-        'ResourceType': resource_type,
+        'PhysicalResourceId': resource.instance_id or '',
+        'ResourceType': resource.t['Type'],
         'LastUpdatedTimestamp': heat_utils.strtime(last_updated_time),
-        'ResourceStatus': resource.state,
-        'ResourceStatusReason': resource.state_description,
+        'ResourceStatus': rs.state,
+        'ResourceStatusReason': rs.state_description,
     }
index 355b68c6f680f3e12d58c079b2410dce070b2150..ec770a8c48a3bb308488d6c87d5616cab42da5ca 100644 (file)
@@ -15,7 +15,8 @@
 
 import eventlet
 import json
-import itertools
+import functools
+import copy
 import logging
 
 from heat.common import exception
@@ -24,8 +25,206 @@ from heat.engine import dependencies
 from heat.engine.resources import Resource
 from heat.db import api as db_api
 
+
 logger = logging.getLogger('heat.engine.parser')
 
+SECTIONS = (VERSION, DESCRIPTION, MAPPINGS,
+            PARAMETERS, RESOURCES, OUTPUTS) = \
+           ('AWSTemplateFormatVersion', 'Description', 'Mappings',
+            'Parameters', 'Resources', 'Outputs')
+
+(PARAM_STACK_NAME, PARAM_REGION) = ('AWS::StackName', 'AWS::Region')
+
+
+class Parameters(checkeddict.CheckedDict):
+    '''
+    The parameters of a stack, with type checking, defaults &c. specified by
+    the stack's template.
+    '''
+
+    def __init__(self, stack_name, template, user_params={}):
+        '''
+        Create the parameter container for a stack from the stack name and
+        template, optionally setting the initial set of parameters.
+        '''
+        checkeddict.CheckedDict.__init__(self, PARAMETERS)
+        self._init_schemata(template[PARAMETERS])
+
+        self[PARAM_STACK_NAME] = stack_name
+        self.update(user_params)
+
+    def _init_schemata(self, schemata):
+        '''
+        Initialise the parameter schemata with the pseudo-parameters and the
+        list of schemata obtained from the template.
+        '''
+        self.addschema(PARAM_STACK_NAME, {"Description": "AWS StackName",
+                                          "Type": "String"})
+        self.addschema(PARAM_REGION, {
+            "Description": "AWS Regions",
+            "Default": "ap-southeast-1",
+            "Type": "String",
+            "AllowedValues": ["us-east-1", "us-west-1", "us-west-2",
+                              "sa-east-1", "eu-west-1", "ap-southeast-1",
+                              "ap-northeast-1"],
+            "ConstraintDescription": "must be a valid EC2 instance type.",
+        })
+
+        for param, schema in schemata.items():
+            self.addschema(param, copy.deepcopy(schema))
+
+    def user_parameters(self):
+        '''
+        Return a dictionary of all the parameters passed in by the user
+        '''
+        return dict((k, v['Value']) for k, v in self.data.iteritems()
+                                    if 'Value' in v)
+
+
+class Template(object):
+    '''A stack template.'''
+
+    def __init__(self, template, template_id=None):
+        '''
+        Initialise the template with a JSON object and a set of Parameters
+        '''
+        self.id = template_id
+        self.t = template
+        self.maps = self[MAPPINGS]
+
+    @classmethod
+    def load(cls, context, template_id):
+        '''Retrieve a Template with the given ID from the database'''
+        t = db_api.raw_template_get(context, template_id)
+        return cls(t.template, template_id)
+
+    def store(self, context=None):
+        '''Store the Template in the database and return its ID'''
+        if self.id is None:
+            rt = {'template': self.t}
+            new_rt = db_api.raw_template_create(context, rt)
+            self.id = new_rt.id
+        return self.id
+
+    def __getitem__(self, section):
+        '''Get the relevant section in the template'''
+        if section not in SECTIONS:
+            raise KeyError('"%s" is not a valid template section' % section)
+        if section == VERSION:
+            return self.t[section]
+
+        if section == DESCRIPTION:
+            default = 'No description'
+        else:
+            default = {}
+
+        return self.t.get(section, default)
+
+    def resolve_find_in_map(self, s):
+        '''
+        Resolve constructs of the form { "Fn::FindInMap" : [ "mapping",
+                                                             "key",
+                                                             "value" ] }
+        '''
+        def handle_find_in_map(args):
+            try:
+                name, key, value = args
+                return self.maps[name][key][value]
+            except (ValueError, TypeError) as ex:
+                raise KeyError(str(ex))
+
+        return _resolve(lambda k, v: k == 'Fn::FindInMap',
+                        handle_find_in_map, s)
+
+    @staticmethod
+    def resolve_availability_zones(s):
+        '''
+            looking for { "Fn::GetAZs" : "str" }
+        '''
+        def match_get_az(key, value):
+            return (key == 'Fn::GetAZs' and
+                    isinstance(value, basestring))
+
+        def handle_get_az(ref):
+            return ['nova']
+
+        return _resolve(match_get_az, handle_get_az, s)
+
+    @staticmethod
+    def resolve_param_refs(s, parameters):
+        '''
+        Resolve constructs of the form { "Ref" : "string" }
+        '''
+        def match_param_ref(key, value):
+            return (key == 'Ref' and
+                    isinstance(value, basestring) and
+                    value in parameters)
+
+        def handle_param_ref(ref):
+            try:
+                return parameters[ref]
+            except (KeyError, ValueError):
+                raise exception.UserParameterMissing(key=ref)
+
+        return _resolve(match_param_ref, handle_param_ref, s)
+
+    @staticmethod
+    def resolve_resource_refs(s, resources):
+        '''
+        Resolve constructs of the form { "Ref" : "resource" }
+        '''
+        def match_resource_ref(key, value):
+            return key == 'Ref' and value in resources
+
+        def handle_resource_ref(arg):
+            return resources[arg].FnGetRefId()
+
+        return _resolve(match_resource_ref, handle_resource_ref, s)
+
+    @staticmethod
+    def resolve_attributes(s, resources):
+        '''
+        Resolve constructs of the form { "Fn::GetAtt" : [ "WebServer",
+                                                          "PublicIp" ] }
+        '''
+        def handle_getatt(args):
+            resource, att = args
+            try:
+                return resources[resource].FnGetAtt(att)
+            except KeyError:
+                raise exception.InvalidTemplateAttribute(resource=resource,
+                                                         key=att)
+
+        return _resolve(lambda k, v: k == 'Fn::GetAtt', handle_getatt, s)
+
+    @staticmethod
+    def resolve_joins(s):
+        '''
+        Resolve constructs of the form { "Fn::Join" : [ "delim", [ "str1",
+                                                                   "str2" ] }
+        '''
+        def handle_join(args):
+            if not isinstance(args, (list, tuple)):
+                raise TypeError('Arguments to "Fn::Join" must be a list')
+            delim, strings = args
+            if not isinstance(strings, (list, tuple)):
+                raise TypeError('Arguments to "Fn::Join" not fully resolved')
+            return delim.join(strings)
+
+        return _resolve(lambda k, v: k == 'Fn::Join', handle_join, s)
+
+    @staticmethod
+    def resolve_base64(s):
+        '''
+        Resolve constructs of the form { "Fn::Base64" : "string" }
+        '''
+        def handle_base64(string):
+            if not isinstance(string, basestring):
+                raise TypeError('Arguments to "Fn::Base64" not fully resolved')
+            return string
+
+        return _resolve(lambda k, v: k == 'Fn::Base64', handle_base64, s)
+
 
 class Stack(object):
     IN_PROGRESS = 'IN_PROGRESS'
@@ -35,46 +234,65 @@ class Stack(object):
     DELETE_FAILED = 'DELETE_FAILED'
     DELETE_COMPLETE = 'DELETE_COMPLETE'
 
-    def __init__(self, context, stack_name, template, stack_id=0, parms=None):
+    def __init__(self, context, stack_name, template, parameters=None,
+                 stack_id=None):
+        '''
+        Initialise from a context, name, Template object and (optionally)
+        Parameters object. The database ID may also be initialised, if the
+        stack is already in the database.
+        '''
         self.id = stack_id
         self.context = context
         self.t = template
-        self.maps = self.t.get('Mappings', {})
-        self.res = {}
-        self.doc = None
         self.name = stack_name
 
-        # Default Parameters
-        self.parms = checkeddict.CheckedDict('Parameters')
-        self.parms.addschema('AWS::StackName', {"Description": "AWS StackName",
-                                                "Type": "String"})
-        self.parms['AWS::StackName'] = stack_name
-        self.parms.addschema('AWS::Region', {"Description": "AWS Regions",
-            "Default": "ap-southeast-1",
-            "Type": "String",
-            "AllowedValues": ["us-east-1", "us-west-1", "us-west-2",
-                              "sa-east-1", "eu-west-1", "ap-southeast-1",
-                              "ap-northeast-1"],
-            "ConstraintDescription": "must be a valid EC2 instance type."})
-
-        # template Parameters
-        ps = self.t.get('Parameters', {})
-        for p in ps:
-            self.parms.addschema(p, ps[p])
-
-        # user Parameters
-        if parms is not None:
-            self.parms.update(parms)
+        if parameters is None:
+            parameters = Parameters(stack_name, template)
+        self.parameters = parameters
 
-        self.outputs = self.resolve_static_data(self.t.get('Outputs', {}))
+        self.outputs = self.resolve_static_data(self.t[OUTPUTS])
 
         self.resources = dict((name,
                                Resource(name, data, self))
-                              for (name, data) in self.t['Resources'].items())
+                              for (name, data) in self.t[RESOURCES].items())
+
+        self.dependencies = self._get_dependencies(self.resources.itervalues())
+
+    @staticmethod
+    def _get_dependencies(resources):
+        '''Return the dependency graph for a list of resources'''
+        deps = dependencies.Dependencies()
+        for resource in resources:
+            resource.add_dependencies(deps)
+
+        return deps
+
+    @classmethod
+    def load(cls, context, stack_id):
+        '''Retrieve a Stack from the database'''
+        s = db_api.stack_get(context, stack_id)
+
+        template = Template.load(context, s.raw_template_id)
+        params = Parameters(s.name, template, s.parameters)
+        stack = cls(context, s.name, template, params, stack_id)
+
+        return stack
 
-        self.dependencies = dependencies.Dependencies()
-        for resource in self.resources.values():
-            resource.add_dependencies(self.dependencies)
+    def store(self, owner=None):
+        '''Store the stack in the database and return its ID'''
+        if self.id is None:
+            new_creds = db_api.user_creds_create(self.context.to_dict())
+
+            s = {'name': self.name,
+                 'raw_template_id': self.t.store(),
+                 'parameters': self.parameters.user_parameters(),
+                 'owner_id': owner and owner.id,
+                 'user_creds_id': new_creds.id,
+                 'username': self.context.username}
+            new_s = db_api.stack_create(self.context, s)
+            self.id = new_s.id
+
+        return self.id
 
     def __iter__(self):
         '''
@@ -103,9 +321,11 @@ class Stack(object):
         return key in self.resources
 
     def keys(self):
+        '''Return a list of resource keys for the stack'''
         return self.resources.keys()
 
     def __str__(self):
+        '''Return a human-readable string representation of the stack'''
         return 'Stack "%s"' % self.name
 
     def validate(self):
@@ -115,8 +335,6 @@ class Stack(object):
         '''
         # TODO(sdake) Should return line number of invalid reference
 
-        response = None
-
         for res in self:
             try:
                 result = res.validate()
@@ -126,35 +344,27 @@ class Stack(object):
 
             if result:
                 err_str = 'Malformed Query Response %s' % result
-                response = {'ValidateTemplateResult': {
-                                'Description': err_str,
-                                'Parameters': []}}
+                response = {'Description': err_str,
+                            'Parameters': []}
                 return response
 
-        if response is None:
-            response = {'ValidateTemplateResult': {
-                        'Description': 'Successfully validated',
-                        'Parameters': []}}
-        for p in self.parms:
-            jp = {'member': {}}
-            res = jp['member']
-            res['NoEcho'] = 'false'
-            res['ParameterKey'] = p
-            res['Description'] = self.parms.get_attr(p, 'Description')
-            res['DefaultValue'] = self.parms.get_attr(p, 'Default')
-            response['ValidateTemplateResult']['Parameters'].append(res)
-        return response
+        def format_param(p):
+            return {'NoEcho': 'false',
+                    'ParameterKey': p,
+                    'Description': self.parameters.get_attr(p, 'Description'),
+                    'DefaultValue': self.parameters.get_attr(p, 'Default')}
 
-    def state_set(self, new_status, reason='change in resource state'):
-        if self.id != 0:
-            stack = db_api.stack_get(self.context, self.id)
-        else:
-            stack = db_api.stack_get_by_name(self.context, self.name)
+        response = {'Description': 'Successfully validated',
+                    'Parameters': [format_param(p) for p in self.parameters]}
 
-        if stack is None:
+        return response
+
+    def state_set(self, new_status, reason):
+        '''Update the stack state in the database'''
+        if self.id is None:
             return
 
-        self.id = stack.id
+        stack = db_api.stack_get(self.context, self.id)
         stack.update_and_save({'status': new_status,
                                'status_reason': reason})
 
@@ -199,7 +409,7 @@ class Stack(object):
         '''
         Delete all of the resources, and then the stack itself.
         '''
-        self.state_set(self.DELETE_IN_PROGRESS)
+        self.state_set(self.DELETE_IN_PROGRESS, 'Stack deletion started')
 
         failures = []
 
@@ -253,104 +463,25 @@ class Stack(object):
                     logger.exception('create')
                     failed = True
             else:
-                res.state_set(res.CREATE_FAILED)
+                res.state_set(res.CREATE_FAILED, 'Resource restart aborted')
         # TODO(asalkeld) if any of this fails we Should
         # restart the whole stack
 
-    def parameter_get(self, key):
-        if not key in self.parms:
-            raise exception.UserParameterMissing(key=key)
-        try:
-            return self.parms[key]
-        except ValueError:
-            raise exception.UserParameterMissing(key=key)
-
-    def _resolve_static_refs(self, s):
-        '''
-            looking for { "Ref" : "str" }
-        '''
-        def match(key, value):
-            return (key == 'Ref' and
-                    isinstance(value, basestring) and
-                    value in self.parms)
-
-        def handle(ref):
-            return self.parameter_get(ref)
-
-        return _resolve(match, handle, s)
-
-    def _resolve_availability_zones(self, s):
-        '''
-            looking for { "Fn::GetAZs" : "str" }
-        '''
-        def match(key, value):
-            return (key == 'Fn::GetAZs' and
-                    isinstance(value, basestring))
-
-        def handle(ref):
-            return ['nova']
-
-        return _resolve(match, handle, s)
-
-    def _resolve_find_in_map(self, s):
-        def handle(args):
-            try:
-                name, key, value = args
-                return self.maps[name][key][value]
-            except (ValueError, TypeError) as ex:
-                raise KeyError(str(ex))
-
-        return _resolve(lambda k, v: k == 'Fn::FindInMap', handle, s)
-
-    def _resolve_attributes(self, s):
-        '''
-            looking for something like:
-            { "Fn::GetAtt" : [ "DBInstance", "Endpoint.Address" ] }
-        '''
-        def match_ref(key, value):
-            return key == 'Ref' and value in self
-
-        def handle_ref(arg):
-            return self[arg].FnGetRefId()
-
-        def handle_getatt(args):
-            resource, att = args
-            try:
-                return self[resource].FnGetAtt(att)
-            except KeyError:
-                raise exception.InvalidTemplateAttribute(resource=resource,
-                                                         key=att)
-
-        return _resolve(lambda k, v: k == 'Fn::GetAtt', handle_getatt,
-                        _resolve(match_ref, handle_ref, s))
-
-    @staticmethod
-    def _resolve_joins(s):
-        '''
-            looking for { "Fn::Join" : [] }
-        '''
-        def handle(args):
-            delim, strings = args
-            return delim.join(strings)
-
-        return _resolve(lambda k, v: k == 'Fn::Join', handle, s)
-
-    @staticmethod
-    def _resolve_base64(s):
-        '''
-            looking for { "Fn::Base64" : "" }
-        '''
-        return _resolve(lambda k, v: k == 'Fn::Base64', lambda d: d, s)
-
     def resolve_static_data(self, snippet):
-        return transform(snippet, [self._resolve_static_refs,
-                                   self._resolve_availability_zones,
-                                   self._resolve_find_in_map])
+        return transform(snippet,
+                         [functools.partial(self.t.resolve_param_refs,
+                                            parameters=self.parameters),
+                          self.t.resolve_availability_zones,
+                          self.t.resolve_find_in_map])
 
     def resolve_runtime_data(self, snippet):
-        return transform(snippet, [self._resolve_attributes,
-                                   self._resolve_joins,
-                                   self._resolve_base64])
+        return transform(snippet,
+                         [functools.partial(self.t.resolve_resource_refs,
+                                            resources=self.resources),
+                          functools.partial(self.t.resolve_attributes,
+                                            resources=self.resources),
+                          self.t.resolve_joins,
+                          self.t.resolve_base64])
 
 
 def transform(data, transformations):
index 9bacefa6629604f8861450132aaf27a0a25e2bf9..a42f7898ea689386d9a0c7c4312f842fc3c6e84a 100644 (file)
@@ -45,45 +45,30 @@ class Stack(Resource):
         return p
 
     def nested(self):
-        if self._nested is None:
-            if self.instance_id is None:
-                return None
+        if self._nested is None and self.instance_id is not None:
+            self._nested = parser.Stack.load(self.stack.context,
+                                             self.instance_id)
 
-            st = db_api.stack_get(self.stack.context, self.instance_id)
-            if not st:
+            if self._nested is None:
                 raise exception.NotFound('Nested stack not found in DB')
 
-            n = parser.Stack(self.stack.context, st.name,
-                             st.raw_template.template,
-                             self.instance_id, self._params())
-            self._nested = n
-
         return self._nested
 
     def create_with_template(self, child_template):
         '''
         Handle the creation of the nested stack from a given JSON template.
         '''
+        template = parser.Template(child_template)
+        params = parser.Parameters(self.name, template, self._params())
+
         self._nested = parser.Stack(self.stack.context,
                                     self.name,
-                                    child_template,
-                                    parms=self._params())
-
-        rt = {'template': child_template, 'stack_name': self.name}
-        new_rt = db_api.raw_template_create(None, rt)
-
-        parent_stack = db_api.stack_get(self.stack.context, self.stack.id)
-
-        s = {'name': self.name,
-             'owner_id': self.stack.id,
-             'raw_template_id': new_rt.id,
-             'user_creds_id': parent_stack.user_creds_id,
-             'username': self.stack.context.username}
-        new_s = db_api.stack_create(None, s)
-        self._nested.id = new_s.id
+                                    template,
+                                    params)
 
+        nested_id = self._nested.store(self.stack)
+        self.instance_id_set(nested_id)
         self._nested.create()
-        self.instance_id_set(self._nested.id)
 
     def handle_create(self):
         response = urllib2.urlopen(self.properties[PROP_TEMPLATE_URL])
@@ -110,7 +95,7 @@ class Stack(Resource):
         if stack is None:
             # This seems like a hack, to get past validation
             return ''
-        if op not in self.nested().outputs:
+        if op not in stack.outputs:
             raise exception.InvalidTemplateAttribute(resource=self.name,
                                                      key=key)
 
index f7591605ee1e1103d111deb49d162dcd4862714f..e852eb3087665e02996e81f32fc85bb808d602ae 100644 (file)
@@ -206,13 +206,13 @@ class TestBinHeat():
         t = json.loads(f.read())
         f.close()
 
-        params = {}
-        params['KeyStoneCreds'] = None
-        t['Parameters']['KeyName']['Value'] = keyname
-        t['Parameters']['DBUsername']['Value'] = dbusername
-        t['Parameters']['DBPassword']['Value'] = creds['password']
+        template = parser.Template(t)
+        params = parser.Parameters('test', t,
+                                   {'KeyName': keyname,
+                                    'DBUsername': dbusername,
+                                    'DBPassword': creds['password']})
 
-        stack = parser.Stack('test', t, 0, params)
+        stack = parser.Stack(None, 'test', template, params)
         parsed_t = stack.resolve_static_refs(t)
         remote_file = sftp.open('/var/lib/cloud/instance/scripts/startup')
         remote_file_list = remote_file.read().split('\n')
index 449e8f46cf3001ffa17d0f2a39fff0a45c639d52..8cc517c7813ebda3a87ddc637374111d1603e6b8 100644 (file)
@@ -1,8 +1,13 @@
 import nose
 import unittest
 from nose.plugins.attrib import attr
+import mox
 
-from heat.engine.parser import _resolve as resolve
+import json
+from heat.common import exception
+from heat.engine import parser
+from heat.engine import checkeddict
+from heat.engine.resources import Resource
 
 
 def join(raw):
@@ -10,7 +15,7 @@ def join(raw):
         delim, strs = args
         return delim.join(strs)
 
-    return resolve(lambda k, v: k == 'Fn::Join', handle_join, raw)
+    return parser._resolve(lambda k, v: k == 'Fn::Join', handle_join, raw)
 
 
 @attr(tag=['unit', 'parser'])
@@ -76,6 +81,225 @@ class ParserTest(unittest.TestCase):
         self.assertEqual(join(raw), 'foo bar\nbaz')
 
 
+mapping_template = json.loads('''{
+  "Mappings" : {
+    "ValidMapping" : {
+      "TestKey" : { "TestValue" : "wibble" }
+    },
+    "InvalidMapping" : {
+      "ValueList" : [ "foo", "bar" ],
+      "ValueString" : "baz"
+    },
+    "MapList": [ "foo", { "bar" : "baz" } ],
+    "MapString": "foobar"
+  }
+}''')
+
+
+@attr(tag=['unit', 'parser', 'template'])
+@attr(speed='fast')
+class TemplateTest(unittest.TestCase):
+    def setUp(self):
+        self.m = mox.Mox()
+
+    def tearDown(self):
+        self.m.UnsetStubs()
+
+    def test_defaults(self):
+        empty = parser.Template({})
+        try:
+            empty[parser.VERSION]
+        except KeyError:
+            pass
+        else:
+            self.fail('Expected KeyError for version not present')
+        self.assertEqual(empty[parser.DESCRIPTION], 'No description')
+        self.assertEqual(empty[parser.MAPPINGS], {})
+        self.assertEqual(empty[parser.PARAMETERS], {})
+        self.assertEqual(empty[parser.RESOURCES], {})
+        self.assertEqual(empty[parser.OUTPUTS], {})
+
+    def test_invalid_section(self):
+        tmpl = parser.Template({'Foo': ['Bar']})
+        try:
+            tmpl['Foo']
+        except KeyError:
+            pass
+        else:
+            self.fail('Expected KeyError for invalid template key')
+
+    def test_find_in_map(self):
+        tmpl = parser.Template(mapping_template)
+        find = {'Fn::FindInMap': ["ValidMapping", "TestKey", "TestValue"]}
+        self.assertEqual(tmpl.resolve_find_in_map(find), "wibble")
+
+    def test_find_in_invalid_map(self):
+        tmpl = parser.Template(mapping_template)
+        finds = ({'Fn::FindInMap': ["InvalidMapping", "ValueList", "foo"]},
+                 {'Fn::FindInMap': ["InvalidMapping", "ValueString", "baz"]},
+                 {'Fn::FindInMap': ["MapList", "foo", "bar"]},
+                 {'Fn::FindInMap': ["MapString", "foo", "bar"]})
+
+        for find in finds:
+            self.assertRaises(KeyError, tmpl.resolve_find_in_map, find)
+
+    def test_bad_find_in_map(self):
+        tmpl = parser.Template(mapping_template)
+        finds = ({'Fn::FindInMap': "String"},
+                 {'Fn::FindInMap': {"Dict": "String"}},
+                 {'Fn::FindInMap': ["ShortList", "foo"]},
+                 {'Fn::FindInMap': ["ReallyShortList"]})
+
+        for find in finds:
+            self.assertRaises(KeyError, tmpl.resolve_find_in_map, find)
+
+    def test_param_refs(self):
+        params = {'foo': 'bar', 'blarg': 'wibble'}
+        p_snippet = {"Ref": "foo"}
+        self.assertEqual(parser.Template.resolve_param_refs(p_snippet, params),
+                         "bar")
+
+    def test_param_refs_resource(self):
+        params = {'foo': 'bar', 'blarg': 'wibble'}
+        r_snippet = {"Ref": "baz"}
+        self.assertEqual(parser.Template.resolve_param_refs(r_snippet, params),
+                         r_snippet)
+
+    def test_param_ref_missing(self):
+        params = checkeddict.CheckedDict("test")
+        params.addschema('foo', {"Required": True})
+        snippet = {"Ref": "foo"}
+        self.assertRaises(exception.UserParameterMissing,
+                          parser.Template.resolve_param_refs,
+                          snippet, params)
+
+    def test_resource_refs(self):
+        resources = {'foo': self.m.CreateMock(Resource),
+                     'blarg': self.m.CreateMock(Resource)}
+        resources['foo'].FnGetRefId().AndReturn('bar')
+        self.m.ReplayAll()
+
+        r_snippet = {"Ref": "foo"}
+        self.assertEqual(parser.Template.resolve_resource_refs(r_snippet,
+                                                               resources),
+                         "bar")
+        self.m.VerifyAll()
+
+    def test_resource_refs_param(self):
+        resources = {'foo': 'bar', 'blarg': 'wibble'}
+        p_snippet = {"Ref": "baz"}
+        self.assertEqual(parser.Template.resolve_resource_refs(p_snippet,
+                                                               resources),
+                         p_snippet)
+
+    def test_join(self):
+        join = {"Fn::Join": [" ", ["foo", "bar"]]}
+        self.assertEqual(parser.Template.resolve_joins(join), "foo bar")
+
+    def test_join_string(self):
+        join = {"Fn::Join": [" ", "foo"]}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join)
+
+    def test_join_dict(self):
+        join = {"Fn::Join": [" ", {"foo": "bar"}]}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join)
+
+    def test_join_wrong_num_args(self):
+        join0 = {"Fn::Join": []}
+        self.assertRaises(ValueError, parser.Template.resolve_joins,
+                          join0)
+        join1 = {"Fn::Join": [" "]}
+        self.assertRaises(ValueError, parser.Template.resolve_joins,
+                          join1)
+        join3 = {"Fn::Join": [" ", {"foo": "bar"}, ""]}
+        self.assertRaises(ValueError, parser.Template.resolve_joins,
+                          join3)
+
+    def test_join_string_nodelim(self):
+        join1 = {"Fn::Join": "o"}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join1)
+        join2 = {"Fn::Join": "oh"}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join2)
+        join3 = {"Fn::Join": "ohh"}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join3)
+
+    def test_join_dict_nodelim(self):
+        join1 = {"Fn::Join": {"foo": "bar"}}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join1)
+        join2 = {"Fn::Join": {"foo": "bar", "blarg": "wibble"}}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join2)
+        join3 = {"Fn::Join": {"foo": "bar", "blarg": "wibble", "baz": "quux"}}
+        self.assertRaises(TypeError, parser.Template.resolve_joins,
+                          join3)
+
+    def test_base64(self):
+        snippet = {"Fn::Base64": "foobar"}
+        # For now, the Base64 function just returns the original text, and
+        # does not convert to base64 (see issue #133)
+        self.assertEqual(parser.Template.resolve_base64(snippet), "foobar")
+
+    def test_base64_list(self):
+        list_snippet = {"Fn::Base64": ["foobar"]}
+        self.assertRaises(TypeError, parser.Template.resolve_base64,
+                          list_snippet)
+
+    def test_base64_dict(self):
+        dict_snippet = {"Fn::Base64": {"foo": "bar"}}
+        self.assertRaises(TypeError, parser.Template.resolve_base64,
+                          dict_snippet)
+
+
+params_schema = json.loads('''{
+  "Parameters" : {
+    "User" : { "Type": "String" },
+    "Defaulted" : {
+      "Type": "String",
+      "Default": "foobar"
+    }
+  }
+}''')
+
+
+@attr(tag=['unit', 'parser', 'parameters'])
+@attr(speed='fast')
+class ParametersTest(unittest.TestCase):
+    def test_pseudo_params(self):
+        params = parser.Parameters('test_stack', {"Parameters": {}})
+
+        self.assertEqual(params['AWS::StackName'], 'test_stack')
+        self.assertTrue('AWS::Region' in params)
+
+    def test_user_param(self):
+        params = parser.Parameters('test', params_schema, {'User': 'wibble'})
+        user_params = params.user_parameters()
+        self.assertEqual(user_params['User'], 'wibble')
+
+    def test_user_param_default(self):
+        params = parser.Parameters('test', params_schema)
+        user_params = params.user_parameters()
+        self.assertTrue('Defaulted' not in user_params)
+
+    def test_user_param_nonexist(self):
+        params = parser.Parameters('test', params_schema)
+        user_params = params.user_parameters()
+        self.assertTrue('User' not in user_params)
+
+    def test_schema_invariance(self):
+        params1 = parser.Parameters('test', params_schema)
+        params1['Defaulted'] = "wibble"
+        self.assertEqual(params1['Defaulted'], 'wibble')
+
+        params2 = parser.Parameters('test', params_schema)
+        self.assertEqual(params2['Defaulted'], 'foobar')
+
+
 # allows testing of the test directly, shown below
 if __name__ == '__main__':
     sys.argv.append(__file__)
index 0e924ed45a1da06a3ed1fe27f2724c809f71c24e..2b0c2cebbb27302a3fef54eebd463c1b2b7cacad 100644 (file)
@@ -34,19 +34,16 @@ class instancesTest(unittest.TestCase):
         t = json.loads(f.read())
         f.close()
 
-        parameters = {}
         t['Parameters']['KeyName']['Value'] = 'test'
-        stack = parser.Stack(None, 'test_stack', t, 0)
+        stack = parser.Stack(None, 'test_stack', parser.Template(t),
+                             stack_id=-1)
 
         self.m.StubOutWithMock(db_api, 'resource_get_by_name_and_stack')
         db_api.resource_get_by_name_and_stack(None, 'test_resource_name',
                                               stack).AndReturn(None)
 
         self.m.StubOutWithMock(instances.Instance, 'nova')
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
+        instances.Instance.nova().MultipleTimes().AndReturn(self.fc)
 
         self.m.ReplayAll()
 
@@ -80,9 +77,9 @@ class instancesTest(unittest.TestCase):
         t = json.loads(f.read())
         f.close()
 
-        parameters = {}
         t['Parameters']['KeyName']['Value'] = 'test'
-        stack = parser.Stack(None, 'test_stack', t, 0)
+        stack = parser.Stack(None, 'test_stack', parser.Template(t),
+                             stack_id=-1)
 
         self.m.StubOutWithMock(db_api, 'resource_get_by_name_and_stack')
         db_api.resource_get_by_name_and_stack(None, 'test_resource_name',
index ab758e3770270b93256f85641e54b4271f46e529..c234303d4cf529cac6f1c2208c6333511b2cc82a 100644 (file)
@@ -24,27 +24,37 @@ class stacksTest(unittest.TestCase):
     def setUp(self):
         self.m = mox.Mox()
         self.fc = fakes.FakeClient()
-        self.path = os.path.dirname(os.path.realpath(__file__)).\
-            replace('heat/tests', 'templates')
+        path = os.path.dirname(os.path.realpath(__file__))
+        self.path = path.replace(os.path.join('heat', 'tests'), 'templates')
 
     def tearDown(self):
         self.m.UnsetStubs()
         print "stackTest teardown complete"
 
+    def create_context(self, user='stacks_test_user'):
+        ctx = context.get_admin_context()
+        self.m.StubOutWithMock(ctx, 'username')
+        ctx.username = user
+        self.m.StubOutWithMock(auth, 'authenticate')
+        return ctx
+
     # We use this in a number of tests so it's factored out here.
-    def start_wordpress_stack(self, stack_name):
-        f = open("%s/WordPress_Single_Instance_gold.template" % self.path)
-        t = json.loads(f.read())
-        f.close()
-        params = {}
-        parameters = {}
-        t['Parameters']['KeyName']['Value'] = 'test'
-        stack = parser.Stack(None, stack_name, t, 0, params)
+    def get_wordpress_stack(self, stack_name, ctx=None):
+        tmpl_path = os.path.join(self.path,
+                                 'WordPress_Single_Instance_gold.template')
+        with open(tmpl_path) as f:
+            t = json.load(f)
+
+        template = parser.Template(t)
+        parameters = parser.Parameters(stack_name, template,
+                                       {'KeyName': 'test'})
+
+        stack = parser.Stack(ctx or self.create_context(),
+                             stack_name, template, parameters)
+
         self.m.StubOutWithMock(instances.Instance, 'nova')
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
-        instances.Instance.nova().AndReturn(self.fc)
+        instances.Instance.nova().MultipleTimes().AndReturn(self.fc)
+
         instance = stack.resources['WebServer']
         instance.itype_oflavor['m1.large'] = 'm1.large'
         instance.calculate_properties()
@@ -55,10 +65,11 @@ class stacksTest(unittest.TestCase):
                 name='WebServer', security_groups=None,
                 userdata=server_userdata).\
                 AndReturn(self.fc.servers.list()[-1])
+
         return stack
 
     def test_wordpress_single_instance_stack_create(self):
-        stack = self.start_wordpress_stack('test_stack')
+        stack = self.get_wordpress_stack('test_stack')
         self.m.ReplayAll()
         stack.create()
 
@@ -67,48 +78,28 @@ class stacksTest(unittest.TestCase):
         self.assertNotEqual(stack.resources['WebServer'].ipaddress, '0.0.0.0')
 
     def test_wordpress_single_instance_stack_delete(self):
-        stack = self.start_wordpress_stack('test_stack')
+        ctx = self.create_context()
+        stack = self.get_wordpress_stack('test_stack', ctx)
         self.m.ReplayAll()
-        rt = {}
-        rt['template'] = stack.t
-        rt['StackName'] = stack.name
-        new_rt = db_api.raw_template_create(None, rt)
-        ct = {'username': 'fred',
-                   'password': 'mentions_fruit'}
-        new_creds = db_api.user_creds_create(ct)
-        s = {}
-        s['name'] = stack.name
-        s['raw_template_id'] = new_rt.id
-        s['user_creds_id'] = new_creds.id
-        s['username'] = ct['username']
-        new_s = db_api.stack_create(None, s)
-        stack.id = new_s.id
+        stack_id = stack.store()
         stack.create()
+
+        db_s = db_api.stack_get(ctx, stack_id)
+        self.assertNotEqual(db_s, None)
+
         self.assertNotEqual(stack.resources['WebServer'], None)
         self.assertTrue(stack.resources['WebServer'].instance_id > 0)
 
         stack.delete()
 
         self.assertEqual(stack.resources['WebServer'].state, 'DELETE_COMPLETE')
-        self.assertEqual(new_s.status, 'DELETE_COMPLETE')
+        self.assertEqual(db_api.stack_get(ctx, stack_id), None)
+        self.assertEqual(db_s.status, 'DELETE_COMPLETE')
 
     def test_stack_event_list(self):
-        stack = self.start_wordpress_stack('test_event_list_stack')
+        stack = self.get_wordpress_stack('test_event_list_stack')
         self.m.ReplayAll()
-        rt = {}
-        rt['template'] = stack.t
-        rt['StackName'] = stack.name
-        new_rt = db_api.raw_template_create(None, rt)
-        ct = {'username': 'fred',
-                   'password': 'mentions_fruit'}
-        new_creds = db_api.user_creds_create(ct)
-        s = {}
-        s['name'] = stack.name
-        s['raw_template_id'] = new_rt.id
-        s['user_creds_id'] = new_creds.id
-        s['username'] = ct['username']
-        new_s = db_api.stack_create(None, s)
-        stack.id = new_s.id
+        stack.store()
         stack.create()
 
         self.assertNotEqual(stack.resources['WebServer'], None)
@@ -135,46 +126,82 @@ class stacksTest(unittest.TestCase):
                              'm1.large')
 
     def test_stack_list(self):
-        stack = self.start_wordpress_stack('test_stack_list')
-        rt = {}
-        rt['template'] = stack.t
-        rt['StackName'] = stack.name
-        new_rt = db_api.raw_template_create(None, rt)
-        ct = {'username': 'fred',
-              'password': 'mentions_fruit'}
-        new_creds = db_api.user_creds_create(ct)
+        ctx = self.create_context()
+        auth.authenticate(ctx).AndReturn(True)
 
-        ctx = context.get_admin_context()
-        self.m.StubOutWithMock(ctx, 'username')
-        ctx.username = 'fred'
-        self.m.StubOutWithMock(auth, 'authenticate')
+        stack = self.get_wordpress_stack('test_stack_list', ctx)
+
+        self.m.ReplayAll()
+        stack.store()
+        stack.create()
+
+        man = manager.EngineManager()
+        sl = man.list_stacks(ctx, {})
+
+        self.assertTrue(len(sl['stacks']) > 0)
+        for s in sl['stacks']:
+            self.assertNotEqual(s['StackId'], None)
+            self.assertNotEqual(s['TemplateDescription'].find('WordPress'), -1)
+
+    def test_stack_describe_all(self):
+        ctx = self.create_context('stack_describe_all')
         auth.authenticate(ctx).AndReturn(True)
 
-        s = {}
-        s['name'] = stack.name
-        s['raw_template_id'] = new_rt.id
-        s['user_creds_id'] = new_creds.id
-        s['username'] = ct['username']
-        new_s = db_api.stack_create(ctx, s)
-        stack.id = new_s.id
-        instances.Instance.nova().AndReturn(self.fc)
+        stack = self.get_wordpress_stack('test_stack_desc_all', ctx)
+
         self.m.ReplayAll()
+        stack.store()
         stack.create()
 
-        f = open("%s/WordPress_Single_Instance_gold.template" % self.path)
-        t = json.loads(f.read())
-        params = {}
-        parameters = {}
-        t['Parameters']['KeyName']['Value'] = 'test'
-        stack = parser.Stack(ctx, 'test_stack_list', t, 0, params)
+        man = manager.EngineManager()
+        sl = man.show_stack(ctx, None, {})
+
+        self.assertEqual(len(sl['stacks']), 1)
+        for s in sl['stacks']:
+            self.assertNotEqual(s['StackId'], None)
+            self.assertNotEqual(s['Description'].find('WordPress'), -1)
+
+    def test_stack_describe_all_empty(self):
+        ctx = self.create_context('stack_describe_all_empty')
+        auth.authenticate(ctx).AndReturn(True)
+
+        self.m.ReplayAll()
 
         man = manager.EngineManager()
-        sl = man.list_stacks(ctx, params)
+        sl = man.show_stack(ctx, None, {})
+
+        self.assertEqual(len(sl['stacks']), 0)
+
+    def test_stack_describe_nonexistent(self):
+        ctx = self.create_context()
+        auth.authenticate(ctx).AndReturn(True)
+
+        self.m.ReplayAll()
+
+        man = manager.EngineManager()
+        sl = man.show_stack(ctx, 'wibble', {})
+
+        self.assertEqual(len(sl['stacks']), 0)
+
+    def test_stack_describe(self):
+        ctx = self.create_context('stack_describe')
+        auth.authenticate(ctx).AndReturn(True)
+
+        stack = self.get_wordpress_stack('test_stack_desc', ctx)
+
+        self.m.ReplayAll()
+        stack.store()
+        stack.create()
+
+        man = manager.EngineManager()
+        sl = man.show_stack(ctx, 'test_stack_desc', {})
 
         self.assertTrue(len(sl['stacks']) > 0)
         for s in sl['stacks']:
-            self.assertTrue(s['StackId'] > 0)
-            self.assertNotEqual(s['TemplateDescription'].find('WordPress'), -1)
+            self.assertEqual(s['StackName'], 'test_stack_desc')
+            self.assertTrue('CreationTime' in s)
+            self.assertNotEqual(s['StackId'], None)
+            self.assertNotEqual(s['Description'].find('WordPress'), -1)
 
     # allows testing of the test directly
     if __name__ == '__main__':
index 14713cac846f6e7a8698893ca66f8bf9f37dfb2b..967872d19b9f5a301db4faf3b10aa1911c790a8f 100644 (file)
@@ -215,8 +215,7 @@ class validateTest(unittest.TestCase):
         t = json.loads(test_template_volumeattach % 'vdq')
         self.m.StubOutWithMock(auth, 'authenticate')
         auth.authenticate(None).AndReturn(True)
-        params = {}
-        stack = parser.Stack(None, 'test_stack', t, 0, params)
+        stack = parser.Stack(None, 'test_stack', parser.Template(t))
 
         self.m.StubOutWithMock(db_api, 'resource_get_by_name_and_stack')
         db_api.resource_get_by_name_and_stack(None, 'test_resource_name',
@@ -224,14 +223,13 @@ class validateTest(unittest.TestCase):
 
         self.m.ReplayAll()
         volumeattach = stack.resources['MountPoint']
-        assert(volumeattach.validate() is None)
+        self.assertTrue(volumeattach.validate() is None)
 
     def test_validate_volumeattach_invalid(self):
         t = json.loads(test_template_volumeattach % 'sda')
         self.m.StubOutWithMock(auth, 'authenticate')
         auth.authenticate(None).AndReturn(True)
-        params = {}
-        stack = parser.Stack(None, 'test_stack', t, 0, params)
+        stack = parser.Stack(None, 'test_stack', parser.Template(t))
 
         self.m.StubOutWithMock(db_api, 'resource_get_by_name_and_stack')
         db_api.resource_get_by_name_and_stack(None, 'test_resource_name',
index f881bb562871d02f90fd0a21d50fb53ed7f73e76..6d846344253b2eb1eae9f84aef56a9bcaca32343 100644 (file)
@@ -7,12 +7,12 @@ import sys
 
 import nose
 import unittest
-import mox
 from nose.plugins.attrib import attr
 from nose import with_setup
 
 import heat.db as db_api
 from heat.engine import parser
+from heat.common import context
 
 logger = logging.getLogger('test_waitcondition')
 
@@ -41,32 +41,15 @@ test_template_waitcondition = '''
 @attr(speed='slow')
 class stacksTest(unittest.TestCase):
     def setUp(self):
-        self.m = mox.Mox()
         self.greenpool = eventlet.GreenPool()
 
-    def tearDown(self):
-        self.m.UnsetStubs()
-
     def create_stack(self, stack_name, temp, params):
-        stack = parser.Stack(None, stack_name, temp, 0, params)
-
-        rt = {}
-        rt['template'] = temp
-        rt['StackName'] = stack_name
-        new_rt = db_api.raw_template_create(None, rt)
-
-        ct = {'username': 'fred',
-              'password': 'mentions_fruit'}
-        new_creds = db_api.user_creds_create(ct)
-
-        s = {}
-        s['name'] = stack_name
-        s['raw_template_id'] = new_rt.id
-        s['user_creds_id'] = new_creds.id
-        s['username'] = ct['username']
-        new_s = db_api.stack_create(None, s)
-        stack.id = new_s.id
+        template = parser.Template(temp)
+        parameters = parser.Parameters(stack_name, template, params)
+        stack = parser.Stack(context.get_admin_context(), stack_name,
+                             template, parameters)
 
+        stack.store()
         return stack
 
     def test_post_success_to_handle(self):
@@ -74,7 +57,6 @@ class stacksTest(unittest.TestCase):
         t = json.loads(test_template_waitcondition)
         stack = self.create_stack('test_stack', t, params)
 
-        self.m.ReplayAll()
         self.greenpool.spawn_n(stack.create)
         eventlet.sleep(1)
         self.assertEqual(stack.resources['WaitForTheHandle'].state,