]> review.fuel-infra Code Review - openstack-build/heat-build.git/commitdiff
Get db session from the context.
authorSteve Baker <sbaker@redhat.com>
Wed, 21 Nov 2012 19:15:05 +0000 (08:15 +1300)
committerSteve Baker <sbaker@redhat.com>
Wed, 21 Nov 2012 19:15:05 +0000 (08:15 +1300)
The aim is to use a single sqlalchemy session for an RPC request.

The context object passed to EngineAPI methods is actually an RpcContext
which contains the same data as the RequestContext. The @request_context
decorator turns this back into a RequestContext which can now have other
behaviours added to it.

RequestContext now has a lazy loaded session attribute.

Save calls on created entities need to be passed the shared session.

Change-Id: Ied4e66deaca205362b84fb698f75cc872886607d

heat/common/context.py
heat/db/api.py
heat/db/sqlalchemy/api.py
heat/engine/parser.py
heat/engine/resources/resource.py
heat/engine/service.py

index 3a8eb73cac57adbfa5a97c294f2aec52c0ac1567..51488934f50776129bcd37b186bcfe4c7706cd63 100644 (file)
@@ -19,6 +19,7 @@ from heat.common import wsgi
 from heat.openstack.common import cfg
 from heat.openstack.common import importutils
 from heat.common import utils as heat_utils
+from heat.db import api as db_api
 
 
 def generate_request_id():
@@ -64,10 +65,17 @@ class RequestContext(object):
         self.owner_is_tenant = owner_is_tenant
         if overwrite or not hasattr(local.store, 'context'):
             self.update_store()
+        self._session = None
 
     def update_store(self):
         local.store.context = self
 
+    @property
+    def session(self):
+        if self._session is None:
+            self._session = db_api.get_session()
+        return self._session
+
     def to_dict(self):
         return {'auth_token': self.auth_token,
                 'username': self.username,
index 14ca76c29d6e9397cf5cb3556a14fbc910030408..d9cbbb8a56adc5b41d37068132f8feed2449cbb9 100644 (file)
@@ -48,6 +48,10 @@ def configure():
     SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout
 
 
+def get_session():
+    return IMPL.get_session()
+
+
 def raw_template_get(context, template_id):
     return IMPL.raw_template_get(context, template_id)
 
index 57680b22a2bc1cccf84e69c10a96d18fcedcd4f6..459e0d9cc5e94e53d9ce0c0a18bca7ffae6db85c 100644 (file)
@@ -22,17 +22,17 @@ from heat.db.sqlalchemy.session import get_session
 from heat.engine import auth
 
 
-def model_query(context, *args, **kwargs):
-    """
-    :param session: if present, the session to use
-    """
-    session = kwargs.get('session') or get_session()
-
+def model_query(context, *args):
+    session = _session(context)
     query = session.query(*args)
 
     return query
 
 
+def _session(context):
+    return (context and context.session) or get_session()
+
+
 def raw_template_get(context, template_id):
     result = model_query(context, models.RawTemplate).get(template_id)
 
@@ -54,7 +54,7 @@ def raw_template_get_all(context):
 def raw_template_create(context, values):
     raw_template_ref = models.RawTemplate()
     raw_template_ref.update(values)
-    raw_template_ref.save()
+    raw_template_ref.save(_session(context))
     return raw_template_ref
 
 
@@ -97,7 +97,7 @@ def resource_get_all(context):
 def resource_create(context, values):
     resource_ref = models.Resource()
     resource_ref.update(values)
-    resource_ref.save()
+    resource_ref.save(_session(context))
     return resource_ref
 
 
@@ -147,7 +147,7 @@ def stack_get_all_by_tenant(context):
 def stack_create(context, values):
     stack_ref = models.Stack()
     stack_ref.update(values)
-    stack_ref.save()
+    stack_ref.save(_session(context))
     return stack_ref
 
 
@@ -161,7 +161,7 @@ def stack_update(context, stack_id, values):
     old_template_id = stack.raw_template_id
 
     stack.update(values)
-    stack.save()
+    stack.save(_session(context))
 
     # When the raw_template ID changes, we delete the old template
     # after storing the new template ID
@@ -196,13 +196,14 @@ def stack_delete(context, stack_id):
     session.flush()
 
 
-def user_creds_create(values):
+def user_creds_create(context):
+    values = context.to_dict()
     user_creds_ref = models.UserCreds()
     user_creds_ref.update(values)
     user_creds_ref.password = auth.encrypt(values['password'])
     user_creds_ref.service_password = auth.encrypt(values['service_password'])
     user_creds_ref.aws_creds = auth.encrypt(values['aws_creds'])
-    user_creds_ref.save()
+    user_creds_ref.save(_session(context))
     return user_creds_ref
 
 
@@ -250,7 +251,7 @@ def event_get_all_by_stack(context, stack_id):
 def event_create(context, values):
     event_ref = models.Event()
     event_ref.update(values)
-    event_ref.save()
+    event_ref.save(_session(context))
     return event_ref
 
 
@@ -280,7 +281,7 @@ def watch_rule_get_all_by_stack(context, stack_id):
 def watch_rule_create(context, values):
     obj_ref = models.WatchRule()
     obj_ref.update(values)
-    obj_ref.save()
+    obj_ref.save(_session(context))
     return obj_ref
 
 
@@ -292,7 +293,7 @@ def watch_rule_update(context, watch_id, values):
                         (watch_id, 'that does not exist'))
 
     wr.update(values)
-    wr.save()
+    wr.save(_session(context))
 
 
 def watch_rule_delete(context, watch_name):
@@ -315,7 +316,7 @@ def watch_rule_delete(context, watch_name):
 def watch_data_create(context, values):
     obj_ref = models.WatchData()
     obj_ref.update(values)
-    obj_ref.save()
+    obj_ref.save(_session(context))
     return obj_ref
 
 
index 2a6e575a808464732d4c6f3e78803a922e3fa9ae..edf44783ee1852844519a230c42fa52f8efcbbf8 100644 (file)
@@ -114,11 +114,11 @@ class Stack(object):
         Store the stack in the database and return its ID
         If self.id is set, we update the existing stack
         '''
-        new_creds = db_api.user_creds_create(self.context.to_dict())
+        new_creds = db_api.user_creds_create(self.context)
 
         s = {
             'name': self.name,
-            'raw_template_id': self.t.store(),
+            'raw_template_id': self.t.store(self.context),
             'parameters': self.parameters.user_parameters(),
             'owner_id': owner and owner.id,
             'user_creds_id': new_creds.id,
index d22a12a308dfb194e4b13e2dbd0efc56347980c1..a7ed5315e8ac411652d6c5e1a421eea86a56cba9 100644 (file)
@@ -303,7 +303,7 @@ class Resource(object):
         self.resource_id = inst
         if self.id is not None:
             try:
-                rs = db_api.resource_get(self.stack.context, self.id)
+                rs = db_api.resource_get(self.context, self.id)
                 rs.update_and_save({'nova_instance': self.resource_id})
             except Exception as ex:
                 logger.warn('db error %s' % str(ex))
index 83098780ced13e238dee3bd804c3be74a27ca305..fc0d40ae83fe6aaeb09847eea7e3d439125ea85c 100644 (file)
@@ -13,6 +13,7 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import functools
 import webob
 
 from heat.common import context
@@ -32,6 +33,15 @@ from heat.openstack.common.rpc import service
 logger = logging.getLogger(__name__)
 
 
+def request_context(func):
+    @functools.wraps(func)
+    def wrapped(self, ctx, *args, **kwargs):
+        if ctx is not None and not isinstance(ctx, context.RequestContext):
+            ctx = context.RequestContext.from_dict(ctx.to_dict())
+        return func(self, ctx, *args, **kwargs)
+    return wrapped
+
+
 class EngineService(service.Service):
     """
     Manages the running instances from creation to destruction.
@@ -96,6 +106,7 @@ class EngineService(service.Service):
                                   context=stack_context,
                                   sid=s.id)
 
+    @request_context
     def identify_stack(self, context, stack_name):
         """
         The identify_stack method returns the full stack identifier for a
@@ -129,6 +140,7 @@ class EngineService(service.Service):
 
         return s
 
+    @request_context
     def show_stack(self, context, stack_identity):
         """
         The show_stack method returns the attributes of one stack.
@@ -146,6 +158,7 @@ class EngineService(service.Service):
 
         return {'stacks': [format_stack_detail(s) for s in stacks]}
 
+    @request_context
     def list_stacks(self, context):
         """
         The list_stacks method returns attributes of all stacks.
@@ -159,6 +172,7 @@ class EngineService(service.Service):
 
         return {'stacks': [format_stack_detail(s) for s in stacks]}
 
+    @request_context
     def create_stack(self, context, stack_name, template, params, args):
         """
         The create_stack method creates a new stack using the template
@@ -201,6 +215,7 @@ class EngineService(service.Service):
 
         return dict(stack.identifier())
 
+    @request_context
     def update_stack(self, context, stack_identity, template, params, args):
         """
         The update_stack method updates an existing stack based on the
@@ -240,6 +255,7 @@ class EngineService(service.Service):
 
         return dict(current_stack.identifier())
 
+    @request_context
     def validate_template(self, context, template):
         """
         The validate_template method uses the stack parser to check
@@ -282,6 +298,7 @@ class EngineService(service.Service):
         }
         return result
 
+    @request_context
     def get_template(self, context, stack_identity):
         """
         Get the template.
@@ -293,6 +310,7 @@ class EngineService(service.Service):
             return s.raw_template.template
         return None
 
+    @request_context
     def delete_stack(self, context, stack_identity):
         """
         The delete_stack method deletes a given stack.
@@ -313,6 +331,7 @@ class EngineService(service.Service):
         self.tg.add_thread(stack.delete)
         return None
 
+    @request_context
     def list_events(self, context, stack_identity):
         """
         The list_events method lists all events associated with a given stack.
@@ -328,6 +347,7 @@ class EngineService(service.Service):
 
         return {'events': [api.format_event(context, e) for e in events]}
 
+    @request_context
     def describe_stack_resource(self, context, stack_identity, resource_name):
         s = self._get_stack(context, stack_identity)
 
@@ -341,6 +361,7 @@ class EngineService(service.Service):
 
         return api.format_stack_resource(stack[resource_name])
 
+    @request_context
     def describe_stack_resources(self, context, stack_identity,
                                  physical_resource_id, logical_resource_id):
         if stack_identity is not None:
@@ -367,6 +388,7 @@ class EngineService(service.Service):
                 for resource in stack if resource.id is not None and
                                          name_match(resource)]
 
+    @request_context
     def list_stack_resources(self, context, stack_identity):
         s = self._get_stack(context, stack_identity)
 
@@ -375,6 +397,7 @@ class EngineService(service.Service):
         return [api.format_stack_resource(resource, detail=False)
                 for resource in stack if resource.id is not None]
 
+    @request_context
     def metadata_update(self, context, stack_id, resource_name, metadata):
         """
         Update the metadata for the given resource.
@@ -406,6 +429,7 @@ class EngineService(service.Service):
             rule = watchrule.WatchRule.load(context, watch=wr)
             rule.evaluate()
 
+    @request_context
     def create_watch_data(self, context, watch_name, stats_data):
         '''
         This could be used by CloudWatch and WaitConditions
@@ -416,6 +440,7 @@ class EngineService(service.Service):
         logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data)))
         return stats_data
 
+    @request_context
     def show_watch(self, context, watch_name):
         '''
         The show_watch method returns the attributes of one watch/alarm
@@ -435,6 +460,7 @@ class EngineService(service.Service):
         result = [api.format_watch(w) for w in wrs]
         return result
 
+    @request_context
     def show_watch_metric(self, context, namespace=None, metric_name=None):
         '''
         The show_watch method returns the datapoints for a metric
@@ -459,6 +485,7 @@ class EngineService(service.Service):
         result = [api.format_watch_data(w) for w in wds]
         return result
 
+    @request_context
     def set_watch_state(self, context, watch_name, state):
         '''
         Temporarily set the state of a given watch