From: Steve Baker Date: Wed, 21 Nov 2012 19:15:05 +0000 (+1300) Subject: Get db session from the context. X-Git-Tag: 2014.1~1186^2 X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=67f4f608153048ec3cb566d63d4df7b1d7fe05e4;p=openstack-build%2Fheat-build.git Get db session from the context. The aim is to use a single sqlalchemy session for an RPC request. The context object passed to EngineAPI methods is actually an RpcContext which contains the same data as the RequestContext. The @request_context decorator turns this back into a RequestContext which can now have other behaviours added to it. RequestContext now has a lazy loaded session attribute. Save calls on created entities need to be passed the shared session. Change-Id: Ied4e66deaca205362b84fb698f75cc872886607d --- diff --git a/heat/common/context.py b/heat/common/context.py index 3a8eb73c..51488934 100644 --- a/heat/common/context.py +++ b/heat/common/context.py @@ -19,6 +19,7 @@ from heat.common import wsgi from heat.openstack.common import cfg from heat.openstack.common import importutils from heat.common import utils as heat_utils +from heat.db import api as db_api def generate_request_id(): @@ -64,10 +65,17 @@ class RequestContext(object): self.owner_is_tenant = owner_is_tenant if overwrite or not hasattr(local.store, 'context'): self.update_store() + self._session = None def update_store(self): local.store.context = self + @property + def session(self): + if self._session is None: + self._session = db_api.get_session() + return self._session + def to_dict(self): return {'auth_token': self.auth_token, 'username': self.username, diff --git a/heat/db/api.py b/heat/db/api.py index 14ca76c2..d9cbbb8a 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -48,6 +48,10 @@ def configure(): SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout +def get_session(): + return IMPL.get_session() + + def raw_template_get(context, template_id): return IMPL.raw_template_get(context, template_id) diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 57680b22..459e0d9c 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -22,17 +22,17 @@ from heat.db.sqlalchemy.session import get_session from heat.engine import auth -def model_query(context, *args, **kwargs): - """ - :param session: if present, the session to use - """ - session = kwargs.get('session') or get_session() - +def model_query(context, *args): + session = _session(context) query = session.query(*args) return query +def _session(context): + return (context and context.session) or get_session() + + def raw_template_get(context, template_id): result = model_query(context, models.RawTemplate).get(template_id) @@ -54,7 +54,7 @@ def raw_template_get_all(context): def raw_template_create(context, values): raw_template_ref = models.RawTemplate() raw_template_ref.update(values) - raw_template_ref.save() + raw_template_ref.save(_session(context)) return raw_template_ref @@ -97,7 +97,7 @@ def resource_get_all(context): def resource_create(context, values): resource_ref = models.Resource() resource_ref.update(values) - resource_ref.save() + resource_ref.save(_session(context)) return resource_ref @@ -147,7 +147,7 @@ def stack_get_all_by_tenant(context): def stack_create(context, values): stack_ref = models.Stack() stack_ref.update(values) - stack_ref.save() + stack_ref.save(_session(context)) return stack_ref @@ -161,7 +161,7 @@ def stack_update(context, stack_id, values): old_template_id = stack.raw_template_id stack.update(values) - stack.save() + stack.save(_session(context)) # When the raw_template ID changes, we delete the old template # after storing the new template ID @@ -196,13 +196,14 @@ def stack_delete(context, stack_id): session.flush() -def user_creds_create(values): +def user_creds_create(context): + values = context.to_dict() user_creds_ref = models.UserCreds() user_creds_ref.update(values) user_creds_ref.password = auth.encrypt(values['password']) user_creds_ref.service_password = auth.encrypt(values['service_password']) user_creds_ref.aws_creds = auth.encrypt(values['aws_creds']) - user_creds_ref.save() + user_creds_ref.save(_session(context)) return user_creds_ref @@ -250,7 +251,7 @@ def event_get_all_by_stack(context, stack_id): def event_create(context, values): event_ref = models.Event() event_ref.update(values) - event_ref.save() + event_ref.save(_session(context)) return event_ref @@ -280,7 +281,7 @@ def watch_rule_get_all_by_stack(context, stack_id): def watch_rule_create(context, values): obj_ref = models.WatchRule() obj_ref.update(values) - obj_ref.save() + obj_ref.save(_session(context)) return obj_ref @@ -292,7 +293,7 @@ def watch_rule_update(context, watch_id, values): (watch_id, 'that does not exist')) wr.update(values) - wr.save() + wr.save(_session(context)) def watch_rule_delete(context, watch_name): @@ -315,7 +316,7 @@ def watch_rule_delete(context, watch_name): def watch_data_create(context, values): obj_ref = models.WatchData() obj_ref.update(values) - obj_ref.save() + obj_ref.save(_session(context)) return obj_ref diff --git a/heat/engine/parser.py b/heat/engine/parser.py index 2a6e575a..edf44783 100644 --- a/heat/engine/parser.py +++ b/heat/engine/parser.py @@ -114,11 +114,11 @@ class Stack(object): Store the stack in the database and return its ID If self.id is set, we update the existing stack ''' - new_creds = db_api.user_creds_create(self.context.to_dict()) + new_creds = db_api.user_creds_create(self.context) s = { 'name': self.name, - 'raw_template_id': self.t.store(), + 'raw_template_id': self.t.store(self.context), 'parameters': self.parameters.user_parameters(), 'owner_id': owner and owner.id, 'user_creds_id': new_creds.id, diff --git a/heat/engine/resources/resource.py b/heat/engine/resources/resource.py index d22a12a3..a7ed5315 100644 --- a/heat/engine/resources/resource.py +++ b/heat/engine/resources/resource.py @@ -303,7 +303,7 @@ class Resource(object): self.resource_id = inst if self.id is not None: try: - rs = db_api.resource_get(self.stack.context, self.id) + rs = db_api.resource_get(self.context, self.id) rs.update_and_save({'nova_instance': self.resource_id}) except Exception as ex: logger.warn('db error %s' % str(ex)) diff --git a/heat/engine/service.py b/heat/engine/service.py index 83098780..fc0d40ae 100644 --- a/heat/engine/service.py +++ b/heat/engine/service.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import functools import webob from heat.common import context @@ -32,6 +33,15 @@ from heat.openstack.common.rpc import service logger = logging.getLogger(__name__) +def request_context(func): + @functools.wraps(func) + def wrapped(self, ctx, *args, **kwargs): + if ctx is not None and not isinstance(ctx, context.RequestContext): + ctx = context.RequestContext.from_dict(ctx.to_dict()) + return func(self, ctx, *args, **kwargs) + return wrapped + + class EngineService(service.Service): """ Manages the running instances from creation to destruction. @@ -96,6 +106,7 @@ class EngineService(service.Service): context=stack_context, sid=s.id) + @request_context def identify_stack(self, context, stack_name): """ The identify_stack method returns the full stack identifier for a @@ -129,6 +140,7 @@ class EngineService(service.Service): return s + @request_context def show_stack(self, context, stack_identity): """ The show_stack method returns the attributes of one stack. @@ -146,6 +158,7 @@ class EngineService(service.Service): return {'stacks': [format_stack_detail(s) for s in stacks]} + @request_context def list_stacks(self, context): """ The list_stacks method returns attributes of all stacks. @@ -159,6 +172,7 @@ class EngineService(service.Service): return {'stacks': [format_stack_detail(s) for s in stacks]} + @request_context def create_stack(self, context, stack_name, template, params, args): """ The create_stack method creates a new stack using the template @@ -201,6 +215,7 @@ class EngineService(service.Service): return dict(stack.identifier()) + @request_context def update_stack(self, context, stack_identity, template, params, args): """ The update_stack method updates an existing stack based on the @@ -240,6 +255,7 @@ class EngineService(service.Service): return dict(current_stack.identifier()) + @request_context def validate_template(self, context, template): """ The validate_template method uses the stack parser to check @@ -282,6 +298,7 @@ class EngineService(service.Service): } return result + @request_context def get_template(self, context, stack_identity): """ Get the template. @@ -293,6 +310,7 @@ class EngineService(service.Service): return s.raw_template.template return None + @request_context def delete_stack(self, context, stack_identity): """ The delete_stack method deletes a given stack. @@ -313,6 +331,7 @@ class EngineService(service.Service): self.tg.add_thread(stack.delete) return None + @request_context def list_events(self, context, stack_identity): """ The list_events method lists all events associated with a given stack. @@ -328,6 +347,7 @@ class EngineService(service.Service): return {'events': [api.format_event(context, e) for e in events]} + @request_context def describe_stack_resource(self, context, stack_identity, resource_name): s = self._get_stack(context, stack_identity) @@ -341,6 +361,7 @@ class EngineService(service.Service): return api.format_stack_resource(stack[resource_name]) + @request_context def describe_stack_resources(self, context, stack_identity, physical_resource_id, logical_resource_id): if stack_identity is not None: @@ -367,6 +388,7 @@ class EngineService(service.Service): for resource in stack if resource.id is not None and name_match(resource)] + @request_context def list_stack_resources(self, context, stack_identity): s = self._get_stack(context, stack_identity) @@ -375,6 +397,7 @@ class EngineService(service.Service): return [api.format_stack_resource(resource, detail=False) for resource in stack if resource.id is not None] + @request_context def metadata_update(self, context, stack_id, resource_name, metadata): """ Update the metadata for the given resource. @@ -406,6 +429,7 @@ class EngineService(service.Service): rule = watchrule.WatchRule.load(context, watch=wr) rule.evaluate() + @request_context def create_watch_data(self, context, watch_name, stats_data): ''' This could be used by CloudWatch and WaitConditions @@ -416,6 +440,7 @@ class EngineService(service.Service): logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data))) return stats_data + @request_context def show_watch(self, context, watch_name): ''' The show_watch method returns the attributes of one watch/alarm @@ -435,6 +460,7 @@ class EngineService(service.Service): result = [api.format_watch(w) for w in wrs] return result + @request_context def show_watch_metric(self, context, namespace=None, metric_name=None): ''' The show_watch method returns the datapoints for a metric @@ -459,6 +485,7 @@ class EngineService(service.Service): result = [api.format_watch_data(w) for w in wds] return result + @request_context def set_watch_state(self, context, watch_name, state): ''' Temporarily set the state of a given watch