]> review.fuel-infra Code Review - openstack-build/heat-build.git/commitdiff
Wrap the resource registration in a global environment
authorAngus Salkeld <asalkeld@redhat.com>
Mon, 19 Aug 2013 05:28:58 +0000 (15:28 +1000)
committerAngus Salkeld <asalkeld@redhat.com>
Mon, 19 Aug 2013 23:52:47 +0000 (09:52 +1000)
Change-Id: I065dadc9fae19ab21d6d4aeae08334f20da161bb

heat/engine/environment.py
heat/engine/resource.py
heat/engine/resources/__init__.py
heat/engine/resources/template_resource.py
heat/tests/test_environment.py
heat/tests/test_properties.py
heat/tests/test_provider_template.py

index 1f6cd8d5a2bc215a241ec7fafef6e0a97d36056f..e38e9c6efb38af778c78cc5d481047a6e2eaa311 100644 (file)
 #  DBUsername: wp_admin
 #  LinuxDistribution: F17
 
+import itertools
+
+from heat.openstack.common import log
+from heat.common import exception
+
+
+LOG = log.getLogger(__name__)
+
+
+class ResourceInfo(object):
+    """Base mapping of resource type to implementation."""
+
+    def __new__(cls, registry, path, value, **kwargs):
+        '''Create a new ResourceInfo of the appropriate class.'''
+
+        if cls != ResourceInfo:
+            # Call is already for a subclass, so pass it through
+            return super(ResourceInfo, cls).__new__(cls)
+
+        name = path[-1]
+        if name.endswith(('.yaml', '.template')):
+            # a template url for the resource "Type"
+            return TemplateResourceInfo(registry, path, value)
+        elif not isinstance(value, basestring):
+            return ClassResourceInfo(registry, path, value)
+        elif value.endswith(('.yaml', '.template')):
+            # a registered template
+            return TemplateResourceInfo(registry, path, value)
+        elif name.endswith('*'):
+            return GlobResourceInfo(registry, path, value)
+        else:
+            return MapResourceInfo(registry, path, value)
+
+    def __init__(self, registry, path, value):
+        self.registry = registry
+        self.path = path
+        self.name = path[-1]
+        self.value = value
+        self.user_resource = True
+
+    def __eq__(self, other):
+        return (self.path == other.path and
+                self.value == other.value and
+                self.user_resource == other.user_resource)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __lt__(self, other):
+        if self.user_resource != other.user_resource:
+            # user resource must be sorted above system ones.
+            return self.user_resource > other.user_resource
+        if len(self.path) != len(other.path):
+            # more specific (longer) path must be sorted above system ones.
+            return len(self.path) > len(other.path)
+        return self.path < other.path
+
+    def __gt__(self, other):
+        return other.__lt__(self)
+
+    def get_resource_info(self, resource_type=None, resource_name=None):
+        return self
+
+    def matches(self, resource_type):
+        return False
+
+    def __str__(self):
+        return '[%s](User:%s) %s -> %s' % (self.description,
+                                           self.user_resource,
+                                           self.name, str(self.value))
+
+
+class ClassResourceInfo(ResourceInfo):
+    """Store the mapping of resource name to python class implementation."""
+    description = 'Plugin'
+
+    def get_class(self):
+        return self.value
+
+
+class TemplateResourceInfo(ResourceInfo):
+    """Store the info needed to start a TemplateResource.
+    """
+    description = 'Template'
+
+    def __init__(self, registry, path, value):
+        super(TemplateResourceInfo, self).__init__(registry, path, value)
+        if self.name.endswith(('.yaml', '.template')):
+            self.template_name = self.name
+        else:
+            self.template_name = value
+
+    def get_class(self):
+        from heat.engine.resources import template_resource
+        return template_resource.TemplateResource
+
+
+class MapResourceInfo(ResourceInfo):
+    """Store the mapping of one resource type to another.
+    like: OS::Networking::FloatingIp -> OS::Neutron::FloatingIp
+    """
+    description = 'Mapping'
+
+    def get_class(self):
+        return None
+
+    def get_resource_info(self, resource_type=None, resource_name=None):
+        return self.registry.get_resource_info(self.value, resource_name)
+
+
+class GlobResourceInfo(MapResourceInfo):
+    """Store the mapping (with wild cards) of one resource type to another.
+    like: OS::Networking::* -> OS::Neutron::*
+    """
+    description = 'Wildcard Mapping'
+
+    def get_resource_info(self, resource_type=None, resource_name=None):
+        orig_prefix = self.name[:-1]
+        new_type = self.value[:-1] + resource_type[len(orig_prefix):]
+        return self.registry.get_resource_info(new_type, resource_name)
+
+    def matches(self, resource_type):
+        return resource_type.startswith(self.name[:-1])
+
+
+class ResourceRegistry(object):
+    """By looking at the environment, find the resource implementation."""
+
+    def __init__(self, global_registry):
+        self._registry = {'resources': {}}
+        self.global_registry = global_registry
+
+    def load(self, json_snippet):
+        self._load_registry([], json_snippet)
+
+    def register_class(self, resource_type, resource_class):
+        ri = ResourceInfo(self, [resource_type], resource_class)
+        self._register_info([resource_type], ri)
+
+    def _load_registry(self, path, registry):
+        for k, v in iter(registry.items()):
+            if isinstance(v, dict):
+                self._load_registry(path + [k], v)
+            else:
+                self._register_info(path + [k],
+                                    ResourceInfo(self, path + [k], v))
+
+    def _register_info(self, path, info):
+        """place the new info in the correct location in the registry.
+        path: a list of keys ['resources', 'my_server', 'OS::Compute::Server']
+        """
+        descriptive_path = '/'.join(path)
+        name = path[-1]
+        # create the structure if needed
+        registry = self._registry
+        for key in path[:-1]:
+            if key not in registry:
+                registry[key] = {}
+            registry = registry[key]
+
+        if name in registry and isinstance(registry[name], ResourceInfo):
+            details = {
+                'path': descriptive_path,
+                'was': str(registry[name].value),
+                'now': str(info.value)}
+            LOG.warn(_('Changing %(path)s from %(was)s to %(now)s') % details)
+        else:
+            LOG.info(_('Registering %(path)s -> %(value)s') % {
+                'path': descriptive_path,
+                'value': str(info.value)})
+        info.user_resource = (self.global_registry is not None)
+        registry[name] = info
+
+    def iterable_by(self, resource_type, resource_name=None):
+        if resource_type.endswith(('.yaml', '.template')):
+            # resource with a Type == a template
+            # we dynamically create an entry as it has not been registered.
+            if resource_type not in self._registry:
+                res = ResourceInfo(self, [resource_type], None)
+                self._register_info([resource_type], res)
+            yield self._registry[resource_type]
+
+        # handle a specific resource mapping.
+        if resource_name:
+            impl = self._registry['resources'].get(resource_name)
+            if impl and resource_type in impl:
+                yield impl[resource_type]
+
+        # handle: "OS::Compute::Server" -> "Rackspace::Compute::Server"
+        impl = self._registry.get(resource_type)
+        if impl:
+            yield impl
+
+        # handle: "OS::*" -> "Dreamhost::*"
+        def is_a_glob(resource_type):
+            return resource_type.endswith('*')
+        globs = itertools.ifilter(is_a_glob, self._registry.keys())
+        for glob in globs:
+            if self._registry[glob].matches(resource_type):
+                yield self._registry[glob]
+
+    def get_resource_info(self, resource_type, resource_name=None,
+                          registry_type=None):
+        """Find possible matches to the resource type and name.
+        chain the results from the global and user registry to find
+        a match.
+        """
+        # use cases
+        # 1) get the impl.
+        #    - filter_by(res_type=X), sort_by(res_name=W, is_user=True)
+        # 2) in TemplateResource we need to get both the
+        #    TemplateClass and the ResourceClass
+        #    - filter_by(res_type=X, impl_type=TemplateResourceInfo),
+        #      sort_by(res_name=W, is_user=True)
+        #    - filter_by(res_type=X, impl_type=ClassResourceInfo),
+        #      sort_by(res_name=W, is_user=True)
+        # 3) get_types() from the api
+        #    - filter_by(is_user=False)
+        # 4) as_dict() to write to the db
+        #    - filter_by(is_user=True)
+        if self.global_registry is not None:
+            giter = self.global_registry.iterable_by(resource_type,
+                                                     resource_name)
+        else:
+            giter = []
+
+        matches = itertools.chain(self.iterable_by(resource_type,
+                                                   resource_name),
+                                  giter)
+
+        for info in sorted(matches):
+            match = info.get_resource_info(resource_type,
+                                           resource_name)
+            if registry_type is None or isinstance(match, registry_type):
+                return match
+
+    def get_class(self, resource_type, resource_name=None):
+        info = self.get_resource_info(resource_type,
+                                      resource_name=resource_name)
+        if info is None:
+            msg = "Unknown resource Type : %s" % resource_type
+            raise exception.StackValidationFailed(message=msg)
+        return info.get_class()
+
+    def as_dict(self):
+        """Return user resources in a dict format."""
+        def _as_dict(level):
+            tmp = {}
+            for k, v in iter(level.items()):
+                if isinstance(v, dict):
+                    tmp[k] = _as_dict(v)
+                elif v.user_resource:
+                    tmp[k] = v.value
+            return tmp
+
+        return _as_dict(self._registry)
+
+    def get_types(self):
+        '''Return a list of valid resource types.'''
+        def is_plugin(key):
+            if isinstance(self._registry[key], ClassResourceInfo):
+                return True
+            return False
+        return [k for k in self._registry if is_plugin(k)]
+
 
 class Environment(object):
 
-    def __init__(self, env=None):
+    def __init__(self, env=None, user_env=True):
         """Create an Environment from a dict of varing format.
         1) old-school flat parameters
         2) or newer {resource_registry: bla, parameters: foo}
 
         :param env: the json environment
+        :param user_env: boolean, if false then we manage python resources too.
         """
         if env is None:
             env = {}
-        self.resource_registry = env.get('resource_registry', {})
-        if 'resources' not in self.resource_registry:
-            self.resource_registry['resources'] = {}
+        if user_env:
+            from heat.engine import resources
+            global_registry = resources.global_env().registry
+        else:
+            global_registry = None
+
+        self.registry = ResourceRegistry(global_registry)
+        self.registry.load(env.get('resource_registry', {}))
+
         if 'parameters' in env:
             self.params = env['parameters']
         else:
             self.params = dict((k, v) for (k, v) in env.iteritems()
                                if k != 'resource_registry')
 
-    def get_resource_type(self, resource_type, resource_name):
-        """Get the specific resource type that the user wants to implement
-        'resource_type'.
-        """
-        impl = self.resource_registry['resources'].get(resource_name)
-        if impl and resource_type in impl:
-            return impl[resource_type]
-
-        # handle: "OS::Compute::Server" -> "Rackspace::Compute::Server"
-        impl = self.resource_registry.get(resource_type)
-        if impl:
-            return impl
-        # handle: "OS::*" -> "Dreamhost::*"
-        for k, v in iter(self.resource_registry.items()):
-            if k.endswith('*'):
-                orig_prefix = k[:-1]
-                if resource_type.startswith(orig_prefix):
-                    return v[:-1] + resource_type[len(orig_prefix):]
-        # no special handling, just return what we were given.
-        return resource_type
+    def load(self, env_snippet):
+        self.registry.load(env_snippet.get('resource_registry', {}))
+        self.params.update(env_snippet.get('parameters', {}))
 
     def user_env_as_dict(self):
         """Get the environment as a dict, ready for storing in the db."""
-        return {'resource_registry': self.resource_registry,
+        return {'resource_registry': self.registry.as_dict(),
                 'parameters': self.params}
+
+    def register_class(self, resource_type, resource_class):
+        self.registry.register_class(resource_type, resource_class)
+
+    def get_class(self, resource_type, resource_name=None):
+        return self.registry.get_class(resource_type, resource_name)
+
+    def get_types(self):
+        return self.registry.get_types()
+
+    def get_resource_info(self, resource_type, resource_name=None,
+                          registry_type=None):
+        return self.registry.get_resource_info(resource_type, resource_name,
+                                               registry_type)
index d50031c6435428c46d778fad6fffb857156b548d..331f89c8f6293e890db245b77019c161fe1cc917 100644 (file)
@@ -22,6 +22,7 @@ from heat.openstack.common import excutils
 from heat.db import api as db_api
 from heat.common import identifier
 from heat.common import short_id
+from heat.engine import resources
 from heat.engine import timestamp
 # import class to avoid name collisions and ugly aliasing
 from heat.engine.attributes import Attributes
@@ -33,45 +34,18 @@ from heat.openstack.common.gettextutils import _
 logger = logging.getLogger(__name__)
 
 
-_resource_classes = {}
-_template_class = None
-
-
 def get_types():
     '''Return an iterator over the list of valid resource types.'''
-    return iter(_resource_classes)
+    return iter(resources.global_env().get_types())
 
 
-def get_class(resource_type, resource_name=None, environment=None):
+def get_class(resource_type, resource_name=None):
     '''Return the Resource class for a given resource type.'''
-    if environment:
-        resource_type = environment.get_resource_type(resource_type,
-                                                      resource_name)
-
-    if resource_type.endswith(('.yaml', '.template')):
-        cls = _template_class
-    else:
-        cls = _resource_classes.get(resource_type)
-    if cls is None:
-        msg = "Unknown resource Type : %s" % resource_type
-        raise exception.StackValidationFailed(message=msg)
-    else:
-        return cls
+    return resources.global_env().get_class(resource_type, resource_name)
 
 
 def _register_class(resource_type, resource_class):
-    logger.info(_('Registering resource type %s') % resource_type)
-    if resource_type in _resource_classes:
-        logger.warning(_('Replacing existing resource type %s') %
-                       resource_type)
-
-    _resource_classes[resource_type] = resource_class
-
-
-def register_template_class(cls):
-    global _template_class
-    if _template_class is None:
-        _template_class = cls
+    resources.global_env().register_class(resource_type, resource_class)
 
 
 class UpdateReplace(Exception):
@@ -149,9 +123,8 @@ class Resource(object):
             return super(Resource, cls).__new__(cls)
 
         # Select the correct subclass to instantiate
-        ResourceClass = get_class(json['Type'],
-                                  resource_name=name,
-                                  environment=stack.env)
+        ResourceClass = stack.env.get_class(json['Type'],
+                                            resource_name=name)
         return ResourceClass(name, json, stack)
 
     def __init__(self, name, json_snippet, stack):
index 924114bea43f9819e1c7891583e0bff95537b817..535e707e58219547fa38010189d1d5a593d4eaf0 100644 (file)
 #    under the License.
 from heat.openstack.common import log as logging
 from heat.openstack.common.gettextutils import _
+from heat.engine import environment
 
 
 logger = logging.getLogger(__name__)
 
 
 def _register_resources(type_pairs):
-    from heat.engine import resource
 
     for res_name, res_class in type_pairs:
-        resource._register_class(res_name, res_class)
+        _environment.register_class(res_name, res_class)
 
 
 def _get_module_resources(module):
@@ -43,16 +43,24 @@ def _register_modules(modules):
     _register_resources(itertools.chain.from_iterable(resource_lists))
 
 
-_initialized = False
+_environment = None
+
+
+def global_env():
+    global _environment
+    if _environment is None:
+        initialise()
+    return _environment
 
 
 def initialise():
-    global _initialized
-    if _initialized:
+    global _environment
+    if _environment is not None:
         return
     import sys
     from heat.common import plugin_loader
 
+    _environment = environment.Environment({}, user_env=False)
     _register_modules(plugin_loader.load_modules(sys.modules[__name__]))
 
     from oslo.config import cfg
index fd3ce62b342a27640f444f4b47bbbf2b86b60e63..875a6e2df5f615451cf6389a119aea6bc9f7f596 100644 (file)
@@ -18,8 +18,8 @@ from requests import exceptions
 from heat.common import template_format
 from heat.common import urlfetch
 from heat.engine import attributes
+from heat.engine import environment
 from heat.engine import properties
-from heat.engine import resource
 from heat.engine import stack_resource
 
 from heat.openstack.common import log as logging
@@ -37,16 +37,20 @@ class TemplateResource(stack_resource.StackResource):
     '''
 
     def __init__(self, name, json_snippet, stack):
-        self.template_name = stack.env.get_resource_type(json_snippet['Type'],
-                                                         name)
         self._parsed_nested = None
         self.stack = stack
-        # on purpose don't pass in the environment so we get
-        # the official/facade class in case we need to copy it's schema.
-        cls_facade = resource.get_class(json_snippet['Type'])
+        tri = stack.env.get_resource_info(
+            json_snippet['Type'],
+            registry_type=environment.TemplateResourceInfo)
+        self.template_name = tri.template_name
+
+        cri = stack.env.get_resource_info(
+            json_snippet['Type'],
+            registry_type=environment.ClassResourceInfo)
+
         # if we're not overriding via the environment, mirror the template as
         # a new resource
-        if cls_facade == self.__class__:
+        if cri is None or cri.get_class() == self.__class__:
             self.properties_schema = (properties.Properties
                 .schema_from_params(self.parsed_nested.get('Parameters')))
             self.attributes_schema = (attributes.Attributes
@@ -54,6 +58,7 @@ class TemplateResource(stack_resource.StackResource):
         # otherwise we are overriding a resource type via the environment
         # and should mimic that type
         else:
+            cls_facade = cri.get_class()
             self.properties_schema = cls_facade.properties_schema
             self.attributes_schema = cls_facade.attributes_schema
 
@@ -112,6 +117,3 @@ class TemplateResource(stack_resource.StackResource):
         if not self.nested():
             return unicode(self.name)
         return self.nested().identifier().arn()
-
-
-resource.register_template_class(TemplateResource)
index a3363954fa1c6b4aa2c6f2992894c479a9ae6bad..487ded6fdd6016e3cc5fe98a6185a5262daaf90b 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-import testtools
-
 from heat.engine import environment
+from heat.engine import resources
+
+from heat.tests import generic_resource
+from heat.tests import common
+
 
+class EnvironmentTest(common.HeatTestCase):
+    def setUp(self):
+        super(EnvironmentTest, self).setUp()
+        self.g_env = resources.global_env()
 
-class EnvironmentTest(testtools.TestCase):
     def test_load_old_parameters(self):
         old = {u'a': u'ff', u'b': u'ss'}
         expected = {u'parameters': old,
@@ -28,17 +34,20 @@ class EnvironmentTest(testtools.TestCase):
 
     def test_load_new_env(self):
         new_env = {u'parameters': {u'a': u'ff', u'b': u'ss'},
-                   u'resource_registry': {u'OS::Food': 'fruity'}}
+                   u'resource_registry': {u'OS::Food': u'fruity.yaml',
+                                          u'resources': {}}}
         env = environment.Environment(new_env)
         self.assertEqual(new_env, env.user_env_as_dict())
 
     def test_global_registry(self):
+        self.g_env.register_class('CloudX::Compute::Server',
+                                  generic_resource.GenericResource)
         new_env = {u'parameters': {u'a': u'ff', u'b': u'ss'},
                    u'resource_registry': {u'OS::*': 'CloudX::*'}}
         env = environment.Environment(new_env)
         self.assertEqual('CloudX::Compute::Server',
-                         env.get_resource_type('OS::Compute::Server',
-                                               'my_db_server'))
+                         env.get_resource_info('OS::Compute::Server',
+                                               'my_db_server').name)
 
     def test_map_one_resource_type(self):
         new_env = {u'parameters': {u'a': u'ff', u'b': u'ss'},
@@ -46,19 +55,44 @@ class EnvironmentTest(testtools.TestCase):
                                           {u'my_db_server':
                                            {u'OS::DBInstance': 'db.yaml'}}}}
         env = environment.Environment(new_env)
-        self.assertEqual('db.yaml',
-                         env.get_resource_type('OS::DBInstance',
-                                               'my_db_server'))
-        self.assertEqual('OS::Compute::Server',
-                         env.get_resource_type('OS::Compute::Server',
-                                               'my_other_server'))
+
+        info = env.get_resource_info('OS::DBInstance', 'my_db_server')
+        self.assertEqual('db.yaml', info.value)
 
     def test_map_all_resources_of_type(self):
+        self.g_env.register_class('OS::Nova::FloatingIP',
+                                  generic_resource.GenericResource)
+
         new_env = {u'parameters': {u'a': u'ff', u'b': u'ss'},
                    u'resource_registry':
                    {u'OS::Networking::FloatingIP': 'OS::Nova::FloatingIP',
                     u'OS::Loadbalancer': 'lb.yaml'}}
+
         env = environment.Environment(new_env)
         self.assertEqual('OS::Nova::FloatingIP',
-                         env.get_resource_type('OS::Networking::FloatingIP',
+                         env.get_resource_info('OS::Networking::FloatingIP',
+                                               'my_fip').name)
+
+    def test_resource_sort_order_len(self):
+        new_env = {u'resource_registry': {u'resources': {u'my_fip': {
+            u'OS::Networking::FloatingIP': 'ip.yaml'}}},
+            u'OS::Networking::FloatingIP': 'OS::Nova::FloatingIP'}
+
+        env = environment.Environment(new_env)
+        self.assertEqual('ip.yaml',
+                         env.get_resource_info('OS::Networking::FloatingIP',
+                                               'my_fip').value)
+
+    def test_env_load(self):
+        new_env = {u'resource_registry': {u'resources': {u'my_fip': {
+            u'OS::Networking::FloatingIP': 'ip.yaml'}}}}
+
+        env = environment.Environment()
+        self.assertEqual(None,
+                         env.get_resource_info('OS::Networking::FloatingIP',
                                                'my_fip'))
+
+        env.load(new_env)
+        self.assertEqual('ip.yaml',
+                         env.get_resource_info('OS::Networking::FloatingIP',
+                                               'my_fip').value)
index 67d4df45b5200cdec7063ae97e52a6c1d079e433..db64551eaf495090d233870555d4eb3e5f39e1a6 100644 (file)
@@ -16,7 +16,7 @@
 import testtools
 
 from heat.engine import properties
-from heat.engine import resource
+from heat.engine import resources
 from heat.common import exception
 
 
@@ -184,7 +184,7 @@ class SchemaTest(testtools.TestCase):
         self.assertEqual(d, dict(l))
 
     def test_all_resource_schemata(self):
-        for resource_type in resource._resource_classes.itervalues():
+        for resource_type in resources.global_env().get_types():
             for schema in getattr(resource_type,
                                   'properties_schema',
                                   {}).itervalues():
index a20c9074b12a978d6f3255134f6887905c965446..c2f39f31984cdbbe62a7e0b8bf936f4080fec6f8 100644 (file)
@@ -47,46 +47,43 @@ class ProviderTemplateTest(HeatTestCase):
         # default class.
         env_str = {'resource_registry': {}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
-        self.assertEqual(cls, generic_rsrc.GenericResource)
+        cls = env.get_class('OS::ResourceType', 'fred')
+        self.assertEqual(generic_rsrc.GenericResource, cls)
 
     def test_get_mine_global_map(self):
         # assertion: with a global rule we get the "mycloud" class.
         env_str = {'resource_registry': {"OS::*": "myCloud::*"}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
-        self.assertEqual(cls, MyCloudResource)
+        cls = env.get_class('OS::ResourceType', 'fred')
+        self.assertEqual(MyCloudResource, cls)
 
     def test_get_mine_type_map(self):
         # assertion: with a global rule we get the "mycloud" class.
         env_str = {'resource_registry': {
             "OS::ResourceType": "myCloud::ResourceType"}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
-        self.assertEqual(cls, MyCloudResource)
+        cls = env.get_class('OS::ResourceType', 'fred')
+        self.assertEqual(MyCloudResource, cls)
 
     def test_get_mine_resource_map(self):
         # assertion: with a global rule we get the "mycloud" class.
         env_str = {'resource_registry': {'resources': {'fred': {
             "OS::ResourceType": "myCloud::ResourceType"}}}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
-        self.assertEqual(cls, MyCloudResource)
+        cls = env.get_class('OS::ResourceType', 'fred')
+        self.assertEqual(MyCloudResource, cls)
 
     def test_get_os_no_match(self):
         # assertion: make sure 'fred' doesn't match 'jerry'.
         env_str = {'resource_registry': {'resources': {'jerry': {
             "OS::ResourceType": "myCloud::ResourceType"}}}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
-        self.assertEqual(cls, generic_rsrc.GenericResource)
+        cls = env.get_class('OS::ResourceType', 'fred')
+        self.assertEqual(generic_rsrc.GenericResource, cls)
 
     def test_to_parameters(self):
         """Tests property conversion to parameter values."""
         utils.setup_dummy_db()
-        stack = parser.Stack(utils.dummy_context(), 'test_stack',
-                             parser.Template({}),
-                             stack_id=uuidutils.generate_uuid())
 
         class DummyResource(object):
             attributes_schema = {"Foo": "A test attribute"}
@@ -97,6 +94,14 @@ class ProviderTemplateTest(HeatTestCase):
                 "AMap": {"Type": "Map"}
             }
 
+        env = environment.Environment()
+        resource._register_class('DummyResource', DummyResource)
+        env.load({'resource_registry':
+                  {'DummyResource': 'test_resource.template'}})
+        stack = parser.Stack(utils.dummy_context(), 'test_stack',
+                             parser.Template({}), env=env,
+                             stack_id=uuidutils.generate_uuid())
+
         map_prop_val = {
             "key1": "val1",
             "key2": ["lval1", "lval2", "lval3"],
@@ -106,7 +111,7 @@ class ProviderTemplateTest(HeatTestCase):
             }
         }
         json_snippet = {
-            "Type": "test_resource.template",
+            "Type": "DummyResource",
             "Properties": {
                 "Foo": "Bar",
                 "AList": ["one", "two", "three"],
@@ -114,9 +119,6 @@ class ProviderTemplateTest(HeatTestCase):
                 "AMap": map_prop_val
             }
         }
-        self.m.StubOutWithMock(template_resource.resource, "get_class")
-        (template_resource.resource.get_class("test_resource.template")
-         .AndReturn(DummyResource))
         self.m.ReplayAll()
         temp_res = template_resource.TemplateResource('test_t_res',
                                                       json_snippet, stack)
@@ -143,7 +145,7 @@ class ProviderTemplateTest(HeatTestCase):
         env_str = {'resource_registry': {'resources': {'fred': {
             "OS::ResourceType": "some_magic.yaml"}}}}
         env = environment.Environment(env_str)
-        cls = resource.get_class('OS::ResourceType', 'fred', env)
+        cls = env.get_class('OS::ResourceType', 'fred')
         self.assertEqual(cls, template_resource.TemplateResource)
 
     def test_template_as_resource(self):