]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Adding marker, pagination, sort key and sort direction to v2 api
authorMike Perez <thingee@gmail.com>
Tue, 18 Dec 2012 08:42:49 +0000 (00:42 -0800)
committerMike Perez <thingee@gmail.com>
Sun, 23 Dec 2012 07:46:29 +0000 (02:46 -0500)
Taking a cue from quantum, use glance's pagination function.

bp api-pagination

Change-Id: Ida7ade8d9332c88679849f7c640651df7e855abb

cinder/api/v1/volumes.py
cinder/api/v2/volumes.py
cinder/common/sqlalchemyutils.py [new file with mode: 0755]
cinder/db/api.py
cinder/db/sqlalchemy/api.py
cinder/tests/api/v1/test_volumes.py
cinder/tests/api/v2/stubs.py
cinder/tests/api/v2/test_volumes.py
cinder/volume/api.py

index acf824fa3aed7e70baf9c2a0480ebf92d21ac0e4..7c991a1377062362ebb6b7b9f03bbf4fa11305c1 100644 (file)
@@ -260,7 +260,9 @@ class VolumeController(wsgi.Controller):
         remove_invalid_options(context,
                                search_opts, self._get_volume_search_options())
 
-        volumes = self.volume_api.get_all(context, search_opts=search_opts)
+        volumes = self.volume_api.get_all(context, marker=None, limit=None,
+                                          sort_key='created_at',
+                                          sort_dir='desc', filters=search_opts)
         limited_list = common.limited(volumes, req)
         res = [entity_maker(context, vol) for vol in limited_list]
         return {'volumes': res}
index cd5ae6963c32c2d89e100cdc6403e6fde108959a..46f81da8075c0103e4d7825076687ff2133bcd1c 100644 (file)
@@ -171,20 +171,27 @@ class VolumeController(wsgi.Controller):
     def _get_volumes(self, req, is_detail):
         """Returns a list of volumes, transformed through view builder."""
 
-        search_opts = {}
-        search_opts.update(req.GET)
-
         context = req.environ['cinder.context']
+
+        params = req.params.copy()
+        marker = params.pop('marker', None)
+        limit = params.pop('limit', None)
+        sort_key = params.pop('sort_key', 'created_at')
+        sort_dir = params.pop('sort_dir', 'desc')
+        filters = params
+
         remove_invalid_options(context,
-                               search_opts, self._get_volume_search_options())
+                               filters, self._get_volume_filter_options())
 
         # NOTE(thingee): v2 API allows name instead of display_name
-        if 'name' in search_opts:
-            search_opts['display_name'] = search_opts['name']
-            del search_opts['name']
+        if 'name' in filters:
+            filters['display_name'] = filters['name']
+            del filters['name']
 
-        volumes = self.volume_api.get_all(context, search_opts=search_opts)
+        volumes = self.volume_api.get_all(context, marker, limit, sort_key,
+                                          sort_dir, filters)
         limited_list = common.limited(volumes, req)
+
         if is_detail:
             volumes = self._view_builder.detail_list(req, limited_list)
         else:
@@ -273,7 +280,7 @@ class VolumeController(wsgi.Controller):
 
         return retval
 
-    def _get_volume_search_options(self):
+    def _get_volume_filter_options(self):
         """Return volume search options allowed by non-admin."""
         return ('name', 'status')
 
@@ -321,16 +328,16 @@ def create_resource(ext_mgr):
     return wsgi.Resource(VolumeController(ext_mgr))
 
 
-def remove_invalid_options(context, search_options, allowed_search_options):
+def remove_invalid_options(context, filters, allowed_search_options):
     """Remove search options that are not valid for non-admin API/context."""
     if context.is_admin:
         # Allow all options
         return
     # Otherwise, strip out all unknown options
-    unknown_options = [opt for opt in search_options
+    unknown_options = [opt for opt in filters
                        if opt not in allowed_search_options]
     bad_options = ", ".join(unknown_options)
-    log_msg = _("Removing options '%(bad_options)s' from query") % locals()
+    log_msg = _("Removing options '%s' from query") % bad_options
     LOG.debug(log_msg)
     for opt in unknown_options:
-        del search_options[opt]
+        del filters[opt]
diff --git a/cinder/common/sqlalchemyutils.py b/cinder/common/sqlalchemyutils.py
new file mode 100755 (executable)
index 0000000..19b7ca9
--- /dev/null
@@ -0,0 +1,128 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2010-2011 OpenStack LLC.
+# Copyright 2012 Justin Santa Barbara
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+"""Implementation of paginate query."""
+
+import sqlalchemy
+
+from cinder import exception
+from cinder.openstack.common import log as logging
+
+
+LOG = logging.getLogger(__name__)
+
+
+# copied from glance/db/sqlalchemy/api.py
+def paginate_query(query, model, limit, sort_keys, marker=None,
+                   sort_dir=None, sort_dirs=None):
+    """Returns a query with sorting / pagination criteria added.
+
+    Pagination works by requiring a unique sort_key, specified by sort_keys.
+    (If sort_keys is not unique, then we risk looping through values.)
+    We use the last row in the previous page as the 'marker' for pagination.
+    So we must return values that follow the passed marker in the order.
+    With a single-valued sort_key, this would be easy: sort_key > X.
+    With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
+    the lexicographical ordering:
+    (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
+
+    We also have to cope with different sort_directions.
+
+    Typically, the id of the last row is used as the client-facing pagination
+    marker, then the actual marker object must be fetched from the db and
+    passed in to us as marker.
+
+    :param query: the query object to which we should add paging/sorting
+    :param model: the ORM model class
+    :param limit: maximum number of items to return
+    :param sort_keys: array of attributes by which results should be sorted
+    :param marker: the last item of the previous page; we returns the next
+                    results after this value.
+    :param sort_dir: direction in which results should be sorted (asc, desc)
+    :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
+
+    :rtype: sqlalchemy.orm.query.Query
+    :return: The query with sorting/pagination added.
+    """
+
+    if 'id' not in sort_keys:
+        # TODO(justinsb): If this ever gives a false-positive, check
+        # the actual primary key, rather than assuming its id
+        LOG.warn(_('Id not in sort_keys; is sort_keys unique?'))
+
+    assert(not (sort_dir and sort_dirs))
+
+    # Default the sort direction to ascending
+    if sort_dirs is None and sort_dir is None:
+        sort_dir = 'asc'
+
+    # Ensure a per-column sort direction
+    if sort_dirs is None:
+        sort_dirs = [sort_dir for _sort_key in sort_keys]
+
+    assert(len(sort_dirs) == len(sort_keys))
+
+    # Add sorting
+    for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
+        sort_dir_func = {
+            'asc': sqlalchemy.asc,
+            'desc': sqlalchemy.desc,
+        }[current_sort_dir]
+
+        try:
+            sort_key_attr = getattr(model, current_sort_key)
+        except AttributeError:
+            raise exception.InvalidInput(reason='Invalid sort key')
+        query = query.order_by(sort_dir_func(sort_key_attr))
+
+    # Add pagination
+    if marker is not None:
+        marker_values = []
+        for sort_key in sort_keys:
+            v = getattr(marker, sort_key)
+            marker_values.append(v)
+
+        # Build up an array of sort criteria as in the docstring
+        criteria_list = []
+        for i in xrange(0, len(sort_keys)):
+            crit_attrs = []
+            for j in xrange(0, i):
+                model_attr = getattr(model, sort_keys[j])
+                crit_attrs.append((model_attr == marker_values[j]))
+
+            model_attr = getattr(model, sort_keys[i])
+            if sort_dirs[i] == 'desc':
+                crit_attrs.append((model_attr < marker_values[i]))
+            elif sort_dirs[i] == 'asc':
+                crit_attrs.append((model_attr > marker_values[i]))
+            else:
+                raise ValueError(_("Unknown sort direction, "
+                                   "must be 'desc' or 'asc'"))
+
+            criteria = sqlalchemy.sql.and_(*crit_attrs)
+            criteria_list.append(criteria)
+
+        f = sqlalchemy.sql.or_(*criteria_list)
+        query = query.filter(f)
+
+    if limit is not None:
+        query = query.limit(limit)
+
+    return query
index 0f5e11fcd51aed21ca17da84e9246cfccb5271e3..7aa00ab787257152536c65f4460e5db7d6433065 100644 (file)
@@ -229,9 +229,9 @@ def volume_get(context, volume_id):
     return IMPL.volume_get(context, volume_id)
 
 
-def volume_get_all(context):
+def volume_get_all(context, marker, limit, sort_key, sort_dir):
     """Get all volumes."""
-    return IMPL.volume_get_all(context)
+    return IMPL.volume_get_all(context, marker, limit, sort_key, sort_dir)
 
 
 def volume_get_all_by_host(context, host):
@@ -244,9 +244,11 @@ def volume_get_all_by_instance_uuid(context, instance_uuid):
     return IMPL.volume_get_all_by_instance_uuid(context, instance_uuid)
 
 
-def volume_get_all_by_project(context, project_id):
+def volume_get_all_by_project(context, project_id, marker, limit, sort_key,
+                              sort_dir):
     """Get all volumes belonging to a project."""
-    return IMPL.volume_get_all_by_project(context, project_id)
+    return IMPL.volume_get_all_by_project(context, project_id, marker, limit,
+                                          sort_key, sort_dir)
 
 
 def volume_get_iscsi_target_num(context, volume_id):
index 78033be7eaa3a2905196262b9a817a79c955a3f8..aef2adf439a6cae6cb7db900b102e2544d583be0 100644 (file)
@@ -29,6 +29,7 @@ from sqlalchemy.orm import joinedload
 from sqlalchemy.sql.expression import literal_column
 from sqlalchemy.sql import func
 
+from cinder.common import sqlalchemyutils
 from cinder import db
 from cinder.db.sqlalchemy import models
 from cinder.db.sqlalchemy.session import get_session
@@ -1022,8 +1023,19 @@ def volume_get(context, volume_id, session=None):
 
 
 @require_admin_context
-def volume_get_all(context):
-    return _volume_get_query(context).all()
+def volume_get_all(context, marker, limit, sort_key, sort_dir):
+    query = _volume_get_query(context)
+
+    marker_volume = None
+    if marker is not None:
+        marker_volume = volume_get(context, marker)
+
+    query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
+                                           [sort_key, 'created_at', 'id'],
+                                           marker=marker_volume,
+                                           sort_dir=sort_dir)
+
+    return query.all()
 
 
 @require_admin_context
@@ -1046,9 +1058,21 @@ def volume_get_all_by_instance_uuid(context, instance_uuid):
 
 
 @require_context
-def volume_get_all_by_project(context, project_id):
+def volume_get_all_by_project(context, project_id, marker, limit, sort_key,
+                              sort_dir):
     authorize_project_context(context, project_id)
-    return _volume_get_query(context).filter_by(project_id=project_id).all()
+    query = _volume_get_query(context).filter_by(project_id=project_id)
+
+    marker_volume = None
+    if marker is not None:
+        marker_volume = volume_get(context, marker)
+
+    query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
+                                           [sort_key, 'created_at', 'id'],
+                                           marker=marker_volume,
+                                           sort_dir=sort_dir)
+
+    return query.all()
 
 
 @require_admin_context
index fd5e7f91340677d6cdb65c35a5686d4dbc39b4f3..3c15010b1fd2269932126228ff2626b77ed08bd5 100644 (file)
@@ -331,7 +331,8 @@ class VolumeApiTest(test.TestCase):
         self.assertEqual(res_dict, expected)
 
     def test_volume_list_by_name(self):
-        def stub_volume_get_all_by_project(context, project_id):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
             return [
                 stubs.stub_volume(1, display_name='vol1'),
                 stubs.stub_volume(2, display_name='vol2'),
@@ -355,7 +356,8 @@ class VolumeApiTest(test.TestCase):
         self.assertEqual(len(resp['volumes']), 0)
 
     def test_volume_list_by_status(self):
-        def stub_volume_get_all_by_project(context, project_id):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
             return [
                 stubs.stub_volume(1, display_name='vol1', status='available'),
                 stubs.stub_volume(2, display_name='vol2', status='available'),
index 2d8d1403268fac10bf16da50d6e0d558481e47f4..307459142900992c7c9c878c53775cec2cec0bd4 100644 (file)
@@ -91,13 +91,15 @@ def stub_volume_get_notfound(self, context, volume_id):
     raise exc.NotFound
 
 
-def stub_volume_get_all(context, search_opts=None):
+def stub_volume_get_all(context, search_opts=None, marker=None, limit=None,
+                        sort_key='created_at', sort_dir='desc'):
     return [stub_volume(100, project_id='fake'),
             stub_volume(101, project_id='superfake'),
             stub_volume(102, project_id='superduperfake')]
 
 
-def stub_volume_get_all_by_project(self, context, search_opts=None):
+def stub_volume_get_all_by_project(self, context, marker, limit, sort_key,
+                                   sort_dir, filters={}):
     return [stub_volume_get(self, context, '1')]
 
 
index 0e22fcb4491373ada623fa612a3008563214355d..45d35025d8b29682fcbadca0cd4e33ed590fdc1d 100644 (file)
@@ -392,8 +392,91 @@ class VolumeApiTest(test.TestCase):
         }
         self.assertEqual(res_dict, expected)
 
+    def test_volume_index_with_marker(self):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
+            return [
+                stubs.stub_volume(1, display_name='vol1'),
+                stubs.stub_volume(2, display_name='vol2'),
+            ]
+        self.stubs.Set(db, 'volume_get_all_by_project',
+                       stub_volume_get_all_by_project)
+        req = fakes.HTTPRequest.blank('/v2/volumes?marker=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 2)
+        self.assertEquals(volumes[0]['id'], 1)
+        self.assertEquals(volumes[1]['id'], 2)
+
+    def test_volume_index_limit(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes?limit=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 1)
+
+    def test_volume_index_limit_negative(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes?limit=-1')
+        self.assertRaises(webob.exc.HTTPBadRequest,
+                          self.controller.index,
+                          req)
+
+    def test_volume_index_limit_non_int(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes?limit=a')
+        self.assertRaises(webob.exc.HTTPBadRequest,
+                          self.controller.index,
+                          req)
+
+    def test_volume_index_limit_marker(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes?marker=1&limit=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 1)
+        self.assertEquals(volumes[0]['id'], '1')
+
+    def test_volume_detail_with_marker(self):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
+            return [
+                stubs.stub_volume(1, display_name='vol1'),
+                stubs.stub_volume(2, display_name='vol2'),
+            ]
+        self.stubs.Set(db, 'volume_get_all_by_project',
+                       stub_volume_get_all_by_project)
+        req = fakes.HTTPRequest.blank('/v2/volumes/detail?marker=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 2)
+        self.assertEquals(volumes[0]['id'], 1)
+        self.assertEquals(volumes[1]['id'], 2)
+
+    def test_volume_detail_limit(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 1)
+
+    def test_volume_detail_limit_negative(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=-1')
+        self.assertRaises(webob.exc.HTTPBadRequest,
+                          self.controller.index,
+                          req)
+
+    def test_volume_detail_limit_non_int(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=a')
+        self.assertRaises(webob.exc.HTTPBadRequest,
+                          self.controller.index,
+                          req)
+
+    def test_volume_detail_limit_marker(self):
+        req = fakes.HTTPRequest.blank('/v2/volumes/detail?marker=1&limit=1')
+        res_dict = self.controller.index(req)
+        volumes = res_dict['volumes']
+        self.assertEquals(len(volumes), 1)
+        self.assertEquals(volumes[0]['id'], '1')
+
     def test_volume_list_by_name(self):
-        def stub_volume_get_all_by_project(context, project_id):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
             return [
                 stubs.stub_volume(1, display_name='vol1'),
                 stubs.stub_volume(2, display_name='vol2'),
@@ -408,7 +491,6 @@ class VolumeApiTest(test.TestCase):
         self.assertEqual(len(resp['volumes']), 3)
         # filter on name
         req = fakes.HTTPRequest.blank('/v2/volumes?name=vol2')
-        #import pdb; pdb.set_trace()
         resp = self.controller.index(req)
         self.assertEqual(len(resp['volumes']), 1)
         self.assertEqual(resp['volumes'][0]['name'], 'vol2')
@@ -418,7 +500,8 @@ class VolumeApiTest(test.TestCase):
         self.assertEqual(len(resp['volumes']), 0)
 
     def test_volume_list_by_status(self):
-        def stub_volume_get_all_by_project(context, project_id):
+        def stub_volume_get_all_by_project(context, project_id, marker, limit,
+                                           sort_key, sort_dir):
             return [
                 stubs.stub_volume(1, display_name='vol1', status='available'),
                 stubs.stub_volume(2, display_name='vol2', status='available'),
index 5c95233689d4e5aef7749f7ec17c158278a4f7e9..97eb6d29ce2c0f801c475da1614e86fe3119a652 100644 (file)
@@ -23,6 +23,7 @@ Handles all requests relating to volumes.
 import functools
 
 from cinder.db import base
+from cinder.db.sqlalchemy import models
 from cinder import exception
 from cinder import flags
 from cinder.image import glance
@@ -267,21 +268,23 @@ class API(base.Base):
         check_policy(context, 'get', volume)
         return volume
 
-    def get_all(self, context, search_opts=None):
+    def get_all(self, context, marker=None, limit=None, sort_key='created_at',
+                sort_dir='desc', filters={}):
         check_policy(context, 'get_all')
 
-        if search_opts is None:
-            search_opts = {}
-
-        if (context.is_admin and 'all_tenants' in search_opts):
+        if (context.is_admin and 'all_tenants' in filters):
             # Need to remove all_tenants to pass the filtering below.
-            del search_opts['all_tenants']
-            volumes = self.db.volume_get_all(context)
+            del filters['all_tenants']
+            volumes = self.db.volume_get_all(context, marker, limit, sort_key,
+                                             sort_dir)
         else:
             volumes = self.db.volume_get_all_by_project(context,
-                                                        context.project_id)
-        if search_opts:
-            LOG.debug(_("Searching by: %s") % str(search_opts))
+                                                        context.project_id,
+                                                        marker, limit,
+                                                        sort_key, sort_dir)
+
+        if filters:
+            LOG.debug(_("Searching by: %s") % str(filters))
 
             def _check_metadata_match(volume, searchdict):
                 volume_metadata = {}
@@ -301,7 +304,7 @@ class API(base.Base):
             not_found = object()
             for volume in volumes:
                 # go over all filters in the list
-                for opt, values in search_opts.iteritems():
+                for opt, values in filters.iteritems():
                     try:
                         filter_func = filter_mapping[opt]
                     except KeyError:
@@ -312,6 +315,7 @@ class API(base.Base):
                 else:  # did not break out loop
                     result.append(volume)  # volume matches all filters
             volumes = result
+
         return volumes
 
     def get_snapshot(self, context, snapshot_id):