b504cfbb5a74cdd97d728743ffcbf9edb139c59e
[openstack-build/neutron-build.git] / neutron / db / common_db_mixin.py
1 # Copyright (c) 2014 OpenStack Foundation.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 import weakref
17
18 from debtcollector import removals
19 import six
20 from sqlalchemy import and_
21 from sqlalchemy import or_
22 from sqlalchemy import sql
23
24 from neutron._i18n import _
25 from neutron.common import exceptions as n_exc
26 from neutron.db import sqlalchemyutils
27
28
29 def model_query_scope(context, model):
30     # Unless a context has 'admin' or 'advanced-service' rights the
31     # query will be scoped to a single tenant_id
32     return ((not context.is_admin and hasattr(model, 'tenant_id')) and
33             (not context.is_advsvc and hasattr(model, 'tenant_id')))
34
35
36 def model_query(context, model):
37     query = context.session.query(model)
38     # define basic filter condition for model query
39     query_filter = None
40     if model_query_scope(context, model):
41         query_filter = (model.tenant_id == context.tenant_id)
42
43     if query_filter is not None:
44         query = query.filter(query_filter)
45     return query
46
47
48 class CommonDbMixin(object):
49     """Common methods used in core and service plugins."""
50     # Plugins, mixin classes implementing extension will register
51     # hooks into the dict below for "augmenting" the "core way" of
52     # building a query for retrieving objects from a model class.
53     # To this aim, the register_model_query_hook and unregister_query_hook
54     # from this class should be invoked
55     _model_query_hooks = {}
56
57     # This dictionary will store methods for extending attributes of
58     # api resources. Mixins can use this dict for adding their own methods
59     # TODO(salvatore-orlando): Avoid using class-level variables
60     _dict_extend_functions = {}
61
62     @classmethod
63     def register_model_query_hook(cls, model, name, query_hook, filter_hook,
64                                   result_filters=None):
65         """Register a hook to be invoked when a query is executed.
66
67         Add the hooks to the _model_query_hooks dict. Models are the keys
68         of this dict, whereas the value is another dict mapping hook names to
69         callables performing the hook.
70         Each hook has a "query" component, used to build the query expression
71         and a "filter" component, which is used to build the filter expression.
72
73         Query hooks take as input the query being built and return a
74         transformed query expression.
75
76         Filter hooks take as input the filter expression being built and return
77         a transformed filter expression
78         """
79         cls._model_query_hooks.setdefault(model, {})[name] = {
80             'query': query_hook, 'filter': filter_hook,
81             'result_filters': result_filters}
82
83     @classmethod
84     def register_dict_extend_funcs(cls, resource, funcs):
85         cls._dict_extend_functions.setdefault(resource, []).extend(funcs)
86
87     @property
88     def safe_reference(self):
89         """Return a weakref to the instance.
90
91         Minimize the potential for the instance persisting
92         unnecessarily in memory by returning a weakref proxy that
93         won't prevent deallocation.
94         """
95         return weakref.proxy(self)
96
97     def model_query_scope(self, context, model):
98         return model_query_scope(context, model)
99
100     def _model_query(self, context, model):
101         if isinstance(model, UnionModel):
102             return self._union_model_query(context, model)
103         else:
104             return self._single_model_query(context, model)
105
106     def _union_model_query(self, context, model):
107         # A union query is a query that combines multiple sets of data
108         # together and represents them as one. So if a UnionModel was
109         # passed in, we generate the query for each model with the
110         # appropriate filters and then combine them together with the
111         # .union operator. This allows any subsequent users of the query
112         # to handle it like a normal query (e.g. add pagination/sorting/etc)
113         first_query = None
114         remaining_queries = []
115         for name, component_model in model.model_map.items():
116             query = self._single_model_query(context, component_model)
117             if model.column_type_name:
118                 query.add_columns(
119                     sql.expression.column('"%s"' % name, is_literal=True).
120                     label(model.column_type_name)
121                 )
122             if first_query is None:
123                 first_query = query
124             else:
125                 remaining_queries.append(query)
126         return first_query.union(*remaining_queries)
127
128     def _single_model_query(self, context, model):
129         query = context.session.query(model)
130         # define basic filter condition for model query
131         query_filter = None
132         if self.model_query_scope(context, model):
133             if hasattr(model, 'rbac_entries'):
134                 query = query.outerjoin(model.rbac_entries)
135                 rbac_model = model.rbac_entries.property.mapper.class_
136                 query_filter = (
137                     (model.tenant_id == context.tenant_id) |
138                     ((rbac_model.action == 'access_as_shared') &
139                      ((rbac_model.target_tenant == context.tenant_id) |
140                       (rbac_model.target_tenant == '*'))))
141             elif hasattr(model, 'shared'):
142                 query_filter = ((model.tenant_id == context.tenant_id) |
143                                 (model.shared == sql.true()))
144             else:
145                 query_filter = (model.tenant_id == context.tenant_id)
146         # Execute query hooks registered from mixins and plugins
147         for _name, hooks in six.iteritems(self._model_query_hooks.get(model,
148                                                                       {})):
149             query_hook = hooks.get('query')
150             if isinstance(query_hook, six.string_types):
151                 query_hook = getattr(self, query_hook, None)
152             if query_hook:
153                 query = query_hook(context, model, query)
154
155             filter_hook = hooks.get('filter')
156             if isinstance(filter_hook, six.string_types):
157                 filter_hook = getattr(self, filter_hook, None)
158             if filter_hook:
159                 query_filter = filter_hook(context, model, query_filter)
160
161         # NOTE(salvatore-orlando): 'if query_filter' will try to evaluate the
162         # condition, raising an exception
163         if query_filter is not None:
164             query = query.filter(query_filter)
165         return query
166
167     def _fields(self, resource, fields):
168         if fields:
169             return dict(((key, item) for key, item in resource.items()
170                          if key in fields))
171         return resource
172
173     @removals.remove(message='This method will be removed in N')
174     def _get_tenant_id_for_create(self, context, resource):
175         if context.is_admin and 'tenant_id' in resource:
176             tenant_id = resource['tenant_id']
177         elif ('tenant_id' in resource and
178               resource['tenant_id'] != context.tenant_id):
179             reason = _('Cannot create resource for another tenant')
180             raise n_exc.AdminRequired(reason=reason)
181         else:
182             tenant_id = context.tenant_id
183         return tenant_id
184
185     def _get_by_id(self, context, model, id):
186         query = self._model_query(context, model)
187         return query.filter(model.id == id).one()
188
189     def _apply_filters_to_query(self, query, model, filters, context=None):
190         if isinstance(model, UnionModel):
191             # NOTE(kevinbenton): a unionmodel is made up of multiple tables so
192             # we apply the filter to each table
193             for component_model in model.model_map.values():
194                 query = self._apply_filters_to_query(query, component_model,
195                                                      filters, context)
196             return query
197         if filters:
198             for key, value in six.iteritems(filters):
199                 column = getattr(model, key, None)
200                 # NOTE(kevinbenton): if column is a hybrid property that
201                 # references another expression, attempting to convert to
202                 # a boolean will fail so we must compare to None.
203                 # See "An Important Expression Language Gotcha" in:
204                 # docs.sqlalchemy.org/en/rel_0_9/changelog/migration_06.html
205                 if column is not None:
206                     if not value:
207                         query = query.filter(sql.false())
208                         return query
209                     query = query.filter(column.in_(value))
210                 elif key == 'shared' and hasattr(model, 'rbac_entries'):
211                     # translate a filter on shared into a query against the
212                     # object's rbac entries
213                     query = query.outerjoin(model.rbac_entries)
214                     rbac = model.rbac_entries.property.mapper.class_
215                     matches = [rbac.target_tenant == '*']
216                     if context:
217                         matches.append(rbac.target_tenant == context.tenant_id)
218                     # any 'access_as_shared' records that match the
219                     # wildcard or requesting tenant
220                     is_shared = and_(rbac.action == 'access_as_shared',
221                                      or_(*matches))
222                     if not value[0]:
223                         # NOTE(kevinbenton): we need to find objects that don't
224                         # have an entry that matches the criteria above so
225                         # we use a subquery to exclude them.
226                         # We can't just filter the inverse of the query above
227                         # because that will still give us a network shared to
228                         # our tenant (or wildcard) if it's shared to another
229                         # tenant.
230                         # This is the column joining the table to rbac via
231                         # the object_id. We can't just use model.id because
232                         # subnets join on network.id so we have to inspect the
233                         # relationship.
234                         join_cols = model.rbac_entries.property.local_columns
235                         oid_col = list(join_cols)[0]
236                         is_shared = ~oid_col.in_(
237                             query.session.query(rbac.object_id).
238                             filter(is_shared)
239                         )
240                     query = query.filter(is_shared)
241             for _nam, hooks in six.iteritems(self._model_query_hooks.get(model,
242                                                                          {})):
243                 result_filter = hooks.get('result_filters', None)
244                 if isinstance(result_filter, six.string_types):
245                     result_filter = getattr(self, result_filter, None)
246
247                 if result_filter:
248                     query = result_filter(query, filters)
249         return query
250
251     def _apply_dict_extend_functions(self, resource_type,
252                                      response, db_object):
253         for func in self._dict_extend_functions.get(
254             resource_type, []):
255             args = (response, db_object)
256             if isinstance(func, six.string_types):
257                 func = getattr(self, func, None)
258             else:
259                 # must call unbound method - use self as 1st argument
260                 args = (self,) + args
261             if func:
262                 func(*args)
263
264     def _get_collection_query(self, context, model, filters=None,
265                               sorts=None, limit=None, marker_obj=None,
266                               page_reverse=False):
267         collection = self._model_query(context, model)
268         collection = self._apply_filters_to_query(collection, model, filters,
269                                                   context)
270         if limit and page_reverse and sorts:
271             sorts = [(s[0], not s[1]) for s in sorts]
272         collection = sqlalchemyutils.paginate_query(collection, model, limit,
273                                                     sorts,
274                                                     marker_obj=marker_obj)
275         return collection
276
277     def _get_collection(self, context, model, dict_func, filters=None,
278                         fields=None, sorts=None, limit=None, marker_obj=None,
279                         page_reverse=False):
280         query = self._get_collection_query(context, model, filters=filters,
281                                            sorts=sorts,
282                                            limit=limit,
283                                            marker_obj=marker_obj,
284                                            page_reverse=page_reverse)
285         items = [dict_func(c, fields) for c in query]
286         if limit and page_reverse:
287             items.reverse()
288         return items
289
290     def _get_collection_count(self, context, model, filters=None):
291         return self._get_collection_query(context, model, filters).count()
292
293     def _get_marker_obj(self, context, resource, limit, marker):
294         if limit and marker:
295             return getattr(self, '_get_%s' % resource)(context, marker)
296         return None
297
298     def _filter_non_model_columns(self, data, model):
299         """Remove all the attributes from data which are not columns of
300         the model passed as second parameter.
301         """
302         columns = [c.name for c in model.__table__.columns]
303         return dict((k, v) for (k, v) in
304                     six.iteritems(data) if k in columns)
305
306
307 class UnionModel(object):
308     """Collection of models that _model_query can query as a single table."""
309
310     def __init__(self, model_map, column_type_name=None):
311         # model_map is a dictionary of models keyed by an arbitrary name.
312         # If column_type_name is specified, the resulting records will have a
313         # column with that name which identifies the source of each record
314         self.model_map = model_map
315         self.column_type_name = column_type_name