]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Extract CommonDBMixin to a separate file
authorEugene Nikanorov <enikanorov@mirantis.com>
Wed, 9 Jul 2014 10:47:01 +0000 (14:47 +0400)
committerEugene Nikanorov <enikanorov@mirantis.com>
Thu, 10 Jul 2014 09:04:30 +0000 (13:04 +0400)
db_base_plugin_v2 imports too much modules that are not necessary
usually, so extract CommonDBMixin in different file.
Plus using db_base_plugin_v2 for some types of modules can lead to
cycles in imports, this refactoring should resolve the issue.

Closes-Bug: #1340145
Change-Id: Idb027d7c5cee2d5bc7598f805c56c55fd4aca048

12 files changed:
neutron/db/common_db_mixin.py [new file with mode: 0644]
neutron/db/db_base_plugin_v2.py
neutron/db/firewall/firewall_db.py
neutron/db/loadbalancer/loadbalancer_db.py
neutron/db/metering/metering_db.py
neutron/db/vpn/vpn_db.py
neutron/plugins/ml2/drivers/l2pop/db.py
neutron/plugins/nuage/nuagedb.py
neutron/services/l3_router/l3_apic.py
neutron/services/l3_router/l3_router_plugin.py
neutron/tests/unit/cisco/n1kv/test_n1kv_db.py
neutron/tests/unit/test_l3_plugin.py

diff --git a/neutron/db/common_db_mixin.py b/neutron/db/common_db_mixin.py
new file mode 100644 (file)
index 0000000..7351e84
--- /dev/null
@@ -0,0 +1,197 @@
+# Copyright (c) 2014 OpenStack Foundation.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import weakref
+
+from oslo.config import cfg
+from sqlalchemy import sql
+
+from neutron.common import exceptions as n_exc
+from neutron.db import sqlalchemyutils
+from neutron.plugins.common import constants as service_constants
+
+
+class CommonDbMixin(object):
+    """Common methods used in core and service plugins."""
+    # Plugins, mixin classes implementing extension will register
+    # hooks into the dict below for "augmenting" the "core way" of
+    # building a query for retrieving objects from a model class.
+    # To this aim, the register_model_query_hook and unregister_query_hook
+    # from this class should be invoked
+    _model_query_hooks = {}
+
+    # This dictionary will store methods for extending attributes of
+    # api resources. Mixins can use this dict for adding their own methods
+    # TODO(salvatore-orlando): Avoid using class-level variables
+    _dict_extend_functions = {}
+
+    @classmethod
+    def register_model_query_hook(cls, model, name, query_hook, filter_hook,
+                                  result_filters=None):
+        """Register a hook to be invoked when a query is executed.
+
+        Add the hooks to the _model_query_hooks dict. Models are the keys
+        of this dict, whereas the value is another dict mapping hook names to
+        callables performing the hook.
+        Each hook has a "query" component, used to build the query expression
+        and a "filter" component, which is used to build the filter expression.
+
+        Query hooks take as input the query being built and return a
+        transformed query expression.
+
+        Filter hooks take as input the filter expression being built and return
+        a transformed filter expression
+        """
+        model_hooks = cls._model_query_hooks.get(model)
+        if not model_hooks:
+            # add key to dict
+            model_hooks = {}
+            cls._model_query_hooks[model] = model_hooks
+        model_hooks[name] = {'query': query_hook, 'filter': filter_hook,
+                             'result_filters': result_filters}
+
+    @property
+    def safe_reference(self):
+        """Return a weakref to the instance.
+
+        Minimize the potential for the instance persisting
+        unnecessarily in memory by returning a weakref proxy that
+        won't prevent deallocation.
+        """
+        return weakref.proxy(self)
+
+    def _model_query(self, context, model):
+        query = context.session.query(model)
+        # define basic filter condition for model query
+        # NOTE(jkoelker) non-admin queries are scoped to their tenant_id
+        # NOTE(salvatore-orlando): unless the model allows for shared objects
+        query_filter = None
+        if not context.is_admin and hasattr(model, 'tenant_id'):
+            if hasattr(model, 'shared'):
+                query_filter = ((model.tenant_id == context.tenant_id) |
+                                (model.shared == sql.true()))
+            else:
+                query_filter = (model.tenant_id == context.tenant_id)
+        # Execute query hooks registered from mixins and plugins
+        for _name, hooks in self._model_query_hooks.get(model,
+                                                        {}).iteritems():
+            query_hook = hooks.get('query')
+            if isinstance(query_hook, basestring):
+                query_hook = getattr(self, query_hook, None)
+            if query_hook:
+                query = query_hook(context, model, query)
+
+            filter_hook = hooks.get('filter')
+            if isinstance(filter_hook, basestring):
+                filter_hook = getattr(self, filter_hook, None)
+            if filter_hook:
+                query_filter = filter_hook(context, model, query_filter)
+
+        # NOTE(salvatore-orlando): 'if query_filter' will try to evaluate the
+        # condition, raising an exception
+        if query_filter is not None:
+            query = query.filter(query_filter)
+        return query
+
+    def _fields(self, resource, fields):
+        if fields:
+            return dict(((key, item) for key, item in resource.items()
+                         if key in fields))
+        return resource
+
+    def _get_tenant_id_for_create(self, context, resource):
+        if context.is_admin and 'tenant_id' in resource:
+            tenant_id = resource['tenant_id']
+        elif ('tenant_id' in resource and
+              resource['tenant_id'] != context.tenant_id):
+            reason = _('Cannot create resource for another tenant')
+            raise n_exc.AdminRequired(reason=reason)
+        else:
+            tenant_id = context.tenant_id
+        return tenant_id
+
+    def _get_by_id(self, context, model, id):
+        query = self._model_query(context, model)
+        return query.filter(model.id == id).one()
+
+    def _apply_filters_to_query(self, query, model, filters):
+        if filters:
+            for key, value in filters.iteritems():
+                column = getattr(model, key, None)
+                if column:
+                    query = query.filter(column.in_(value))
+            for _name, hooks in self._model_query_hooks.get(model,
+                                                            {}).iteritems():
+                result_filter = hooks.get('result_filters', None)
+                if isinstance(result_filter, basestring):
+                    result_filter = getattr(self, result_filter, None)
+
+                if result_filter:
+                    query = result_filter(query, filters)
+        return query
+
+    def _apply_dict_extend_functions(self, resource_type,
+                                     response, db_object):
+        for func in self._dict_extend_functions.get(
+            resource_type, []):
+            args = (response, db_object)
+            if isinstance(func, basestring):
+                func = getattr(self, func, None)
+            else:
+                # must call unbound method - use self as 1st argument
+                args = (self,) + args
+            if func:
+                func(*args)
+
+    def _get_collection_query(self, context, model, filters=None,
+                              sorts=None, limit=None, marker_obj=None,
+                              page_reverse=False):
+        collection = self._model_query(context, model)
+        collection = self._apply_filters_to_query(collection, model, filters)
+        if limit and page_reverse and sorts:
+            sorts = [(s[0], not s[1]) for s in sorts]
+        collection = sqlalchemyutils.paginate_query(collection, model, limit,
+                                                    sorts,
+                                                    marker_obj=marker_obj)
+        return collection
+
+    def _get_collection(self, context, model, dict_func, filters=None,
+                        fields=None, sorts=None, limit=None, marker_obj=None,
+                        page_reverse=False):
+        query = self._get_collection_query(context, model, filters=filters,
+                                           sorts=sorts,
+                                           limit=limit,
+                                           marker_obj=marker_obj,
+                                           page_reverse=page_reverse)
+        items = [dict_func(c, fields) for c in query]
+        if limit and page_reverse:
+            items.reverse()
+        return items
+
+    def _get_collection_count(self, context, model, filters=None):
+        return self._get_collection_query(context, model, filters).count()
+
+    def _get_marker_obj(self, context, resource, limit, marker):
+        if limit and marker:
+            return getattr(self, '_get_%s' % resource)(context, marker)
+        return None
+
+    def _filter_non_model_columns(self, data, model):
+        """Remove all the attributes from data which are not columns of
+        the model passed as second parameter.
+        """
+        columns = [c.name for c in model.__table__.columns]
+        return dict((k, v) for (k, v) in
+                    data.iteritems() if k in columns)
index 4d804f559d896f9063fb43401d7e8fb6c003001f..484f4041333d94255d4641bb45fc8ad1a8738b6c 100644 (file)
@@ -29,6 +29,7 @@ from neutron.common import exceptions as n_exc
 from neutron.common import ipv6_utils
 from neutron import context as ctx
 from neutron.db import api as db
+from neutron.db import common_db_mixin
 from neutron.db import models_v2
 from neutron.db import sqlalchemyutils
 from neutron.extensions import l3
@@ -52,182 +53,8 @@ LOG = logging.getLogger(__name__)
 AUTO_DELETE_PORT_OWNERS = [constants.DEVICE_OWNER_DHCP]
 
 
-class CommonDbMixin(object):
-    """Common methods used in core and service plugins."""
-    # Plugins, mixin classes implementing extension will register
-    # hooks into the dict below for "augmenting" the "core way" of
-    # building a query for retrieving objects from a model class.
-    # To this aim, the register_model_query_hook and unregister_query_hook
-    # from this class should be invoked
-    _model_query_hooks = {}
-
-    # This dictionary will store methods for extending attributes of
-    # api resources. Mixins can use this dict for adding their own methods
-    # TODO(salvatore-orlando): Avoid using class-level variables
-    _dict_extend_functions = {}
-
-    @classmethod
-    def register_model_query_hook(cls, model, name, query_hook, filter_hook,
-                                  result_filters=None):
-        """Register a hook to be invoked when a query is executed.
-
-        Add the hooks to the _model_query_hooks dict. Models are the keys
-        of this dict, whereas the value is another dict mapping hook names to
-        callables performing the hook.
-        Each hook has a "query" component, used to build the query expression
-        and a "filter" component, which is used to build the filter expression.
-
-        Query hooks take as input the query being built and return a
-        transformed query expression.
-
-        Filter hooks take as input the filter expression being built and return
-        a transformed filter expression
-        """
-        model_hooks = cls._model_query_hooks.get(model)
-        if not model_hooks:
-            # add key to dict
-            model_hooks = {}
-            cls._model_query_hooks[model] = model_hooks
-        model_hooks[name] = {'query': query_hook, 'filter': filter_hook,
-                             'result_filters': result_filters}
-
-    @property
-    def safe_reference(self):
-        """Return a weakref to the instance.
-
-        Minimize the potential for the instance persisting
-        unnecessarily in memory by returning a weakref proxy that
-        won't prevent deallocation.
-        """
-        return weakref.proxy(self)
-
-    def _model_query(self, context, model):
-        query = context.session.query(model)
-        # define basic filter condition for model query
-        # NOTE(jkoelker) non-admin queries are scoped to their tenant_id
-        # NOTE(salvatore-orlando): unless the model allows for shared objects
-        query_filter = None
-        if not context.is_admin and hasattr(model, 'tenant_id'):
-            if hasattr(model, 'shared'):
-                query_filter = ((model.tenant_id == context.tenant_id) |
-                                (model.shared == sql.true()))
-            else:
-                query_filter = (model.tenant_id == context.tenant_id)
-        # Execute query hooks registered from mixins and plugins
-        for _name, hooks in self._model_query_hooks.get(model,
-                                                        {}).iteritems():
-            query_hook = hooks.get('query')
-            if isinstance(query_hook, basestring):
-                query_hook = getattr(self, query_hook, None)
-            if query_hook:
-                query = query_hook(context, model, query)
-
-            filter_hook = hooks.get('filter')
-            if isinstance(filter_hook, basestring):
-                filter_hook = getattr(self, filter_hook, None)
-            if filter_hook:
-                query_filter = filter_hook(context, model, query_filter)
-
-        # NOTE(salvatore-orlando): 'if query_filter' will try to evaluate the
-        # condition, raising an exception
-        if query_filter is not None:
-            query = query.filter(query_filter)
-        return query
-
-    def _fields(self, resource, fields):
-        if fields:
-            return dict(((key, item) for key, item in resource.items()
-                         if key in fields))
-        return resource
-
-    def _get_tenant_id_for_create(self, context, resource):
-        if context.is_admin and 'tenant_id' in resource:
-            tenant_id = resource['tenant_id']
-        elif ('tenant_id' in resource and
-              resource['tenant_id'] != context.tenant_id):
-            reason = _('Cannot create resource for another tenant')
-            raise n_exc.AdminRequired(reason=reason)
-        else:
-            tenant_id = context.tenant_id
-        return tenant_id
-
-    def _get_by_id(self, context, model, id):
-        query = self._model_query(context, model)
-        return query.filter(model.id == id).one()
-
-    def _apply_filters_to_query(self, query, model, filters):
-        if filters:
-            for key, value in filters.iteritems():
-                column = getattr(model, key, None)
-                if column:
-                    query = query.filter(column.in_(value))
-            for _name, hooks in self._model_query_hooks.get(model,
-                                                            {}).iteritems():
-                result_filter = hooks.get('result_filters', None)
-                if isinstance(result_filter, basestring):
-                    result_filter = getattr(self, result_filter, None)
-
-                if result_filter:
-                    query = result_filter(query, filters)
-        return query
-
-    def _apply_dict_extend_functions(self, resource_type,
-                                     response, db_object):
-        for func in self._dict_extend_functions.get(
-            resource_type, []):
-            args = (response, db_object)
-            if isinstance(func, basestring):
-                func = getattr(self, func, None)
-            else:
-                # must call unbound method - use self as 1st argument
-                args = (self,) + args
-            if func:
-                func(*args)
-
-    def _get_collection_query(self, context, model, filters=None,
-                              sorts=None, limit=None, marker_obj=None,
-                              page_reverse=False):
-        collection = self._model_query(context, model)
-        collection = self._apply_filters_to_query(collection, model, filters)
-        if limit and page_reverse and sorts:
-            sorts = [(s[0], not s[1]) for s in sorts]
-        collection = sqlalchemyutils.paginate_query(collection, model, limit,
-                                                    sorts,
-                                                    marker_obj=marker_obj)
-        return collection
-
-    def _get_collection(self, context, model, dict_func, filters=None,
-                        fields=None, sorts=None, limit=None, marker_obj=None,
-                        page_reverse=False):
-        query = self._get_collection_query(context, model, filters=filters,
-                                           sorts=sorts,
-                                           limit=limit,
-                                           marker_obj=marker_obj,
-                                           page_reverse=page_reverse)
-        items = [dict_func(c, fields) for c in query]
-        if limit and page_reverse:
-            items.reverse()
-        return items
-
-    def _get_collection_count(self, context, model, filters=None):
-        return self._get_collection_query(context, model, filters).count()
-
-    def _get_marker_obj(self, context, resource, limit, marker):
-        if limit and marker:
-            return getattr(self, '_get_%s' % resource)(context, marker)
-        return None
-
-    def _filter_non_model_columns(self, data, model):
-        """Remove all the attributes from data which are not columns of
-        the model passed as second parameter.
-        """
-        columns = [c.name for c in model.__table__.columns]
-        return dict((k, v) for (k, v) in
-                    data.iteritems() if k in columns)
-
-
 class NeutronDbPluginV2(neutron_plugin_base_v2.NeutronPluginBaseV2,
-                        CommonDbMixin):
+                        common_db_mixin.CommonDbMixin):
     """V2 Neutron plugin interface implementation using SQLAlchemy models.
 
     Whenever a non-read call happens the plugin will call an event handler
index 58b930d953b84dbe8e9aaf7e562fcd1aaf3dbd7e..46042df84b1a09455d2c855544e6e39e9068f560 100644 (file)
@@ -20,7 +20,7 @@ from sqlalchemy.ext.orderinglist import ordering_list
 from sqlalchemy import orm
 from sqlalchemy.orm import exc
 
-from neutron.db import db_base_plugin_v2 as base_db
+from neutron.db import common_db_mixin as base_db
 from neutron.db import model_base
 from neutron.db import models_v2
 from neutron.extensions import firewall
index 0940a945f340adf75af0a1f957417f8139c20f3c..e81e0bd74cc4a1f2a8e532c5ae3f6def5ad8532b 100644 (file)
@@ -21,7 +21,7 @@ from sqlalchemy.orm import validates
 
 from neutron.api.v2 import attributes
 from neutron.common import exceptions as n_exc
-from neutron.db import db_base_plugin_v2 as base_db
+from neutron.db import common_db_mixin as base_db
 from neutron.db import model_base
 from neutron.db import models_v2
 from neutron.db import servicetype_db as st_db
index fe48ae4fd01f1d3f5993c234f23847f9d70f9327..8de0cc427a229531a0c80689b7183b0a4a74e6c9 100644 (file)
@@ -21,7 +21,7 @@ from sqlalchemy import orm
 from neutron.api.rpc.agentnotifiers import metering_rpc_agent_api
 from neutron.common import constants
 from neutron.db import api as dbapi
-from neutron.db import db_base_plugin_v2 as base_db
+from neutron.db import common_db_mixin as base_db
 from neutron.db import l3_db
 from neutron.db import model_base
 from neutron.db import models_v2
index f3d11fecab6c042e2ae7ee1dedc2e0463032d9e7..5f4e6511755d74bfe4ce36fbad9bbaafee236685 100644 (file)
@@ -22,7 +22,7 @@ from sqlalchemy.orm import exc
 
 from neutron.common import constants as n_constants
 from neutron.db import api as qdbapi
-from neutron.db import db_base_plugin_v2 as base_db
+from neutron.db import common_db_mixin as base_db
 from neutron.db import l3_agentschedulers_db as l3_agent_db
 from neutron.db import l3_db
 from neutron.db import model_base
index 3c4fc9bcea1d9a043589ab8a58bcee6bf6259771..9f42f978a6fb9bc68b07b95a2fdfca0b8a5d339f 100644 (file)
@@ -21,7 +21,7 @@ from sqlalchemy import sql
 
 from neutron.common import constants as const
 from neutron.db import agents_db
-from neutron.db import db_base_plugin_v2 as base_db
+from neutron.db import common_db_mixin as base_db
 from neutron.db import models_v2
 from neutron.openstack.common import jsonutils
 from neutron.openstack.common import timeutils
index bd1b2f3d221e98a711aef8cba915275a7f99bce3..2f0affeba2c9a2c7759fac7ea97ada4a32a388f9 100644 (file)
@@ -14,7 +14,7 @@
 #
 # @author: Ronak Shah, Nuage Networks, Alcatel-Lucent USA Inc.
 
-from neutron.db import db_base_plugin_v2
+from neutron.db import common_db_mixin
 from neutron.plugins.nuage import nuage_models
 
 
@@ -130,7 +130,7 @@ def get_net_partition_by_id(session, id):
 
 def get_net_partitions(session, filters=None, fields=None):
     query = session.query(nuage_models.NetPartition)
-    common_db = db_base_plugin_v2.CommonDbMixin()
+    common_db = common_db_mixin.CommonDbMixin()
     query = common_db._apply_filters_to_query(query,
                                               nuage_models.NetPartition,
                                               filters)
index 02198e8dc290e1fbb9ba270f73385f982183f674..54c202ecbd2a1c0e30c657037d901f40945629cf 100644 (file)
@@ -29,7 +29,6 @@ LOG = logging.getLogger(__name__)
 
 
 class ApicL3ServicePlugin(db_base_plugin_v2.NeutronDbPluginV2,
-                          db_base_plugin_v2.CommonDbMixin,
                           extraroute_db.ExtraRoute_db_mixin,
                           l3_gwmode_db.L3_NAT_db_mixin):
     """Implementation of the APIC L3 Router Service Plugin.
index c018a3c4e6dfbde0dc3bd9c6bcc844270718989f..33dc46794188e54385b054d22087c868e8c033dc 100644 (file)
@@ -22,7 +22,7 @@ from neutron.common import constants as q_const
 from neutron.common import rpc as n_rpc
 from neutron.common import topics
 from neutron.db import api as qdbapi
-from neutron.db import db_base_plugin_v2
+from neutron.db import common_db_mixin
 from neutron.db import extraroute_db
 from neutron.db import l3_agentschedulers_db
 from neutron.db import l3_gwmode_db
@@ -38,7 +38,7 @@ class L3RouterPluginRpcCallbacks(n_rpc.RpcCallback,
     RPC_API_VERSION = '1.1'
 
 
-class L3RouterPlugin(db_base_plugin_v2.CommonDbMixin,
+class L3RouterPlugin(common_db_mixin.CommonDbMixin,
                      extraroute_db.ExtraRoute_db_mixin,
                      l3_gwmode_db.L3_NAT_db_mixin,
                      l3_agentschedulers_db.L3AgentSchedulerDbMixin):
index dac40e4c744659b54b40dd053b63f70bbe2eb0d5..bf7c77b70fe87bf7094ce1adb1b5f94caf9fe18f 100644 (file)
@@ -23,7 +23,7 @@ from testtools import matchers
 from neutron.common import exceptions as n_exc
 from neutron import context
 from neutron.db import api as db
-from neutron.db import db_base_plugin_v2
+from neutron.db import common_db_mixin
 from neutron.plugins.cisco.common import cisco_constants
 from neutron.plugins.cisco.common import cisco_exceptions as c_exc
 from neutron.plugins.cisco.db import n1kv_db_v2
@@ -769,7 +769,7 @@ class PolicyProfileTests(base.BaseTestCase):
 
 class ProfileBindingTests(base.BaseTestCase,
                           n1kv_db_v2.NetworkProfile_db_mixin,
-                          db_base_plugin_v2.CommonDbMixin):
+                          common_db_mixin.CommonDbMixin):
 
     def setUp(self):
         super(ProfileBindingTests, self).setUp()
index 4eb80d0d33dfb3d6941b29e834804ddf8fed5632..db8698e283117d868e96af7375248469a97437d1 100644 (file)
@@ -27,6 +27,7 @@ from neutron.common import constants as l3_constants
 from neutron.common import exceptions as n_exc
 from neutron import context
 from neutron.db import api as qdbapi
+from neutron.db import common_db_mixin
 from neutron.db import db_base_plugin_v2
 from neutron.db import external_net_db
 from neutron.db import l3_agentschedulers_db
@@ -283,7 +284,7 @@ class TestNoL3NatPlugin(TestL3NatBasePlugin):
 
 # A L3 routing service plugin class for tests with plugins that
 # delegate away L3 routing functionality
-class TestL3NatServicePlugin(db_base_plugin_v2.CommonDbMixin,
+class TestL3NatServicePlugin(common_db_mixin.CommonDbMixin,
                              l3_db.L3_NAT_db_mixin):
 
     supported_extension_aliases = ["router"]