]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
Replace to_dict() calls with a function decorator
authorJohn Schwarz <jschwarz@redhat.com>
Mon, 27 Jul 2015 09:09:10 +0000 (12:09 +0300)
committerJohn Schwarz <jschwarz@redhat.com>
Mon, 3 Aug 2015 08:42:25 +0000 (11:42 +0300)
Up until now, API server functions would need to return simple iterable
objects, such as dicts and lists of dicts. This patch introduces a
decorator which allows such functions to return non-simple objects (as
long as the returned object implements the 'to_dict()' method, or is a
list of such objects) and converts them on its own, simplifying the
user's code and removing code duplication.

Change-Id: Ib30a9213b86b33826291197cf01f00bc1dd3db52

neutron/db/db_base_plugin_common.py
neutron/services/qos/qos_plugin.py
neutron/tests/unit/db/test_db_base_plugin_common.py

index 4ce5daab7b684413cd242dcd289b9046b848a0aa..c2fbff201071b3afa10aa6f93e09b7c30e600786 100644 (file)
@@ -29,16 +29,30 @@ from neutron.db import models_v2
 LOG = logging.getLogger(__name__)
 
 
+def convert_result_to_dict(f):
+    @functools.wraps(f)
+    def inner(*args, **kwargs):
+        result = f(*args, **kwargs)
+
+        if result is None:
+            return None
+        elif isinstance(result, list):
+            return [r.to_dict() for r in result]
+        else:
+            return result.to_dict()
+    return inner
+
+
 def filter_fields(f):
     @functools.wraps(f)
     def inner_filter(*args, **kwargs):
         result = f(*args, **kwargs)
         fields = kwargs.get('fields')
         if not fields:
-            pos = f.func_code.co_varnames.index('fields')
             try:
+                pos = f.func_code.co_varnames.index('fields')
                 fields = args[pos]
-            except IndexError:
+            except (IndexError, ValueError):
                 return result
 
         do_filter = lambda d: {k: v for k, v in d.items() if k in fields}
index 23135bf82be7a942c580d4137c5ecd6d9a1fbfd7..d66acc2685cf5055f6c06023a5bcf6cd01f300ae 100644 (file)
@@ -42,18 +42,20 @@ class QoSPlugin(qos.QoSPluginBase):
         self.notification_driver_manager = (
             driver_mgr.QosServiceNotificationDriverManager())
 
+    @db_base_plugin_common.convert_result_to_dict
     def create_policy(self, context, policy):
         policy = policy_object.QosPolicy(context, **policy['policy'])
         policy.create()
         self.notification_driver_manager.create_policy(policy)
-        return policy.to_dict()
+        return policy
 
+    @db_base_plugin_common.convert_result_to_dict
     def update_policy(self, context, policy_id, policy):
         policy = policy_object.QosPolicy(context, **policy['policy'])
         policy.id = policy_id
         policy.update()
         self.notification_driver_manager.update_policy(policy)
-        return policy.to_dict()
+        return policy
 
     def delete_policy(self, context, policy_id):
         policy = policy_object.QosPolicy(context)
@@ -68,21 +70,23 @@ class QoSPlugin(qos.QoSPluginBase):
         return obj
 
     @db_base_plugin_common.filter_fields
+    @db_base_plugin_common.convert_result_to_dict
     def get_policy(self, context, policy_id, fields=None):
-        return self._get_policy_obj(context, policy_id).to_dict()
+        return self._get_policy_obj(context, policy_id)
 
     @db_base_plugin_common.filter_fields
+    @db_base_plugin_common.convert_result_to_dict
     def get_policies(self, context, filters=None, fields=None,
                      sorts=None, limit=None, marker=None,
                      page_reverse=False):
         #TODO(QoS): Support all the optional parameters
-        return [policy_obj.to_dict() for policy_obj in
-                policy_object.QosPolicy.get_objects(context)]
+        return policy_object.QosPolicy.get_objects(context)
 
     #TODO(QoS): Consider adding a proxy catch-all for rules, so
     #           we capture the API function call, and just pass
     #           the rule type as a parameter removing lots of
     #           future code duplication when we have more rules.
+    @db_base_plugin_common.convert_result_to_dict
     def create_policy_bandwidth_limit_rule(self, context, policy_id,
                                            bandwidth_limit_rule):
         #TODO(QoS): avoid creation of severan bandwidth limit rules
@@ -96,8 +100,9 @@ class QoSPlugin(qos.QoSPluginBase):
             **bandwidth_limit_rule['bandwidth_limit_rule'])
         rule.create()
         self.notification_driver_manager.update_policy(policy)
-        return rule.to_dict()
+        return rule
 
+    @db_base_plugin_common.convert_result_to_dict
     def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
                                            bandwidth_limit_rule):
         # validate that we have access to the policy
@@ -107,7 +112,7 @@ class QoSPlugin(qos.QoSPluginBase):
         rule.id = rule_id
         rule.update()
         self.notification_driver_manager.update_policy(policy)
-        return rule.to_dict()
+        return rule
 
     def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
         # validate that we have access to the policy
@@ -118,14 +123,16 @@ class QoSPlugin(qos.QoSPluginBase):
         self.notification_driver_manager.update_policy(policy)
 
     @db_base_plugin_common.filter_fields
+    @db_base_plugin_common.convert_result_to_dict
     def get_policy_bandwidth_limit_rule(self, context, rule_id,
                                         policy_id, fields=None):
         # validate that we have access to the policy
         self._get_policy_obj(context, policy_id)
         return rule_object.QosBandwidthLimitRule.get_by_id(context,
-                                                           rule_id).to_dict()
+                                                           rule_id)
 
     @db_base_plugin_common.filter_fields
+    @db_base_plugin_common.convert_result_to_dict
     def get_policy_bandwidth_limit_rules(self, context, policy_id,
                                          filters=None, fields=None,
                                          sorts=None, limit=None,
@@ -133,12 +140,11 @@ class QoSPlugin(qos.QoSPluginBase):
         #TODO(QoS): Support all the optional parameters
         # validate that we have access to the policy
         self._get_policy_obj(context, policy_id)
-        return [rule_obj.to_dict() for rule_obj in
-                rule_object.QosBandwidthLimitRule.get_objects(context)]
+        return rule_object.QosBandwidthLimitRule.get_objects(context)
 
     @db_base_plugin_common.filter_fields
+    @db_base_plugin_common.convert_result_to_dict
     def get_rule_types(self, context, filters=None, fields=None,
                        sorts=None, limit=None,
                        marker=None, page_reverse=False):
-        return [rule_type_obj.to_dict() for rule_type_obj in
-                rule_type_object.QosRuleType.get_objects()]
+        return rule_type_object.QosRuleType.get_objects()
index 9074bf6183c14f333d351b2890c8a4e775d6c1ad..21866522ad78f5df90b2fc060b838110204a50be 100644 (file)
@@ -17,6 +17,35 @@ from neutron.db import db_base_plugin_common
 from neutron.tests import base
 
 
+class DummyObject(object):
+    def __init__(self, **kwargs):
+        self.kwargs = kwargs
+
+    def to_dict(self):
+        return self.kwargs
+
+
+class ConvertToDictTestCase(base.BaseTestCase):
+
+    @db_base_plugin_common.convert_result_to_dict
+    def method_dict(self, fields=None):
+        return DummyObject(one=1, two=2, three=3)
+
+    @db_base_plugin_common.convert_result_to_dict
+    def method_list(self):
+        return [DummyObject(one=1, two=2, three=3)] * 3
+
+    def test_simple_object(self):
+        expected = {'one': 1, 'two': 2, 'three': 3}
+        observed = self.method_dict()
+        self.assertEqual(expected, observed)
+
+    def test_list_of_objects(self):
+        expected = [{'one': 1, 'two': 2, 'three': 3}] * 3
+        observed = self.method_list()
+        self.assertEqual(expected, observed)
+
+
 class FilterFieldsTestCase(base.BaseTestCase):
 
     @db_base_plugin_common.filter_fields