1 # Copyright (c) 2014 OpenStack Foundation.
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may
5 # not use this file except in compliance with the License. You may obtain
6 # a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations
18 from debtcollector import removals
20 from sqlalchemy import and_
21 from sqlalchemy import or_
22 from sqlalchemy import sql
24 from neutron._i18n import _
25 from neutron.common import exceptions as n_exc
26 from neutron.db import sqlalchemyutils
29 def model_query_scope(context, model):
30 # Unless a context has 'admin' or 'advanced-service' rights the
31 # query will be scoped to a single tenant_id
32 return ((not context.is_admin and hasattr(model, 'tenant_id')) and
33 (not context.is_advsvc and hasattr(model, 'tenant_id')))
36 def model_query(context, model):
37 query = context.session.query(model)
38 # define basic filter condition for model query
40 if model_query_scope(context, model):
41 query_filter = (model.tenant_id == context.tenant_id)
43 if query_filter is not None:
44 query = query.filter(query_filter)
48 class CommonDbMixin(object):
49 """Common methods used in core and service plugins."""
50 # Plugins, mixin classes implementing extension will register
51 # hooks into the dict below for "augmenting" the "core way" of
52 # building a query for retrieving objects from a model class.
53 # To this aim, the register_model_query_hook and unregister_query_hook
54 # from this class should be invoked
55 _model_query_hooks = {}
57 # This dictionary will store methods for extending attributes of
58 # api resources. Mixins can use this dict for adding their own methods
59 # TODO(salvatore-orlando): Avoid using class-level variables
60 _dict_extend_functions = {}
63 def register_model_query_hook(cls, model, name, query_hook, filter_hook,
65 """Register a hook to be invoked when a query is executed.
67 Add the hooks to the _model_query_hooks dict. Models are the keys
68 of this dict, whereas the value is another dict mapping hook names to
69 callables performing the hook.
70 Each hook has a "query" component, used to build the query expression
71 and a "filter" component, which is used to build the filter expression.
73 Query hooks take as input the query being built and return a
74 transformed query expression.
76 Filter hooks take as input the filter expression being built and return
77 a transformed filter expression
79 cls._model_query_hooks.setdefault(model, {})[name] = {
80 'query': query_hook, 'filter': filter_hook,
81 'result_filters': result_filters}
84 def register_dict_extend_funcs(cls, resource, funcs):
85 cls._dict_extend_functions.setdefault(resource, []).extend(funcs)
88 def safe_reference(self):
89 """Return a weakref to the instance.
91 Minimize the potential for the instance persisting
92 unnecessarily in memory by returning a weakref proxy that
93 won't prevent deallocation.
95 return weakref.proxy(self)
97 def model_query_scope(self, context, model):
98 return model_query_scope(context, model)
100 def _model_query(self, context, model):
101 if isinstance(model, UnionModel):
102 return self._union_model_query(context, model)
104 return self._single_model_query(context, model)
106 def _union_model_query(self, context, model):
107 # A union query is a query that combines multiple sets of data
108 # together and represents them as one. So if a UnionModel was
109 # passed in, we generate the query for each model with the
110 # appropriate filters and then combine them together with the
111 # .union operator. This allows any subsequent users of the query
112 # to handle it like a normal query (e.g. add pagination/sorting/etc)
114 remaining_queries = []
115 for name, component_model in model.model_map.items():
116 query = self._single_model_query(context, component_model)
117 if model.column_type_name:
119 sql.expression.column('"%s"' % name, is_literal=True).
120 label(model.column_type_name)
122 if first_query is None:
125 remaining_queries.append(query)
126 return first_query.union(*remaining_queries)
128 def _single_model_query(self, context, model):
129 query = context.session.query(model)
130 # define basic filter condition for model query
132 if self.model_query_scope(context, model):
133 if hasattr(model, 'rbac_entries'):
134 query = query.outerjoin(model.rbac_entries)
135 rbac_model = model.rbac_entries.property.mapper.class_
137 (model.tenant_id == context.tenant_id) |
138 ((rbac_model.action == 'access_as_shared') &
139 ((rbac_model.target_tenant == context.tenant_id) |
140 (rbac_model.target_tenant == '*'))))
141 elif hasattr(model, 'shared'):
142 query_filter = ((model.tenant_id == context.tenant_id) |
143 (model.shared == sql.true()))
145 query_filter = (model.tenant_id == context.tenant_id)
146 # Execute query hooks registered from mixins and plugins
147 for _name, hooks in six.iteritems(self._model_query_hooks.get(model,
149 query_hook = hooks.get('query')
150 if isinstance(query_hook, six.string_types):
151 query_hook = getattr(self, query_hook, None)
153 query = query_hook(context, model, query)
155 filter_hook = hooks.get('filter')
156 if isinstance(filter_hook, six.string_types):
157 filter_hook = getattr(self, filter_hook, None)
159 query_filter = filter_hook(context, model, query_filter)
161 # NOTE(salvatore-orlando): 'if query_filter' will try to evaluate the
162 # condition, raising an exception
163 if query_filter is not None:
164 query = query.filter(query_filter)
167 def _fields(self, resource, fields):
169 return dict(((key, item) for key, item in resource.items()
173 @removals.remove(message='This method will be removed in N')
174 def _get_tenant_id_for_create(self, context, resource):
175 if context.is_admin and 'tenant_id' in resource:
176 tenant_id = resource['tenant_id']
177 elif ('tenant_id' in resource and
178 resource['tenant_id'] != context.tenant_id):
179 reason = _('Cannot create resource for another tenant')
180 raise n_exc.AdminRequired(reason=reason)
182 tenant_id = context.tenant_id
185 def _get_by_id(self, context, model, id):
186 query = self._model_query(context, model)
187 return query.filter(model.id == id).one()
189 def _apply_filters_to_query(self, query, model, filters, context=None):
190 if isinstance(model, UnionModel):
191 # NOTE(kevinbenton): a unionmodel is made up of multiple tables so
192 # we apply the filter to each table
193 for component_model in model.model_map.values():
194 query = self._apply_filters_to_query(query, component_model,
198 for key, value in six.iteritems(filters):
199 column = getattr(model, key, None)
200 # NOTE(kevinbenton): if column is a hybrid property that
201 # references another expression, attempting to convert to
202 # a boolean will fail so we must compare to None.
203 # See "An Important Expression Language Gotcha" in:
204 # docs.sqlalchemy.org/en/rel_0_9/changelog/migration_06.html
205 if column is not None:
207 query = query.filter(sql.false())
209 query = query.filter(column.in_(value))
210 elif key == 'shared' and hasattr(model, 'rbac_entries'):
211 # translate a filter on shared into a query against the
212 # object's rbac entries
213 query = query.outerjoin(model.rbac_entries)
214 rbac = model.rbac_entries.property.mapper.class_
215 matches = [rbac.target_tenant == '*']
217 matches.append(rbac.target_tenant == context.tenant_id)
218 # any 'access_as_shared' records that match the
219 # wildcard or requesting tenant
220 is_shared = and_(rbac.action == 'access_as_shared',
223 # NOTE(kevinbenton): we need to find objects that don't
224 # have an entry that matches the criteria above so
225 # we use a subquery to exclude them.
226 # We can't just filter the inverse of the query above
227 # because that will still give us a network shared to
228 # our tenant (or wildcard) if it's shared to another
230 # This is the column joining the table to rbac via
231 # the object_id. We can't just use model.id because
232 # subnets join on network.id so we have to inspect the
234 join_cols = model.rbac_entries.property.local_columns
235 oid_col = list(join_cols)[0]
236 is_shared = ~oid_col.in_(
237 query.session.query(rbac.object_id).
240 query = query.filter(is_shared)
241 for _nam, hooks in six.iteritems(self._model_query_hooks.get(model,
243 result_filter = hooks.get('result_filters', None)
244 if isinstance(result_filter, six.string_types):
245 result_filter = getattr(self, result_filter, None)
248 query = result_filter(query, filters)
251 def _apply_dict_extend_functions(self, resource_type,
252 response, db_object):
253 for func in self._dict_extend_functions.get(
255 args = (response, db_object)
256 if isinstance(func, six.string_types):
257 func = getattr(self, func, None)
259 # must call unbound method - use self as 1st argument
260 args = (self,) + args
264 def _get_collection_query(self, context, model, filters=None,
265 sorts=None, limit=None, marker_obj=None,
267 collection = self._model_query(context, model)
268 collection = self._apply_filters_to_query(collection, model, filters,
270 if limit and page_reverse and sorts:
271 sorts = [(s[0], not s[1]) for s in sorts]
272 collection = sqlalchemyutils.paginate_query(collection, model, limit,
274 marker_obj=marker_obj)
277 def _get_collection(self, context, model, dict_func, filters=None,
278 fields=None, sorts=None, limit=None, marker_obj=None,
280 query = self._get_collection_query(context, model, filters=filters,
283 marker_obj=marker_obj,
284 page_reverse=page_reverse)
285 items = [dict_func(c, fields) for c in query]
286 if limit and page_reverse:
290 def _get_collection_count(self, context, model, filters=None):
291 return self._get_collection_query(context, model, filters).count()
293 def _get_marker_obj(self, context, resource, limit, marker):
295 return getattr(self, '_get_%s' % resource)(context, marker)
298 def _filter_non_model_columns(self, data, model):
299 """Remove all the attributes from data which are not columns of
300 the model passed as second parameter.
302 columns = [c.name for c in model.__table__.columns]
303 return dict((k, v) for (k, v) in
304 six.iteritems(data) if k in columns)
307 class UnionModel(object):
308 """Collection of models that _model_query can query as a single table."""
310 def __init__(self, model_map, column_type_name=None):
311 # model_map is a dictionary of models keyed by an arbitrary name.
312 # If column_type_name is specified, the resulting records will have a
313 # column with that name which identifies the source of each record
314 self.model_map = model_map
315 self.column_type_name = column_type_name