]> review.fuel-infra Code Review - openstack-build/heat-build.git/commitdiff
Refactor template resolution
authorZane Bitter <zbitter@redhat.com>
Fri, 1 Jun 2012 08:50:15 +0000 (10:50 +0200)
committerZane Bitter <zbitter@redhat.com>
Mon, 4 Jun 2012 09:31:13 +0000 (11:31 +0200)
Resolve functions in templates by making a copy of the data rather than
modifying the original. This means that e.g. a resource resolving functions
in its own template data does not result in changes to the data held by the
Stack.

This patch also refactors all of the template resolution methods to operate
using a common parsing algorithm to move through the tree.

Finally, the resources have been worked to load data as it is needed,
rather than requiring external code to put them into the correct state
before using them.

Change-Id: I79eafaefc9ced07b652fac7162aa2edbfa7f547a
Signed-off-by: Zane Bitter <zbitter@redhat.com>
heat/engine/eip.py
heat/engine/instance.py
heat/engine/parser.py
heat/engine/resources.py
heat/tests/test_resources.py
heat/tests/test_stacks.py
heat/tests/test_validate.py

index 22bed5d2c5e533dcba5cebba8bc2130e9598ffa5..dd7f21d5188a34bca790c7f2d8db690f907b8f33 100644 (file)
@@ -28,7 +28,18 @@ class ElasticIp(Resource):
 
     def __init__(self, name, json_snippet, stack):
         super(ElasticIp, self).__init__(name, json_snippet, stack)
-        self.ipaddress = ''
+        self.ipaddress = None
+
+    def _ipaddress(self):
+        if self.ipaddress is None:
+            if self.instance_id is not None:
+                try:
+                    ips = self.nova().floating_ips.get(self.instance_id)
+                except NotFound as ex:
+                    logger.warn("Floating IPs not found: %s" % str(ex))
+                else:
+                    self.ipaddress = ips.ip
+        return self.ipaddress or ''
 
     def create(self):
         """Allocate a floating IP for the current tenant."""
@@ -49,19 +60,6 @@ class ElasticIp(Resource):
         '''
         return Resource.validate(self)
 
-    def reload(self):
-        '''
-        get the ipaddress here
-        '''
-        if self.instance_id is not None:
-            try:
-                ips = self.nova().floating_ips.get(self.instance_id)
-                self.ipaddress = ips.ip
-            except Exception as ex:
-                logger.warn("Error getting floating IPs: %s" % str(ex))
-
-        Resource.reload(self)
-
     def delete(self):
         """De-allocate a floating IP."""
         if self.state == self.DELETE_IN_PROGRESS or \
@@ -77,7 +75,7 @@ class ElasticIp(Resource):
         self.state_set(self.DELETE_COMPLETE)
 
     def FnGetRefId(self):
-        return unicode(self.ipaddress)
+        return unicode(self._ipaddress())
 
     def FnGetAtt(self, key):
         if key == 'AllocationId':
index df6ae49d0caa3336e44365ae7ac7efe8dd4551e9..f0935b3ef0f6dd386b3404323c25a5eb927a43f8 100644 (file)
@@ -110,7 +110,7 @@ class Instance(Resource):
 
     def __init__(self, name, json_snippet, stack):
         super(Instance, self).__init__(name, json_snippet, stack)
-        self.ipaddress = '0.0.0.0'
+        self.ipaddress = None
         self.mime_string = None
 
         self.itype_oflavor = {'t1.micro': 'm1.tiny',
@@ -126,15 +126,30 @@ class Instance(Resource):
             'cc2.8xlarge': 'm1.large',
             'cg1.4xlarge': 'm1.large'}
 
-    def FnGetAtt(self, key):
+    def _ipaddress(self):
+        '''
+        Return the server's IP address, fetching it from Nova if necessary
+        '''
+        if self.ipaddress is None:
+            try:
+                server = self.nova().servers.get(self.instance_id)
+            except NotFound as ex:
+                logger.warn('Instance IP address not found (%s)' % str(ex))
+            else:
+                for n in server.networks:
+                    self.ipaddress = server.networks[n][0]
+                    break
+
+        return self.ipaddress or '0.0.0.0'
 
+    def FnGetAtt(self, key):
         res = None
         if key == 'AvailabilityZone':
             res = self.properties['AvailabilityZone']
         elif key == 'PublicIp':
-            res = self.ipaddress
+            res = self._ipaddress()
         elif key == 'PrivateDnsName':
-            res = self.ipaddress
+            res = self._ipaddress()
         else:
             raise exception.InvalidTemplateAttribute(resource=self.name,
                                                      key=key)
@@ -259,19 +274,6 @@ class Instance(Resource):
                         'Provided KeyName is not registered with nova'}
         return None
 
-    def reload(self):
-        '''
-        re-read the server's ipaddress so FnGetAtt works.
-        '''
-        try:
-            server = self.nova().servers.get(self.instance_id)
-            for n in server.networks:
-                self.ipaddress = server.networks[n][0]
-        except NotFound:
-            self.ipaddress = '0.0.0.0'
-
-        Resource.reload(self)
-
     def delete(self):
         if self.state == self.DELETE_IN_PROGRESS or \
            self.state == self.DELETE_COMPLETE:
index 5ca502bd0191a6b35b54179eaf29a1bb8503bd49..eefb89036f568c1c9b957e0b140e163a125e9aa3 100644 (file)
@@ -15,6 +15,7 @@
 
 import eventlet
 import json
+import itertools
 import logging
 from heat.common import exception
 from heat.engine import checkeddict
@@ -72,7 +73,7 @@ class Stack(object):
             res = Resource(rname, rdesc, self)
             self.resources[rname] = res
 
-            self.calulate_dependencies(rdesc, res)
+            self.calulate_dependencies(res.t, res)
 
     def validate(self):
         '''
@@ -233,23 +234,15 @@ class Stack(object):
         pool.spawn_n(self.delete_blocking)
 
     def get_outputs(self):
+        outputs = self.resolve_runtime_data(self.outputs)
 
-        for r in self.resources:
-            self.resources[r].reload()
+        def output_dict(k):
+            return {'Description': outputs[k].get('Description',
+                                                  'No description given'),
+                    'OutputKey': k,
+                    'OutputValue': outputs[k].get('Value', '')}
 
-        self.resolve_attributes(self.outputs)
-        self.resolve_joins(self.outputs)
-
-        outs = []
-        for o in self.outputs:
-            out = {}
-            out['Description'] = self.outputs[o].get('Description',
-                                                     'No description given')
-            out['OutputKey'] = o
-            out['OutputValue'] = self.outputs[o].get('Value', '')
-            outs.append(out)
-
-        return outs
+        return [output_dict(key) for key in outputs]
 
     def restart_resource_blocking(self, resource_name):
         '''
@@ -334,110 +327,112 @@ class Stack(object):
         except ValueError:
             raise exception.UserParameterMissing(key=key)
 
-    def resolve_static_refs(self, s):
+    def _resolve_static_refs(self, s):
         '''
-            looking for { "Ref": "str" }
+            looking for { "Ref" : "str" }
         '''
-        if isinstance(s, dict):
-            for i in s:
-                if i == 'Ref' and \
-                      isinstance(s[i], (basestring, unicode)) and \
-                      s[i] in self.parms:
-                    return self.parameter_get(s[i])
-                else:
-                    s[i] = self.resolve_static_refs(s[i])
-        elif isinstance(s, list):
-            for index, item in enumerate(s):
-                #print 'resolve_static_refs %d %s' % (index, item)
-                s[index] = self.resolve_static_refs(item)
-        return s
+        def match(key, value):
+            return (key == 'Ref' and
+                    isinstance(value, basestring) and
+                    value in self.parms)
 
-    def resolve_find_in_map(self, s):
-        '''
-            looking for { "Fn::FindInMap": ["str", "str"] }
-        '''
-        if isinstance(s, dict):
-            for i in s:
-                if i == 'Fn::FindInMap':
-                    obj = self.maps
-                    if isinstance(s[i], list):
-                        #print 'map list: %s' % s[i]
-                        for index, item in enumerate(s[i]):
-                            if isinstance(item, dict):
-                                item = self.resolve_find_in_map(item)
-                                #print 'map item dict: %s' % (item)
-                            else:
-                                pass
-                                #print 'map item str: %s' % (item)
-                            obj = obj[item]
-                    else:
-                        obj = obj[s[i]]
-                    return obj
-                else:
-                    s[i] = self.resolve_find_in_map(s[i])
-        elif isinstance(s, list):
-            for index, item in enumerate(s):
-                s[index] = self.resolve_find_in_map(item)
-        return s
+        def handle(ref):
+            return self.parameter_get(ref)
+
+        return _resolve(match, handle, s)
 
-    def resolve_attributes(self, s):
+    def _resolve_find_in_map(self, s):
+        def handle(args):
+            try:
+                name, key, value = args
+                return self.maps[name][key][value]
+            except (ValueError, TypeError) as ex:
+                raise KeyError(str(ex))
+
+        return _resolve(lambda k, v: k == 'Fn::FindInMap', handle, s)
+
+    def _resolve_attributes(self, s):
         '''
             looking for something like:
-            {"Fn::GetAtt" : ["DBInstance", "Endpoint.Address"]}
+            { "Fn::GetAtt" : [ "DBInstance", "Endpoint.Address" ] }
         '''
-        if isinstance(s, dict):
-            for i in s:
-                if i == 'Ref' and s[i] in self.resources:
-                    return self.resources[s[i]].FnGetRefId()
-                elif i == 'Fn::GetAtt':
-                    resource_name = s[i][0]
-                    key_name = s[i][1]
-                    res = self.resources.get(resource_name)
-                    rc = None
-                    if res:
-                        return res.FnGetAtt(key_name)
-                    else:
-                        raise exception.InvalidTemplateAttribute(
-                                        resource=resource_name, key=key_name)
-                    return rc
-                else:
-                    s[i] = self.resolve_attributes(s[i])
-        elif isinstance(s, list):
-            for index, item in enumerate(s):
-                s[index] = self.resolve_attributes(item)
-        return s
+        def match_ref(key, value):
+            return key == 'Ref' and value in self.resources
+
+        def handle_ref(arg):
+            return self.resources[arg].FnGetRefId()
+
+        def handle_getatt(args):
+            resource, att = args
+            try:
+                return self.resources[resource].FnGetAtt(att)
+            except KeyError:
+                raise exception.InvalidTemplateAttribute(resource=resource,
+                                                         key=att)
+
+        return _resolve(lambda k, v: k == 'Fn::GetAtt', handle_getatt,
+                        _resolve(match_ref, handle_ref, s))
 
-    def resolve_joins(self, s):
+    @staticmethod
+    def _resolve_joins(s):
         '''
-            looking for { "Fn::join": []}
+            looking for { "Fn::Join" : [] }
         '''
-        if isinstance(s, dict):
-            for i in s:
-                if i == 'Fn::Join':
-                    j = None
-                    try:
-                        j = s[i][0].join(s[i][1])
-                    except Exception:
-                        logger.error('Could not join %s' % str(s[i]))
-                    return j
-                else:
-                    s[i] = self.resolve_joins(s[i])
-        elif isinstance(s, list):
-            for index, item in enumerate(s):
-                s[index] = self.resolve_joins(item)
-        return s
+        def handle(args):
+            delim, strings = args
+            return delim.join(strings)
+
+        return _resolve(lambda k, v: k == 'Fn::Join', handle, s)
 
-    def resolve_base64(self, s):
+    @staticmethod
+    def _resolve_base64(s):
         '''
-            looking for { "Fn::join": [] }
+            looking for { "Fn::Base64" : "" }
         '''
-        if isinstance(s, dict):
-            for i in s:
-                if i == 'Fn::Base64':
-                    return s[i]
-                else:
-                    s[i] = self.resolve_base64(s[i])
-        elif isinstance(s, list):
-            for index, item in enumerate(s):
-                s[index] = self.resolve_base64(item)
-        return s
+        return _resolve(lambda k, v: k == 'Fn::Base64', lambda d: d, s)
+
+    def resolve_static_data(self, snippet):
+        return transform(snippet, [self._resolve_static_refs,
+                                   self._resolve_find_in_map])
+
+    def resolve_runtime_data(self, snippet):
+        return transform(snippet, [self._resolve_attributes,
+                                   self._resolve_joins,
+                                   self._resolve_base64])
+
+
+def transform(data, transformations):
+    '''
+    Apply each of the transformation functions in the supplied list to the data
+    in turn.
+    '''
+    for t in transformations:
+        data = t(data)
+    return data
+
+
+def _resolve(match, handle, snippet):
+    '''
+    Resolve constructs in a snippet of a template. The supplied match function
+    should return True if a particular key-value pair should be substituted,
+    and the handle function should return the correct substitution when passed
+    the argument list as parameters.
+
+    Returns a copy of the original snippet with the substitutions performed.
+    '''
+    recurse = lambda k: _resolve(match, handle, snippet[k])
+
+    if isinstance(snippet, dict):
+        should_handle = lambda k: match(k, snippet[k])
+        matches = itertools.imap(recurse,
+                                 itertools.ifilter(should_handle, snippet))
+        try:
+            args = next(matches)
+        except StopIteration:
+            # No matches
+            return dict((k, recurse(k)) for k in snippet)
+        else:
+            return handle(args)
+    elif isinstance(snippet, list):
+        return [recurse(i) for i in range(len(snippet))]
+    return snippet
index f7b7ed78bbf795b0c3b87b0deb2f43a9271ad0ac..1f1361beecb50221ed47ac0adbff95f4d52a7df0 100644 (file)
@@ -53,13 +53,13 @@ class Resource(object):
         return ResourceClass(name, json, stack)
 
     def __init__(self, name, json_snippet, stack):
-        self.t = json_snippet
         self.depends_on = []
         self.references = []
         self.stack = stack
         self.name = name
+        self.t = stack.resolve_static_data(json_snippet)
         self.properties = checkeddict.Properties(name, self.properties_schema)
-        if not 'Properties' in self.t:
+        if 'Properties' not in self.t:
             # make a dummy entry to prevent having to check all over the
             # place for it.
             self.t['Properties'] = {}
@@ -75,9 +75,6 @@ class Resource(object):
             self.id = None
         self._nova = {}
 
-        stack.resolve_static_refs(self.t)
-        stack.resolve_find_in_map(self.t)
-
     def nova(self, service_type='compute'):
         if service_type in self._nova:
             return self._nova[service_type]
@@ -98,26 +95,22 @@ class Resource(object):
                                                  service_name=service_name)
         return self._nova[service_type]
 
+    def calculate_properties(self):
+        template = self.stack.resolve_runtime_data(self.t)
+
+        for p, v in template['Properties'].items():
+            self.properties[p] = v
+
     def create(self):
         logger.info('creating %s name:%s' % (self.t['Type'], self.name))
-
-        self.stack.resolve_attributes(self.t)
-        self.stack.resolve_joins(self.t)
-        self.stack.resolve_base64(self.t)
-        for p in self.t['Properties']:
-            self.properties[p] = self.t['Properties'][p]
+        self.calculate_properties()
         self.properties.validate()
 
     def validate(self):
         logger.info('validating %s name:%s' % (self.t['Type'], self.name))
 
-        self.stack.resolve_attributes(self.t)
-        self.stack.resolve_joins(self.t)
-        self.stack.resolve_base64(self.t)
-
         try:
-            for p in self.t['Properties']:
-                self.properties[p] = self.t['Properties'][p]
+            self.calculate_properties()
         except ValueError as ex:
                 return {'Error': '%s' % str(ex)}
         self.properties.validate()
@@ -160,7 +153,8 @@ class Resource(object):
             ev['name'] = new_state
             ev['resource_status_reason'] = reason
             ev['resource_type'] = self.t['Type']
-            ev['resource_properties'] = self.t['Properties']
+            self.calculate_properties()
+            ev['resource_properties'] = dict(self.properties)
             try:
                 db_api.event_create(None, ev)
             except Exception as ex:
@@ -168,24 +162,10 @@ class Resource(object):
             self.state = new_state
 
     def delete(self):
-        self.reload()
         logger.info('deleting %s name:%s inst:%s db_id:%s' %
                     (self.t['Type'], self.name,
                      self.instance_id, str(self.id)))
 
-    def reload(self):
-        '''
-        The point of this function is to get the Resource instance back
-        into the state that it was just after it was created. So we
-        need to retrieve things like ipaddresses and other variables
-        used by FnGetAtt and FnGetRefId. classes inheriting from Resource
-        might need to override this, but still call it.
-        This is currently used by stack.get_outputs()
-        '''
-        logger.info('reloading %s name:%s instance_id:%s' %
-                    (self.t['Type'], self.name, self.instance_id))
-        self.stack.resolve_attributes(self.t)
-
     def FnGetRefId(self):
         '''
         http://docs.amazonwebservices.com/AWSCloudFormation/latest/UserGuide/ \
index 4dea45c841574e8c31929ca8015b662ece44aae4..59a1075e1e4c0968d5bfeddc159aff6738afa0d2 100644 (file)
@@ -59,9 +59,7 @@ class instancesTest(unittest.TestCase):
                                       t['Resources']['WebServer'], stack)
 
         instance.itype_oflavor['256 MB Server'] = '256 MB Server'
-        instance.stack.resolve_attributes(instance.t)
-        instance.stack.resolve_joins(instance.t)
-        instance.stack.resolve_base64(instance.t)
+        instance.t = instance.stack.resolve_runtime_data(instance.t)
 
         # need to resolve the template functions
         server_userdata = instance._build_userdata(
@@ -109,9 +107,7 @@ class instancesTest(unittest.TestCase):
                                       t['Resources']['WebServer'], stack)
 
         instance.itype_oflavor['256 MB Server'] = '256 MB Server'
-        instance.stack.resolve_attributes(instance.t)
-        instance.stack.resolve_joins(instance.t)
-        instance.stack.resolve_base64(instance.t)
+        instance.t = instance.stack.resolve_runtime_data(instance.t)
 
         # need to resolve the template functions
         server_userdata = instance._build_userdata(
index 75788595b037ec1107d3cc3e529d092f026e3ea1..ec8a64dfce36bfdf9908c267c1861c3c3940ee3a 100644 (file)
@@ -45,11 +45,9 @@ class stacksTest(unittest.TestCase):
         instances.Instance.nova().AndReturn(self.fc)
         instance = stack.resources['WebServer']
         instance.itype_oflavor['m1.large'] = 'm1.large'
-        instance.stack.resolve_attributes(instance.t)
-        instance.stack.resolve_joins(instance.t)
-        instance.stack.resolve_base64(instance.t)
+        instance.calculate_properties()
         server_userdata = instance._build_userdata(
-                                instance.t['Properties']['UserData'])
+                                instance.properties['UserData'])
         self.m.StubOutWithMock(self.fc.servers, 'create')
         self.fc.servers.create(image=744, flavor=3, key_name='test',
                 name='WebServer', security_groups=None,
index c4f14383e14482c814c635b5712a334f9ebf7a53..a073935ac8097626297fede6201af6439678a8e9 100644 (file)
@@ -223,9 +223,6 @@ class validateTest(unittest.TestCase):
 
         self.m.ReplayAll()
         volumeattach = stack.resources['MountPoint']
-        stack.resolve_attributes(volumeattach.t)
-        stack.resolve_joins(volumeattach.t)
-        stack.resolve_base64(volumeattach.t)
         assert(volumeattach.validate() is None)
 
     def test_validate_volumeattach_invalid(self):
@@ -241,9 +238,6 @@ class validateTest(unittest.TestCase):
 
         self.m.ReplayAll()
         volumeattach = stack.resources['MountPoint']
-        stack.resolve_attributes(volumeattach.t)
-        stack.resolve_joins(volumeattach.t)
-        stack.resolve_base64(volumeattach.t)
         assert(volumeattach.validate())
 
     def test_validate_ref_valid(self):