]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Implement volume quota support in Cinder
authorJohn Griffith <john.griffith@solidfire.com>
Thu, 16 Aug 2012 21:52:52 +0000 (15:52 -0600)
committerJohn Griffith <john.griffith@solidfire.com>
Thu, 30 Aug 2012 19:24:26 +0000 (13:24 -0600)
parital fix for bug 1023311
  * To use needs cinderclient https://review.openstack.org/#/c/11509/
  * Updates quota classes with changes in Nova
  * Adds needed quota related DB tables
  * Updates test_quota to reflect changes in Nova
  * Adds absolute limits and empty rate limit functions
  * Updates test/integration/test_volume to make it work w/ above changes

Change-Id: I221c7a9dc51a2bb9bf7228c056f63ba9546cf5f9

20 files changed:
cinder/api/openstack/volume/__init__.py
cinder/api/openstack/volume/contrib/quota_classes.py [new file with mode: 0644]
cinder/api/openstack/volume/contrib/quotas.py [new file with mode: 0644]
cinder/api/openstack/volume/limits.py [new file with mode: 0644]
cinder/api/openstack/volume/schemas/v1.1/limits.rng [new file with mode: 0644]
cinder/api/openstack/volume/views/limits.py [new file with mode: 0644]
cinder/api/openstack/wsgi.py
cinder/db/api.py
cinder/db/sqlalchemy/api.py
cinder/db/sqlalchemy/migrate_repo/versions/002_quota_class.py [new file with mode: 0644]
cinder/db/sqlalchemy/models.py
cinder/exception.py
cinder/quota.py
cinder/tests/api/openstack/fakes.py
cinder/tests/api/openstack/volume/test_limits.py [new file with mode: 0644]
cinder/tests/integrated/test_volumes.py
cinder/tests/test_quota.py
cinder/tests/test_volume.py
cinder/volume/api.py
etc/cinder/policy.json

index 542477ebf3e05a23dbdde4135b149d703b93befe..4bfe9938df0fdaaced7316144cfb55c00c721461 100644 (file)
@@ -22,6 +22,7 @@ WSGI middleware for OpenStack Volume API.
 
 import cinder.api.openstack
 from cinder.api.openstack.volume import extensions
+from cinder.api.openstack.volume import limits
 from cinder.api.openstack.volume import snapshots
 from cinder.api.openstack.volume import types
 from cinder.api.openstack.volume import volumes
@@ -61,3 +62,7 @@ class APIRouter(cinder.api.openstack.APIRouter):
         mapper.resource("snapshot", "snapshots",
                         controller=self.resources['snapshots'],
                         collection={'detail': 'GET'})
+
+        self.resources['limits'] = limits.create_resource()
+        mapper.resource("limit", "limits",
+                        controller=self.resources['limits'])
diff --git a/cinder/api/openstack/volume/contrib/quota_classes.py b/cinder/api/openstack/volume/contrib/quota_classes.py
new file mode 100644 (file)
index 0000000..cbad0e3
--- /dev/null
@@ -0,0 +1,105 @@
+# Copyright 2012 OpenStack LLC.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import webob
+
+from cinder.api.openstack import extensions
+from cinder.api.openstack import wsgi
+from cinder.api.openstack import xmlutil
+from cinder import db
+from cinder import exception
+from cinder import quota
+
+
+QUOTAS = quota.QUOTAS
+
+
+authorize = extensions.extension_authorizer('volume', 'quota_classes')
+
+
+class QuotaClassTemplate(xmlutil.TemplateBuilder):
+    def construct(self):
+        root = xmlutil.TemplateElement('quota_class_set',
+                                       selector='quota_class_set')
+        root.set('id')
+
+        for resource in QUOTAS.resources:
+            elem = xmlutil.SubTemplateElement(root, resource)
+            elem.text = resource
+
+        return xmlutil.MasterTemplate(root, 1)
+
+
+class QuotaClassSetsController(object):
+
+    def _format_quota_set(self, quota_class, quota_set):
+        """Convert the quota object to a result dict"""
+
+        result = dict(id=str(quota_class))
+
+        for resource in QUOTAS.resources:
+            result[resource] = quota_set[resource]
+
+        return dict(quota_class_set=result)
+
+    @wsgi.serializers(xml=QuotaClassTemplate)
+    def show(self, req, id):
+        context = req.environ['cinder.context']
+        authorize(context)
+        try:
+            db.sqlalchemy.api.authorize_quota_class_context(context, id)
+        except exception.NotAuthorized:
+            raise webob.exc.HTTPForbidden()
+
+        return self._format_quota_set(
+            id,
+            QUOTAS.get_class_quotas(context, id)
+            )
+
+    @wsgi.serializers(xml=QuotaClassTemplate)
+    def update(self, req, id, body):
+        context = req.environ['cinder.context']
+        authorize(context)
+        quota_class = id
+        for key in body['quota_class_set'].keys():
+            if key in QUOTAS:
+                value = int(body['quota_class_set'][key])
+                try:
+                    db.quota_class_update(context, quota_class, key, value)
+                except exception.QuotaClassNotFound:
+                    db.quota_class_create(context, quota_class, key, value)
+                except exception.AdminRequired:
+                    raise webob.exc.HTTPForbidden()
+        return {'quota_class_set': QUOTAS.get_class_quotas(context,
+                                                           quota_class)}
+
+
+class Quota_classes(extensions.ExtensionDescriptor):
+    """Quota classes management support"""
+
+    name = "QuotaClasses"
+    alias = "os-quota-class-sets"
+    namespace = ("http://docs.openstack.org/volume/ext/"
+                 "quota-classes-sets/api/v1.1")
+    updated = "2012-03-12T00:00:00+00:00"
+
+    def get_resources(self):
+        resources = []
+
+        res = extensions.ResourceExtension('os-quota-class-sets',
+                                           QuotaClassSetsController())
+        resources.append(res)
+
+        return resources
diff --git a/cinder/api/openstack/volume/contrib/quotas.py b/cinder/api/openstack/volume/contrib/quotas.py
new file mode 100644 (file)
index 0000000..7f00863
--- /dev/null
@@ -0,0 +1,125 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import webob
+
+from cinder.api.openstack import extensions
+from cinder.api.openstack import wsgi
+from cinder.api.openstack import xmlutil
+from cinder import db
+from cinder.db.sqlalchemy import api as sqlalchemy_api
+from cinder import exception
+from cinder import quota
+
+
+QUOTAS = quota.QUOTAS
+
+
+authorize_update = extensions.extension_authorizer('compute', 'quotas:update')
+authorize_show = extensions.extension_authorizer('compute', 'quotas:show')
+
+
+class QuotaTemplate(xmlutil.TemplateBuilder):
+    def construct(self):
+        root = xmlutil.TemplateElement('quota_set', selector='quota_set')
+        root.set('id')
+
+        for resource in QUOTAS.resources:
+            elem = xmlutil.SubTemplateElement(root, resource)
+            elem.text = resource
+
+        return xmlutil.MasterTemplate(root, 1)
+
+
+class QuotaSetsController(object):
+
+    def _format_quota_set(self, project_id, quota_set):
+        """Convert the quota object to a result dict"""
+
+        result = dict(id=str(project_id))
+
+        for resource in QUOTAS.resources:
+            result[resource] = quota_set[resource]
+
+        return dict(quota_set=result)
+
+    def _validate_quota_limit(self, limit):
+        # NOTE: -1 is a flag value for unlimited
+        if limit < -1:
+            msg = _("Quota limit must be -1 or greater.")
+            raise webob.exc.HTTPBadRequest(explanation=msg)
+
+    def _get_quotas(self, context, id, usages=False):
+        values = QUOTAS.get_project_quotas(context, id, usages=usages)
+
+        if usages:
+            return values
+        else:
+            return dict((k, v['limit']) for k, v in values.items())
+
+    @wsgi.serializers(xml=QuotaTemplate)
+    def show(self, req, id):
+        context = req.environ['cinder.context']
+        authorize_show(context)
+        try:
+            sqlalchemy_api.authorize_project_context(context, id)
+        except exception.NotAuthorized:
+            raise webob.exc.HTTPForbidden()
+
+        return self._format_quota_set(id, self._get_quotas(context, id))
+
+    @wsgi.serializers(xml=QuotaTemplate)
+    def update(self, req, id, body):
+        context = req.environ['cinder.context']
+        authorize_update(context)
+        project_id = id
+        for key in body['quota_set'].keys():
+            if key in QUOTAS:
+                value = int(body['quota_set'][key])
+                self._validate_quota_limit(value)
+                try:
+                    db.quota_update(context, project_id, key, value)
+                except exception.ProjectQuotaNotFound:
+                    db.quota_create(context, project_id, key, value)
+                except exception.AdminRequired:
+                    raise webob.exc.HTTPForbidden()
+        return {'quota_set': self._get_quotas(context, id)}
+
+    @wsgi.serializers(xml=QuotaTemplate)
+    def defaults(self, req, id):
+        context = req.environ['cinder.context']
+        authorize_show(context)
+        return self._format_quota_set(id, QUOTAS.get_defaults(context))
+
+
+class Quotas(extensions.ExtensionDescriptor):
+    """Quotas management support"""
+
+    name = "Quotas"
+    alias = "os-quota-sets"
+    namespace = "http://docs.openstack.org/compute/ext/quotas-sets/api/v1.1"
+    updated = "2011-08-08T00:00:00+00:00"
+
+    def get_resources(self):
+        resources = []
+
+        res = extensions.ResourceExtension('os-quota-sets',
+                                            QuotaSetsController(),
+                                            member_actions={'defaults': 'GET'})
+        resources.append(res)
+
+        return resources
diff --git a/cinder/api/openstack/volume/limits.py b/cinder/api/openstack/volume/limits.py
new file mode 100644 (file)
index 0000000..a9e62dd
--- /dev/null
@@ -0,0 +1,482 @@
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+"""
+Module dedicated functions/classes dealing with rate limiting requests.
+"""
+
+import collections
+import copy
+import httplib
+import math
+import re
+import time
+
+import webob.dec
+import webob.exc
+
+from cinder.api.openstack.volume.views import limits as limits_views
+from cinder.api.openstack import wsgi
+from cinder.api.openstack import xmlutil
+from cinder.openstack.common import importutils
+from cinder.openstack.common import jsonutils
+from cinder import quota
+from cinder import wsgi as base_wsgi
+
+QUOTAS = quota.QUOTAS
+
+
+# Convenience constants for the limits dictionary passed to Limiter().
+PER_SECOND = 1
+PER_MINUTE = 60
+PER_HOUR = 60 * 60
+PER_DAY = 60 * 60 * 24
+
+
+limits_nsmap = {None: xmlutil.XMLNS_COMMON_V10, 'atom': xmlutil.XMLNS_ATOM}
+
+
+class LimitsTemplate(xmlutil.TemplateBuilder):
+    def construct(self):
+        root = xmlutil.TemplateElement('limits', selector='limits')
+
+        rates = xmlutil.SubTemplateElement(root, 'rates')
+        rate = xmlutil.SubTemplateElement(rates, 'rate', selector='rate')
+        rate.set('uri', 'uri')
+        rate.set('regex', 'regex')
+        limit = xmlutil.SubTemplateElement(rate, 'limit', selector='limit')
+        limit.set('value', 'value')
+        limit.set('verb', 'verb')
+        limit.set('remaining', 'remaining')
+        limit.set('unit', 'unit')
+        limit.set('next-available', 'next-available')
+
+        absolute = xmlutil.SubTemplateElement(root, 'absolute',
+                                              selector='absolute')
+        limit = xmlutil.SubTemplateElement(absolute, 'limit',
+                                           selector=xmlutil.get_items)
+        limit.set('name', 0)
+        limit.set('value', 1)
+
+        return xmlutil.MasterTemplate(root, 1, nsmap=limits_nsmap)
+
+
+class LimitsController(object):
+    """
+    Controller for accessing limits in the OpenStack API.
+    """
+
+    @wsgi.serializers(xml=LimitsTemplate)
+    def index(self, req):
+        """
+        Return all global and rate limit information.
+        """
+        context = req.environ['cinder.context']
+        quotas = QUOTAS.get_project_quotas(context, context.project_id,
+                                           usages=False)
+        abs_limits = dict((k, v['limit']) for k, v in quotas.items())
+        rate_limits = req.environ.get("cinder.limits", [])
+
+        builder = self._get_view_builder(req)
+        return builder.build(rate_limits, abs_limits)
+
+    def _get_view_builder(self, req):
+        return limits_views.ViewBuilder()
+
+
+def create_resource():
+    return wsgi.Resource(LimitsController())
+
+
+class Limit(object):
+    """
+    Stores information about a limit for HTTP requests.
+    """
+
+    UNITS = {
+        1: "SECOND",
+        60: "MINUTE",
+        60 * 60: "HOUR",
+        60 * 60 * 24: "DAY",
+    }
+
+    UNIT_MAP = dict([(v, k) for k, v in UNITS.items()])
+
+    def __init__(self, verb, uri, regex, value, unit):
+        """
+        Initialize a new `Limit`.
+
+        @param verb: HTTP verb (POST, PUT, etc.)
+        @param uri: Human-readable URI
+        @param regex: Regular expression format for this limit
+        @param value: Integer number of requests which can be made
+        @param unit: Unit of measure for the value parameter
+        """
+        self.verb = verb
+        self.uri = uri
+        self.regex = regex
+        self.value = int(value)
+        self.unit = unit
+        self.unit_string = self.display_unit().lower()
+        self.remaining = int(value)
+
+        if value <= 0:
+            raise ValueError("Limit value must be > 0")
+
+        self.last_request = None
+        self.next_request = None
+
+        self.water_level = 0
+        self.capacity = self.unit
+        self.request_value = float(self.capacity) / float(self.value)
+        msg = _("Only %(value)s %(verb)s request(s) can be "
+                "made to %(uri)s every %(unit_string)s.")
+        self.error_message = msg % self.__dict__
+
+    def __call__(self, verb, url):
+        """
+        Represents a call to this limit from a relevant request.
+
+        @param verb: string http verb (POST, GET, etc.)
+        @param url: string URL
+        """
+        if self.verb != verb or not re.match(self.regex, url):
+            return
+
+        now = self._get_time()
+
+        if self.last_request is None:
+            self.last_request = now
+
+        leak_value = now - self.last_request
+
+        self.water_level -= leak_value
+        self.water_level = max(self.water_level, 0)
+        self.water_level += self.request_value
+
+        difference = self.water_level - self.capacity
+
+        self.last_request = now
+
+        if difference > 0:
+            self.water_level -= self.request_value
+            self.next_request = now + difference
+            return difference
+
+        cap = self.capacity
+        water = self.water_level
+        val = self.value
+
+        self.remaining = math.floor(((cap - water) / cap) * val)
+        self.next_request = now
+
+    def _get_time(self):
+        """Retrieve the current time. Broken out for testability."""
+        return time.time()
+
+    def display_unit(self):
+        """Display the string name of the unit."""
+        return self.UNITS.get(self.unit, "UNKNOWN")
+
+    def display(self):
+        """Return a useful representation of this class."""
+        return {
+            "verb": self.verb,
+            "URI": self.uri,
+            "regex": self.regex,
+            "value": self.value,
+            "remaining": int(self.remaining),
+            "unit": self.display_unit(),
+            "resetTime": int(self.next_request or self._get_time()),
+        }
+
+# "Limit" format is a dictionary with the HTTP verb, human-readable URI,
+# a regular-expression to match, value and unit of measure (PER_DAY, etc.)
+
+DEFAULT_LIMITS = [
+    Limit("POST", "*", ".*", 10, PER_MINUTE),
+    Limit("POST", "*/servers", "^/servers", 50, PER_DAY),
+    Limit("PUT", "*", ".*", 10, PER_MINUTE),
+    Limit("GET", "*changes-since*", ".*changes-since.*", 3, PER_MINUTE),
+    Limit("DELETE", "*", ".*", 100, PER_MINUTE),
+]
+
+
+class RateLimitingMiddleware(base_wsgi.Middleware):
+    """
+    Rate-limits requests passing through this middleware. All limit information
+    is stored in memory for this implementation.
+    """
+
+    def __init__(self, application, limits=None, limiter=None, **kwargs):
+        """
+        Initialize new `RateLimitingMiddleware`, which wraps the given WSGI
+        application and sets up the given limits.
+
+        @param application: WSGI application to wrap
+        @param limits: String describing limits
+        @param limiter: String identifying class for representing limits
+
+        Other parameters are passed to the constructor for the limiter.
+        """
+        base_wsgi.Middleware.__init__(self, application)
+
+        # Select the limiter class
+        if limiter is None:
+            limiter = Limiter
+        else:
+            limiter = importutils.import_class(limiter)
+
+        # Parse the limits, if any are provided
+        if limits is not None:
+            limits = limiter.parse_limits(limits)
+
+        self._limiter = limiter(limits or DEFAULT_LIMITS, **kwargs)
+
+    @webob.dec.wsgify(RequestClass=wsgi.Request)
+    def __call__(self, req):
+        """
+        Represents a single call through this middleware. We should record the
+        request if we have a limit relevant to it. If no limit is relevant to
+        the request, ignore it.
+
+        If the request should be rate limited, return a fault telling the user
+        they are over the limit and need to retry later.
+        """
+        verb = req.method
+        url = req.url
+        context = req.environ.get("cinder.context")
+
+        if context:
+            username = context.user_id
+        else:
+            username = None
+
+        delay, error = self._limiter.check_for_delay(verb, url, username)
+
+        if delay:
+            msg = _("This request was rate-limited.")
+            retry = time.time() + delay
+            return wsgi.OverLimitFault(msg, error, retry)
+
+        req.environ["cinder.limits"] = self._limiter.get_limits(username)
+
+        return self.application
+
+
+class Limiter(object):
+    """
+    Rate-limit checking class which handles limits in memory.
+    """
+
+    def __init__(self, limits, **kwargs):
+        """
+        Initialize the new `Limiter`.
+
+        @param limits: List of `Limit` objects
+        """
+        self.limits = copy.deepcopy(limits)
+        self.levels = collections.defaultdict(lambda: copy.deepcopy(limits))
+
+        # Pick up any per-user limit information
+        for key, value in kwargs.items():
+            if key.startswith('user:'):
+                username = key[5:]
+                self.levels[username] = self.parse_limits(value)
+
+    def get_limits(self, username=None):
+        """
+        Return the limits for a given user.
+        """
+        return [limit.display() for limit in self.levels[username]]
+
+    def check_for_delay(self, verb, url, username=None):
+        """
+        Check the given verb/user/user triplet for limit.
+
+        @return: Tuple of delay (in seconds) and error message (or None, None)
+        """
+        delays = []
+
+        for limit in self.levels[username]:
+            delay = limit(verb, url)
+            if delay:
+                delays.append((delay, limit.error_message))
+
+        if delays:
+            delays.sort()
+            return delays[0]
+
+        return None, None
+
+    # Note: This method gets called before the class is instantiated,
+    # so this must be either a static method or a class method.  It is
+    # used to develop a list of limits to feed to the constructor.  We
+    # put this in the class so that subclasses can override the
+    # default limit parsing.
+    @staticmethod
+    def parse_limits(limits):
+        """
+        Convert a string into a list of Limit instances.  This
+        implementation expects a semicolon-separated sequence of
+        parenthesized groups, where each group contains a
+        comma-separated sequence consisting of HTTP method,
+        user-readable URI, a URI reg-exp, an integer number of
+        requests which can be made, and a unit of measure.  Valid
+        values for the latter are "SECOND", "MINUTE", "HOUR", and
+        "DAY".
+
+        @return: List of Limit instances.
+        """
+
+        # Handle empty limit strings
+        limits = limits.strip()
+        if not limits:
+            return []
+
+        # Split up the limits by semicolon
+        result = []
+        for group in limits.split(';'):
+            group = group.strip()
+            if group[:1] != '(' or group[-1:] != ')':
+                raise ValueError("Limit rules must be surrounded by "
+                                 "parentheses")
+            group = group[1:-1]
+
+            # Extract the Limit arguments
+            args = [a.strip() for a in group.split(',')]
+            if len(args) != 5:
+                raise ValueError("Limit rules must contain the following "
+                                 "arguments: verb, uri, regex, value, unit")
+
+            # Pull out the arguments
+            verb, uri, regex, value, unit = args
+
+            # Upper-case the verb
+            verb = verb.upper()
+
+            # Convert value--raises ValueError if it's not integer
+            value = int(value)
+
+            # Convert unit
+            unit = unit.upper()
+            if unit not in Limit.UNIT_MAP:
+                raise ValueError("Invalid units specified")
+            unit = Limit.UNIT_MAP[unit]
+
+            # Build a limit
+            result.append(Limit(verb, uri, regex, value, unit))
+
+        return result
+
+
+class WsgiLimiter(object):
+    """
+    Rate-limit checking from a WSGI application. Uses an in-memory `Limiter`.
+
+    To use, POST ``/<username>`` with JSON data such as::
+
+        {
+            "verb" : GET,
+            "path" : "/servers"
+        }
+
+    and receive a 204 No Content, or a 403 Forbidden with an X-Wait-Seconds
+    header containing the number of seconds to wait before the action would
+    succeed.
+    """
+
+    def __init__(self, limits=None):
+        """
+        Initialize the new `WsgiLimiter`.
+
+        @param limits: List of `Limit` objects
+        """
+        self._limiter = Limiter(limits or DEFAULT_LIMITS)
+
+    @webob.dec.wsgify(RequestClass=wsgi.Request)
+    def __call__(self, request):
+        """
+        Handles a call to this application. Returns 204 if the request is
+        acceptable to the limiter, else a 403 is returned with a relevant
+        header indicating when the request *will* succeed.
+        """
+        if request.method != "POST":
+            raise webob.exc.HTTPMethodNotAllowed()
+
+        try:
+            info = dict(jsonutils.loads(request.body))
+        except ValueError:
+            raise webob.exc.HTTPBadRequest()
+
+        username = request.path_info_pop()
+        verb = info.get("verb")
+        path = info.get("path")
+
+        delay, error = self._limiter.check_for_delay(verb, path, username)
+
+        if delay:
+            headers = {"X-Wait-Seconds": "%.2f" % delay}
+            return webob.exc.HTTPForbidden(headers=headers, explanation=error)
+        else:
+            return webob.exc.HTTPNoContent()
+
+
+class WsgiLimiterProxy(object):
+    """
+    Rate-limit requests based on answers from a remote source.
+    """
+
+    def __init__(self, limiter_address):
+        """
+        Initialize the new `WsgiLimiterProxy`.
+
+        @param limiter_address: IP/port combination of where to request limit
+        """
+        self.limiter_address = limiter_address
+
+    def check_for_delay(self, verb, path, username=None):
+        body = jsonutils.dumps({"verb": verb, "path": path})
+        headers = {"Content-Type": "application/json"}
+
+        conn = httplib.HTTPConnection(self.limiter_address)
+
+        if username:
+            conn.request("POST", "/%s" % (username), body, headers)
+        else:
+            conn.request("POST", "/", body, headers)
+
+        resp = conn.getresponse()
+
+        if 200 >= resp.status < 300:
+            return None, None
+
+        return resp.getheader("X-Wait-Seconds"), resp.read() or None
+
+    # Note: This method gets called before the class is instantiated,
+    # so this must be either a static method or a class method.  It is
+    # used to develop a list of limits to feed to the constructor.
+    # This implementation returns an empty list, since all limit
+    # decisions are made by a remote server.
+    @staticmethod
+    def parse_limits(limits):
+        """
+        Ignore a limits string--simply doesn't apply for the limit
+        proxy.
+
+        @return: Empty list.
+        """
+
+        return []
diff --git a/cinder/api/openstack/volume/schemas/v1.1/limits.rng b/cinder/api/openstack/volume/schemas/v1.1/limits.rng
new file mode 100644 (file)
index 0000000..a66af4b
--- /dev/null
@@ -0,0 +1,28 @@
+<element name="limits" ns="http://docs.openstack.org/common/api/v1.0"
+  xmlns="http://relaxng.org/ns/structure/1.0">
+  <element name="rates">
+    <zeroOrMore>
+      <element name="rate">
+        <attribute name="uri"> <text/> </attribute>
+        <attribute name="regex"> <text/> </attribute>
+        <zeroOrMore>
+          <element name="limit">
+            <attribute name="value"> <text/> </attribute>
+            <attribute name="verb"> <text/> </attribute>
+            <attribute name="remaining"> <text/> </attribute>
+            <attribute name="unit"> <text/> </attribute>
+            <attribute name="next-available"> <text/> </attribute>
+          </element>
+        </zeroOrMore>
+      </element>
+    </zeroOrMore>
+  </element>
+  <element name="absolute">
+    <zeroOrMore>
+      <element name="limit">
+        <attribute name="name"> <text/> </attribute>
+        <attribute name="value"> <text/> </attribute>
+      </element>
+    </zeroOrMore>
+  </element>
+</element>
diff --git a/cinder/api/openstack/volume/views/limits.py b/cinder/api/openstack/volume/views/limits.py
new file mode 100644 (file)
index 0000000..81b1e79
--- /dev/null
@@ -0,0 +1,100 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010-2011 OpenStack LLC.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+import datetime
+
+from cinder.openstack.common import timeutils
+
+
+class ViewBuilder(object):
+    """OpenStack API base limits view builder."""
+
+    def build(self, rate_limits, absolute_limits):
+        rate_limits = self._build_rate_limits(rate_limits)
+        absolute_limits = self._build_absolute_limits(absolute_limits)
+
+        output = {
+            "limits": {
+                "rate": rate_limits,
+                "absolute": absolute_limits,
+            },
+        }
+
+        return output
+
+    def _build_absolute_limits(self, absolute_limits):
+        """Builder for absolute limits
+
+        absolute_limits should be given as a dict of limits.
+        For example: {"ram": 512, "gigabytes": 1024}.
+
+        """
+        limit_names = {
+            "ram": ["maxTotalRAMSize"],
+            "instances": ["maxTotalInstances"],
+            "cores": ["maxTotalCores"],
+            "gigabytes": ["maxTotalVolumeGigabytes"],
+            "volumes": ["maxTotalVolumes"],
+            "key_pairs": ["maxTotalKeypairs"],
+            "floating_ips": ["maxTotalFloatingIps"],
+            "metadata_items": ["maxServerMeta", "maxImageMeta"],
+            "injected_files": ["maxPersonality"],
+            "injected_file_content_bytes": ["maxPersonalitySize"],
+        }
+        limits = {}
+        for name, value in absolute_limits.iteritems():
+            if name in limit_names and value is not None:
+                for name in limit_names[name]:
+                    limits[name] = value
+        return limits
+
+    def _build_rate_limits(self, rate_limits):
+        limits = []
+        for rate_limit in rate_limits:
+            _rate_limit_key = None
+            _rate_limit = self._build_rate_limit(rate_limit)
+
+            # check for existing key
+            for limit in limits:
+                if (limit["uri"] == rate_limit["URI"] and
+                    limit["regex"] == rate_limit["regex"]):
+                    _rate_limit_key = limit
+                    break
+
+            # ensure we have a key if we didn't find one
+            if not _rate_limit_key:
+                _rate_limit_key = {
+                    "uri": rate_limit["URI"],
+                    "regex": rate_limit["regex"],
+                    "limit": [],
+                }
+                limits.append(_rate_limit_key)
+
+            _rate_limit_key["limit"].append(_rate_limit)
+
+        return limits
+
+    def _build_rate_limit(self, rate_limit):
+        _get_utc = datetime.datetime.utcfromtimestamp
+        next_avail = _get_utc(rate_limit["resetTime"])
+        return {
+            "verb": rate_limit["verb"],
+            "value": rate_limit["value"],
+            "remaining": int(rate_limit["remaining"]),
+            "unit": rate_limit["unit"],
+            "next-available": timeutils.isotime(at=next_avail),
+        }
index a44a6fa937145f8e4c14324ef3df111bbcb6226b..f420d41b35a29c81ef8023613f907453c1127898 100644 (file)
 #    under the License.
 
 import inspect
-from xml.dom import minidom
-from xml.parsers import expat
-
-from lxml import etree
+import math
+import time
 import webob
 
 from cinder import exception
+from cinder import wsgi
 from cinder.openstack.common import log as logging
 from cinder.openstack.common import jsonutils
-from cinder import wsgi
+
+from lxml import etree
+from xml.dom import minidom
+from xml.parsers import expat
 
 
 XMLNS_V1 = 'http://docs.openstack.org/volume/api/v1'
-
 XMLNS_ATOM = 'http://www.w3.org/2005/Atom'
 
 LOG = logging.getLogger(__name__)
@@ -1060,3 +1061,50 @@ def _set_request_id_header(req, headers):
     context = req.environ.get('cinder.context')
     if context:
         headers['x-compute-request-id'] = context.request_id
+
+
+class OverLimitFault(webob.exc.HTTPException):
+    """
+    Rate-limited request response.
+    """
+
+    def __init__(self, message, details, retry_time):
+        """
+        Initialize new `OverLimitFault` with relevant information.
+        """
+        hdrs = OverLimitFault._retry_after(retry_time)
+        self.wrapped_exc = webob.exc.HTTPRequestEntityTooLarge(headers=hdrs)
+        self.content = {
+            "overLimitFault": {
+                "code": self.wrapped_exc.status_int,
+                "message": message,
+                "details": details,
+            },
+        }
+
+    @staticmethod
+    def _retry_after(retry_time):
+        delay = int(math.ceil(retry_time - time.time()))
+        retry_after = delay if delay > 0 else 0
+        headers = {'Retry-After': '%d' % retry_after}
+        return headers
+
+    @webob.dec.wsgify(RequestClass=Request)
+    def __call__(self, request):
+        """
+        Return the wrapped exception with a serialized body conforming to our
+        error format.
+        """
+        content_type = request.best_match_content_type()
+        metadata = {"attributes": {"overLimitFault": "code"}}
+
+        xml_serializer = XMLDictSerializer(metadata, XMLNS_V1)
+        serializer = {
+            'application/xml': xml_serializer,
+            'application/json': JSONDictSerializer(),
+        }[content_type]
+
+        content = serializer.serialize(self.content)
+        self.wrapped_exc.body = content
+
+        return self.wrapped_exc
index 00cd0892ae95900289926e2b24348cdc88fba8bb..96bdad03b6a7461f05a9d5d305d1720a195b7238 100644 (file)
@@ -214,9 +214,11 @@ def volume_create(context, values):
     return IMPL.volume_create(context, values)
 
 
-def volume_data_get_for_project(context, project_id):
+def volume_data_get_for_project(context, project_id, session=None):
     """Get (volume_count, gigabytes) for project."""
-    return IMPL.volume_data_get_for_project(context, project_id)
+    return IMPL.volume_data_get_for_project(context,
+                                            project_id,
+                                            session)
 
 
 def volume_destroy(context, volume_id):
index f5c5b4d308e4e36b486806cf76c38c17fb4535c2..8310359cffca05fa8052c47d621680d3235d70d3 100644 (file)
@@ -20,6 +20,7 @@
 """Implementation of SQLAlchemy backend."""
 
 import datetime
+import functools
 import warnings
 
 from cinder import db
@@ -31,7 +32,12 @@ from cinder.db.sqlalchemy import models
 from cinder.db.sqlalchemy.session import get_session
 from cinder.openstack.common import timeutils
 from sqlalchemy.exc import IntegrityError
+from sqlalchemy import or_
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import joinedload_all
+from sqlalchemy.sql.expression import asc
+from sqlalchemy.sql.expression import desc
+from sqlalchemy.sql.expression import literal_column
 from sqlalchemy.sql import func
 from sqlalchemy.sql.expression import literal_column
 
@@ -384,6 +390,514 @@ def iscsi_target_create_safe(context, values):
 ###################
 
 
+@require_context
+def quota_get(context, project_id, resource, session=None):
+    result = model_query(context, models.Quota, session=session,
+                         read_deleted="no").\
+                     filter_by(project_id=project_id).\
+                     filter_by(resource=resource).\
+                     first()
+
+    if not result:
+        raise exception.ProjectQuotaNotFound(project_id=project_id)
+
+    return result
+
+
+@require_context
+def quota_get_all_by_project(context, project_id):
+    authorize_project_context(context, project_id)
+
+    rows = model_query(context, models.Quota, read_deleted="no").\
+                   filter_by(project_id=project_id).\
+                   all()
+
+    result = {'project_id': project_id}
+    for row in rows:
+        result[row.resource] = row.hard_limit
+
+    return result
+
+
+@require_admin_context
+def quota_create(context, project_id, resource, limit):
+    quota_ref = models.Quota()
+    quota_ref.project_id = project_id
+    quota_ref.resource = resource
+    quota_ref.hard_limit = limit
+    quota_ref.save()
+    return quota_ref
+
+
+@require_admin_context
+def quota_update(context, project_id, resource, limit):
+    session = get_session()
+    with session.begin():
+        quota_ref = quota_get(context, project_id, resource, session=session)
+        quota_ref.hard_limit = limit
+        quota_ref.save(session=session)
+
+
+@require_admin_context
+def quota_destroy(context, project_id, resource):
+    session = get_session()
+    with session.begin():
+        quota_ref = quota_get(context, project_id, resource, session=session)
+        quota_ref.delete(session=session)
+
+
+###################
+
+
+@require_context
+def quota_class_get(context, class_name, resource, session=None):
+    result = model_query(context, models.QuotaClass, session=session,
+                         read_deleted="no").\
+                     filter_by(class_name=class_name).\
+                     filter_by(resource=resource).\
+                     first()
+
+    if not result:
+        raise exception.QuotaClassNotFound(class_name=class_name)
+
+    return result
+
+
+@require_context
+def quota_class_get_all_by_name(context, class_name):
+    authorize_quota_class_context(context, class_name)
+
+    rows = model_query(context, models.QuotaClass, read_deleted="no").\
+                   filter_by(class_name=class_name).\
+                   all()
+
+    result = {'class_name': class_name}
+    for row in rows:
+        result[row.resource] = row.hard_limit
+
+    return result
+
+
+@require_admin_context
+def quota_class_create(context, class_name, resource, limit):
+    quota_class_ref = models.QuotaClass()
+    quota_class_ref.class_name = class_name
+    quota_class_ref.resource = resource
+    quota_class_ref.hard_limit = limit
+    quota_class_ref.save()
+    return quota_class_ref
+
+
+@require_admin_context
+def quota_class_update(context, class_name, resource, limit):
+    session = get_session()
+    with session.begin():
+        quota_class_ref = quota_class_get(context, class_name, resource,
+                                          session=session)
+        quota_class_ref.hard_limit = limit
+        quota_class_ref.save(session=session)
+
+
+@require_admin_context
+def quota_class_destroy(context, class_name, resource):
+    session = get_session()
+    with session.begin():
+        quota_class_ref = quota_class_get(context, class_name, resource,
+                                          session=session)
+        quota_class_ref.delete(session=session)
+
+
+@require_admin_context
+def quota_class_destroy_all_by_name(context, class_name):
+    session = get_session()
+    with session.begin():
+        quota_classes = model_query(context, models.QuotaClass,
+                                    session=session, read_deleted="no").\
+                                filter_by(class_name=class_name).\
+                                all()
+
+        for quota_class_ref in quota_classes:
+            quota_class_ref.delete(session=session)
+
+
+###################
+
+
+@require_context
+def quota_usage_get(context, project_id, resource, session=None):
+    result = model_query(context, models.QuotaUsage, session=session,
+                         read_deleted="no").\
+                     filter_by(project_id=project_id).\
+                     filter_by(resource=resource).\
+                     first()
+
+    if not result:
+        raise exception.QuotaUsageNotFound(project_id=project_id)
+
+    return result
+
+
+@require_context
+def quota_usage_get_all_by_project(context, project_id):
+    authorize_project_context(context, project_id)
+
+    rows = model_query(context, models.QuotaUsage, read_deleted="no").\
+                   filter_by(project_id=project_id).\
+                   all()
+
+    result = {'project_id': project_id}
+    for row in rows:
+        result[row.resource] = dict(in_use=row.in_use, reserved=row.reserved)
+
+    return result
+
+
+@require_admin_context
+def quota_usage_create(context, project_id, resource, in_use, reserved,
+                       until_refresh, session=None):
+    quota_usage_ref = models.QuotaUsage()
+    quota_usage_ref.project_id = project_id
+    quota_usage_ref.resource = resource
+    quota_usage_ref.in_use = in_use
+    quota_usage_ref.reserved = reserved
+    quota_usage_ref.until_refresh = until_refresh
+    quota_usage_ref.save(session=session)
+
+    return quota_usage_ref
+
+
+@require_admin_context
+def quota_usage_update(context, project_id, resource, in_use, reserved,
+                       until_refresh, session=None):
+    def do_update(session):
+        quota_usage_ref = quota_usage_get(context, project_id, resource,
+                                          session=session)
+        quota_usage_ref.in_use = in_use
+        quota_usage_ref.reserved = reserved
+        quota_usage_ref.until_refresh = until_refresh
+        quota_usage_ref.save(session=session)
+
+    if session:
+        # Assume caller started a transaction
+        do_update(session)
+    else:
+        session = get_session()
+        with session.begin():
+            do_update(session)
+
+
+@require_admin_context
+def quota_usage_destroy(context, project_id, resource):
+    session = get_session()
+    with session.begin():
+        quota_usage_ref = quota_usage_get(context, project_id, resource,
+                                          session=session)
+        quota_usage_ref.delete(session=session)
+
+
+###################
+
+
+@require_context
+def reservation_get(context, uuid, session=None):
+    result = model_query(context, models.Reservation, session=session,
+                         read_deleted="no").\
+                     filter_by(uuid=uuid).\
+                     first()
+
+    if not result:
+        raise exception.ReservationNotFound(uuid=uuid)
+
+    return result
+
+
+@require_context
+def reservation_get_all_by_project(context, project_id):
+    authorize_project_context(context, project_id)
+
+    rows = model_query(context, models.QuotaUsage, read_deleted="no").\
+                   filter_by(project_id=project_id).\
+                   all()
+
+    result = {'project_id': project_id}
+    for row in rows:
+        result.setdefault(row.resource, {})
+        result[row.resource][row.uuid] = row.delta
+
+    return result
+
+
+@require_admin_context
+def reservation_create(context, uuid, usage, project_id, resource, delta,
+                       expire, session=None):
+    reservation_ref = models.Reservation()
+    reservation_ref.uuid = uuid
+    reservation_ref.usage_id = usage['id']
+    reservation_ref.project_id = project_id
+    reservation_ref.resource = resource
+    reservation_ref.delta = delta
+    reservation_ref.expire = expire
+    reservation_ref.save(session=session)
+    return reservation_ref
+
+
+@require_admin_context
+def reservation_destroy(context, uuid):
+    session = get_session()
+    with session.begin():
+        reservation_ref = reservation_get(context, uuid, session=session)
+        reservation_ref.delete(session=session)
+
+
+###################
+
+
+# NOTE(johannes): The quota code uses SQL locking to ensure races don't
+# cause under or over counting of resources. To avoid deadlocks, this
+# code always acquires the lock on quota_usages before acquiring the lock
+# on reservations.
+
+def _get_quota_usages(context, session):
+    # Broken out for testability
+    rows = model_query(context, models.QuotaUsage,
+                       read_deleted="no",
+                       session=session).\
+                   filter_by(project_id=context.project_id).\
+                   with_lockmode('update').\
+                   all()
+    return dict((row.resource, row) for row in rows)
+
+
+@require_context
+def quota_reserve(context, resources, quotas, deltas, expire,
+                  until_refresh, max_age):
+    elevated = context.elevated()
+    session = get_session()
+    with session.begin():
+        # Get the current usages
+        usages = _get_quota_usages(context, session)
+
+        # Handle usage refresh
+        work = set(deltas.keys())
+        while work:
+            resource = work.pop()
+
+            # Do we need to refresh the usage?
+            refresh = False
+            if resource not in usages:
+                usages[resource] = quota_usage_create(elevated,
+                                                      context.project_id,
+                                                      resource,
+                                                      0, 0,
+                                                      until_refresh or None,
+                                                      session=session)
+                refresh = True
+            elif usages[resource].in_use < 0:
+                # Negative in_use count indicates a desync, so try to
+                # heal from that...
+                refresh = True
+            elif usages[resource].until_refresh is not None:
+                usages[resource].until_refresh -= 1
+                if usages[resource].until_refresh <= 0:
+                    refresh = True
+            elif max_age and (usages[resource].updated_at -
+                              timeutils.utcnow()).seconds >= max_age:
+                refresh = True
+
+            # OK, refresh the usage
+            if refresh:
+                # Grab the sync routine
+                sync = resources[resource].sync
+
+                updates = sync(elevated, context.project_id, session)
+                for res, in_use in updates.items():
+                    # Make sure we have a destination for the usage!
+                    if res not in usages:
+                        usages[res] = quota_usage_create(elevated,
+                                                         context.project_id,
+                                                         res,
+                                                         0, 0,
+                                                         until_refresh or None,
+                                                         session=session)
+
+                    # Update the usage
+                    usages[res].in_use = in_use
+                    usages[res].until_refresh = until_refresh or None
+
+                    # Because more than one resource may be refreshed
+                    # by the call to the sync routine, and we don't
+                    # want to double-sync, we make sure all refreshed
+                    # resources are dropped from the work set.
+                    work.discard(res)
+
+                    # NOTE(Vek): We make the assumption that the sync
+                    #            routine actually refreshes the
+                    #            resources that it is the sync routine
+                    #            for.  We don't check, because this is
+                    #            a best-effort mechanism.
+
+        # Check for deltas that would go negative
+        unders = [resource for resource, delta in deltas.items()
+                  if delta < 0 and
+                  delta + usages[resource].in_use < 0]
+
+        # Now, let's check the quotas
+        # NOTE(Vek): We're only concerned about positive increments.
+        #            If a project has gone over quota, we want them to
+        #            be able to reduce their usage without any
+        #            problems.
+        overs = [resource for resource, delta in deltas.items()
+                 if quotas[resource] >= 0 and delta >= 0 and
+                 quotas[resource] < delta + usages[resource].total]
+
+        # NOTE(Vek): The quota check needs to be in the transaction,
+        #            but the transaction doesn't fail just because
+        #            we're over quota, so the OverQuota raise is
+        #            outside the transaction.  If we did the raise
+        #            here, our usage updates would be discarded, but
+        #            they're not invalidated by being over-quota.
+
+        # Create the reservations
+        if not overs:
+            reservations = []
+            for resource, delta in deltas.items():
+                reservation = reservation_create(elevated,
+                                                 str(utils.gen_uuid()),
+                                                 usages[resource],
+                                                 context.project_id,
+                                                 resource, delta, expire,
+                                                 session=session)
+                reservations.append(reservation.uuid)
+
+                # Also update the reserved quantity
+                # NOTE(Vek): Again, we are only concerned here about
+                #            positive increments.  Here, though, we're
+                #            worried about the following scenario:
+                #
+                #            1) User initiates resize down.
+                #            2) User allocates a new instance.
+                #            3) Resize down fails or is reverted.
+                #            4) User is now over quota.
+                #
+                #            To prevent this, we only update the
+                #            reserved value if the delta is positive.
+                if delta > 0:
+                    usages[resource].reserved += delta
+
+        # Apply updates to the usages table
+        for usage_ref in usages.values():
+            usage_ref.save(session=session)
+
+    if unders:
+        LOG.warning(_("Change will make usage less than 0 for the following "
+                      "resources: %(unders)s") % locals())
+    if overs:
+        usages = dict((k, dict(in_use=v['in_use'], reserved=v['reserved']))
+                      for k, v in usages.items())
+        raise exception.OverQuota(overs=sorted(overs), quotas=quotas,
+                                  usages=usages)
+
+    return reservations
+
+
+def _quota_reservations(session, context, reservations):
+    """Return the relevant reservations."""
+
+    # Get the listed reservations
+    return model_query(context, models.Reservation,
+                       read_deleted="no",
+                       session=session).\
+                   filter(models.Reservation.uuid.in_(reservations)).\
+                   with_lockmode('update').\
+                   all()
+
+
+@require_context
+def reservation_commit(context, reservations):
+    session = get_session()
+    with session.begin():
+        usages = _get_quota_usages(context, session)
+
+        for reservation in _quota_reservations(session, context, reservations):
+            usage = usages[reservation.resource]
+            if reservation.delta >= 0:
+                usage.reserved -= reservation.delta
+            usage.in_use += reservation.delta
+
+            reservation.delete(session=session)
+
+        for usage in usages.values():
+            usage.save(session=session)
+
+
+@require_context
+def reservation_rollback(context, reservations):
+    session = get_session()
+    with session.begin():
+        usages = _get_quota_usages(context, session)
+
+        for reservation in _quota_reservations(session, context, reservations):
+            usage = usages[reservation.resource]
+            if reservation.delta >= 0:
+                usage.reserved -= reservation.delta
+
+            reservation.delete(session=session)
+
+        for usage in usages.values():
+            usage.save(session=session)
+
+
+@require_admin_context
+def quota_destroy_all_by_project(context, project_id):
+    session = get_session()
+    with session.begin():
+        quotas = model_query(context, models.Quota, session=session,
+                             read_deleted="no").\
+                         filter_by(project_id=project_id).\
+                         all()
+
+        for quota_ref in quotas:
+            quota_ref.delete(session=session)
+
+        quota_usages = model_query(context, models.QuotaUsage,
+                                   session=session, read_deleted="no").\
+                               filter_by(project_id=project_id).\
+                               all()
+
+        for quota_usage_ref in quota_usages:
+            quota_usage_ref.delete(session=session)
+
+        reservations = model_query(context, models.Reservation,
+                                   session=session, read_deleted="no").\
+                               filter_by(project_id=project_id).\
+                               all()
+
+        for reservation_ref in reservations:
+            reservation_ref.delete(session=session)
+
+
+@require_admin_context
+def reservation_expire(context):
+    session = get_session()
+    with session.begin():
+        current_time = timeutils.utcnow()
+        results = model_query(context, models.Reservation, session=session,
+                              read_deleted="no").\
+                          filter(models.Reservation.expire < current_time).\
+                          all()
+
+        if results:
+            for reservation in results:
+                if reservation.delta >= 0:
+                    reservation.usage.reserved -= reservation.delta
+                    reservation.usage.save(session=session)
+
+                reservation.delete(session=session)
+
+
+###################
+
+
 @require_admin_context
 def volume_allocate_iscsi_target(context, volume_id, host):
     session = get_session()
@@ -447,11 +961,12 @@ def volume_create(context, values):
 
 
 @require_admin_context
-def volume_data_get_for_project(context, project_id):
+def volume_data_get_for_project(context, project_id, session=None):
     result = model_query(context,
                          func.count(models.Volume.id),
                          func.sum(models.Volume.size),
-                         read_deleted="no").\
+                         read_deleted="no",
+                         session=session).\
                      filter_by(project_id=project_id).\
                      first()
 
diff --git a/cinder/db/sqlalchemy/migrate_repo/versions/002_quota_class.py b/cinder/db/sqlalchemy/migrate_repo/versions/002_quota_class.py
new file mode 100644 (file)
index 0000000..3491734
--- /dev/null
@@ -0,0 +1,140 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack LLC.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+from sqlalchemy import Boolean, Column, DateTime
+from sqlalchemy import MetaData, Integer, String, Table, ForeignKey
+
+from cinder.openstack.common import log as logging
+
+LOG = logging.getLogger(__name__)
+
+
+def upgrade(migrate_engine):
+    meta = MetaData()
+    meta.bind = migrate_engine
+
+    # New table
+    quota_classes = Table('quota_classes', meta,
+            Column('created_at', DateTime(timezone=False)),
+            Column('updated_at', DateTime(timezone=False)),
+            Column('deleted_at', DateTime(timezone=False)),
+            Column('deleted', Boolean(create_constraint=True, name=None)),
+            Column('id', Integer(), primary_key=True),
+            Column('class_name',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False), index=True),
+            Column('resource',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False)),
+            Column('hard_limit', Integer(), nullable=True),
+            mysql_engine='InnoDB',
+            mysql_charset='utf8',
+            )
+
+    try:
+        quota_classes.create()
+    except Exception:
+        LOG.error(_("Table |%s| not created!"), repr(quota_classes))
+        raise
+
+    quota_usages = Table('quota_usages', meta,
+            Column('created_at', DateTime(timezone=False)),
+            Column('updated_at', DateTime(timezone=False)),
+            Column('deleted_at', DateTime(timezone=False)),
+            Column('deleted', Boolean(create_constraint=True, name=None)),
+            Column('id', Integer(), primary_key=True),
+            Column('project_id',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False),
+                   index=True),
+            Column('resource',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False)),
+            Column('in_use', Integer(), nullable=False),
+            Column('reserved', Integer(), nullable=False),
+            Column('until_refresh', Integer(), nullable=True),
+            mysql_engine='InnoDB',
+            mysql_charset='utf8',
+            )
+
+    try:
+        quota_usages.create()
+    except Exception:
+        LOG.error(_("Table |%s| not created!"), repr(quota_usages))
+        raise
+
+    reservations = Table('reservations', meta,
+            Column('created_at', DateTime(timezone=False)),
+            Column('updated_at', DateTime(timezone=False)),
+            Column('deleted_at', DateTime(timezone=False)),
+            Column('deleted', Boolean(create_constraint=True, name=None)),
+            Column('id', Integer(), primary_key=True),
+            Column('uuid',
+                   String(length=36, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False), nullable=False),
+            Column('usage_id', Integer(), ForeignKey('quota_usages.id'),
+                   nullable=False),
+            Column('project_id',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False),
+                   index=True),
+            Column('resource',
+                   String(length=255, convert_unicode=True,
+                          assert_unicode=None, unicode_error=None,
+                          _warn_on_bytestring=False)),
+            Column('delta', Integer(), nullable=False),
+            Column('expire', DateTime(timezone=False)),
+            mysql_engine='InnoDB',
+            mysql_charset='utf8',
+            )
+
+    try:
+        reservations.create()
+    except Exception:
+        LOG.error(_("Table |%s| not created!"), repr(reservations))
+        raise
+
+
+def downgrade(migrate_engine):
+    meta = MetaData()
+    meta.bind = migrate_engine
+
+    quota_classes = Table('quota_classes', meta, autoload=True)
+    try:
+        quota_classes.drop()
+    except Exception:
+        LOG.error(_("quota_classes table not dropped"))
+        raise
+
+    quota_usages = Table('quota_usages', meta, autoload=True)
+    try:
+        quota_usages.drop()
+    except Exception:
+        LOG.error(_("quota_usages table not dropped"))
+        raise
+
+    reservations = Table('reservations', meta, autoload=True)
+    try:
+        reservations.drop()
+    except Exception:
+        LOG.error(_("reservations table not dropped"))
+        raise
index 120ccb405b28eec0b3d4a3d043ead91d5b3b1c0f..ae2a6f51f96bf7371af89726933afc3f40958a36 100644 (file)
@@ -241,8 +241,43 @@ class QuotaClass(BASE, CinderBase):
     hard_limit = Column(Integer, nullable=True)
 
 
+class QuotaUsage(BASE, CinderBase):
+    """Represents the current usage for a given resource."""
+
+    __tablename__ = 'quota_usages'
+    id = Column(Integer, primary_key=True)
+
+    project_id = Column(String(255), index=True)
+    resource = Column(String(255))
+
+    in_use = Column(Integer)
+    reserved = Column(Integer)
+
+    @property
+    def total(self):
+        return self.in_use + self.reserved
+
+    until_refresh = Column(Integer, nullable=True)
+
+
+class Reservation(BASE, CinderBase):
+    """Represents a resource reservation for quotas."""
+
+    __tablename__ = 'reservations'
+    id = Column(Integer, primary_key=True)
+    uuid = Column(String(36), nullable=False)
+
+    usage_id = Column(Integer, ForeignKey('quota_usages.id'), nullable=False)
+
+    project_id = Column(String(255), index=True)
+    resource = Column(String(255))
+
+    delta = Column(Integer)
+    expire = Column(DateTime, nullable=False)
+
+
 class Snapshot(BASE, CinderBase):
-    """Represents a block storage device that can be attached to a vm."""
+    """Represents a block storage device that can be attached to a VM."""
     __tablename__ = 'snapshots'
     id = Column(String(36), primary_key=True)
 
index f624173d5bf2828849cc33b8bbaeea8d999cc5b5..65ecae72493dbeeed5dbba808d556da35b7ccbad 100644 (file)
@@ -295,10 +295,23 @@ class HostBinaryNotFound(NotFound):
     message = _("Could not find binary %(binary)s on host %(host)s.")
 
 
+class InvalidReservationExpiration(Invalid):
+    message = _("Invalid reservation expiration %(expire)s.")
+
+
+class InvalidQuotaValue(Invalid):
+    message = _("Change would make usage less than 0 for the following "
+                "resources: %(unders)s")
+
+
 class QuotaNotFound(NotFound):
     message = _("Quota could not be found")
 
 
+class QuotaResourceUnknown(QuotaNotFound):
+    message = _("Unknown quota resources %(unknown)s.")
+
+
 class ProjectQuotaNotFound(QuotaNotFound):
     message = _("Quota for project %(project_id)s could not be found.")
 
@@ -307,6 +320,18 @@ class QuotaClassNotFound(QuotaNotFound):
     message = _("Quota class %(class_name)s could not be found.")
 
 
+class QuotaUsageNotFound(QuotaNotFound):
+    message = _("Quota usage for project %(project_id)s could not be found.")
+
+
+class ReservationNotFound(QuotaNotFound):
+    message = _("Quota reservation %(uuid)s could not be found.")
+
+
+class OverQuota(CinderException):
+    message = _("Quota exceeded for resources: %(overs)s")
+
+
 class MigrationNotFound(NotFound):
     message = _("Migration %(migration_id)s could not be found.")
 
@@ -372,6 +397,18 @@ class QuotaError(CinderException):
     safe = True
 
 
+class VolumeSizeExceedsAvailableQuota(QuotaError):
+    message = _("Requested volume exceeds allowed volume size quota")
+
+
+class VolumeSizeExceedsQuota(QuotaError):
+    message = _("Maximum volume size exceeded")
+
+
+class VolumeLimitExceeded(QuotaError):
+    message = _("Maximum number of volumes allowed (%(allowed)d) exceeded")
+
+
 class DuplicateSfVolumeNames(Duplicate):
     message = _("Detected more than one volume with name %(vol_name)s")
 
index 2df6c89a4f257f466b342a2b05060eaab82284dd..8c6057d42ddabb7e02ed3254b99c0478a79d3f6c 100644 (file)
 
 """Quotas for instances, volumes, and floating ips."""
 
+import datetime
+
 from cinder import db
-from cinder.openstack.common import cfg
+from cinder import exception
 from cinder import flags
+from cinder.openstack.common import cfg
+from cinder.openstack.common import importutils
+from cinder.openstack.common import log as logging
+from cinder.openstack.common import timeutils
+
 
+LOG = logging.getLogger(__name__)
 
 quota_opts = [
-    cfg.IntOpt('quota_instances',
-               default=10,
-               help='number of instances allowed per project'),
-    cfg.IntOpt('quota_cores',
-               default=20,
-               help='number of instance cores allowed per project'),
-    cfg.IntOpt('quota_ram',
-               default=50 * 1024,
-               help='megabytes of instance ram allowed per project'),
     cfg.IntOpt('quota_volumes',
                default=10,
                help='number of volumes allowed per project'),
     cfg.IntOpt('quota_gigabytes',
                default=1000,
                help='number of volume gigabytes allowed per project'),
-    cfg.IntOpt('quota_floating_ips',
-               default=10,
-               help='number of floating ips allowed per project'),
-    cfg.IntOpt('quota_metadata_items',
-               default=128,
-               help='number of metadata items allowed per instance'),
-    cfg.IntOpt('quota_injected_files',
-               default=5,
-               help='number of injected files allowed'),
-    cfg.IntOpt('quota_injected_file_content_bytes',
-               default=10 * 1024,
-               help='number of bytes allowed per injected file'),
-    cfg.IntOpt('quota_injected_file_path_bytes',
-               default=255,
-               help='number of bytes allowed per injected file path'),
-    cfg.IntOpt('quota_security_groups',
-               default=10,
-               help='number of security groups per project'),
-    cfg.IntOpt('quota_security_group_rules',
-               default=20,
-               help='number of security rules per security group'),
+    cfg.IntOpt('reservation_expire',
+               default=86400,
+               help='number of seconds until a reservation expires'),
+    cfg.IntOpt('until_refresh',
+               default=0,
+               help='count of reservations until usage is refreshed'),
+    cfg.IntOpt('max_age',
+               default=0,
+               help='number of seconds between subsequent usage refreshes'),
+    cfg.StrOpt('quota_driver',
+               default='cinder.quota.DbQuotaDriver',
+               help='default driver to use for quota checks'),
     ]
 
 FLAGS = flags.FLAGS
 FLAGS.register_opts(quota_opts)
 
 
-quota_resources = ['metadata_items', 'injected_file_content_bytes',
-        'volumes', 'gigabytes', 'ram', 'floating_ips', 'instances',
-        'injected_files', 'cores', 'security_groups', 'security_group_rules']
-
-
-def _get_default_quotas():
-    defaults = {
-        'instances': FLAGS.quota_instances,
-        'cores': FLAGS.quota_cores,
-        'ram': FLAGS.quota_ram,
-        'volumes': FLAGS.quota_volumes,
-        'gigabytes': FLAGS.quota_gigabytes,
-        'floating_ips': FLAGS.quota_floating_ips,
-        'metadata_items': FLAGS.quota_metadata_items,
-        'injected_files': FLAGS.quota_injected_files,
-        'injected_file_content_bytes':
-            FLAGS.quota_injected_file_content_bytes,
-        'security_groups': FLAGS.quota_security_groups,
-        'security_group_rules': FLAGS.quota_security_group_rules,
-    }
-    # -1 in the quota flags means unlimited
-    return defaults
-
-
-def get_class_quotas(context, quota_class, defaults=None):
-    """Update defaults with the quota class values."""
-
-    if not defaults:
-        defaults = _get_default_quotas()
-
-    quota = db.quota_class_get_all_by_name(context, quota_class)
-    for key in defaults.keys():
-        if key in quota:
-            defaults[key] = quota[key]
-
-    return defaults
-
-
-def get_project_quotas(context, project_id):
-    defaults = _get_default_quotas()
-    if context.quota_class:
-        get_class_quotas(context, context.quota_class, defaults)
-    quota = db.quota_get_all_by_project(context, project_id)
-    for key in defaults.keys():
-        if key in quota:
-            defaults[key] = quota[key]
-    return defaults
-
-
-def _get_request_allotment(requested, used, quota):
-    if quota == -1:
-        return requested
-    return quota - used
-
-
-def allowed_instances(context, requested_instances, instance_type):
-    """Check quota and return min(requested_instances, allowed_instances)."""
-    project_id = context.project_id
-    context = context.elevated()
-    requested_cores = requested_instances * instance_type['vcpus']
-    requested_ram = requested_instances * instance_type['memory_mb']
-    usage = db.instance_data_get_for_project(context, project_id)
-    used_instances, used_cores, used_ram = usage
-    quota = get_project_quotas(context, project_id)
-    allowed_instances = _get_request_allotment(requested_instances,
-                                               used_instances,
-                                               quota['instances'])
-    allowed_cores = _get_request_allotment(requested_cores, used_cores,
-                                           quota['cores'])
-    allowed_ram = _get_request_allotment(requested_ram, used_ram, quota['ram'])
-    if instance_type['vcpus']:
-        allowed_instances = min(allowed_instances,
-                                allowed_cores // instance_type['vcpus'])
-    if instance_type['memory_mb']:
-        allowed_instances = min(allowed_instances,
-                                allowed_ram // instance_type['memory_mb'])
-
-    return min(requested_instances, allowed_instances)
-
-
-def allowed_volumes(context, requested_volumes, size):
-    """Check quota and return min(requested_volumes, allowed_volumes)."""
-    project_id = context.project_id
-    context = context.elevated()
-    size = int(size)
-    requested_gigabytes = requested_volumes * size
-    used_volumes, used_gigabytes = db.volume_data_get_for_project(context,
-                                                                  project_id)
-    quota = get_project_quotas(context, project_id)
-    allowed_volumes = _get_request_allotment(requested_volumes, used_volumes,
-                                             quota['volumes'])
-    allowed_gigabytes = _get_request_allotment(requested_gigabytes,
-                                               used_gigabytes,
-                                               quota['gigabytes'])
-    if size != 0:
-        allowed_volumes = min(allowed_volumes,
-                              int(allowed_gigabytes // size))
-    return min(requested_volumes, allowed_volumes)
-
-
-def allowed_floating_ips(context, requested_floating_ips):
-    """Check quota and return min(requested, allowed) floating ips."""
-    project_id = context.project_id
-    context = context.elevated()
-    used_floating_ips = db.floating_ip_count_by_project(context, project_id)
-    quota = get_project_quotas(context, project_id)
-    allowed_floating_ips = _get_request_allotment(requested_floating_ips,
-                                                  used_floating_ips,
-                                                  quota['floating_ips'])
-    return min(requested_floating_ips, allowed_floating_ips)
-
-
-def allowed_security_groups(context, requested_security_groups):
-    """Check quota and return min(requested, allowed) security groups."""
-    project_id = context.project_id
-    context = context.elevated()
-    used_sec_groups = db.security_group_count_by_project(context, project_id)
-    quota = get_project_quotas(context, project_id)
-    allowed_sec_groups = _get_request_allotment(requested_security_groups,
-                                                  used_sec_groups,
-                                                  quota['security_groups'])
-    return min(requested_security_groups, allowed_sec_groups)
-
-
-def allowed_security_group_rules(context, security_group_id,
-        requested_rules):
-    """Check quota and return min(requested, allowed) sec group rules."""
-    project_id = context.project_id
-    context = context.elevated()
-    used_rules = db.security_group_rule_count_by_group(context,
-                                                            security_group_id)
-    quota = get_project_quotas(context, project_id)
-    allowed_rules = _get_request_allotment(requested_rules,
-                                              used_rules,
-                                              quota['security_group_rules'])
-    return min(requested_rules, allowed_rules)
-
-
-def _calculate_simple_quota(context, resource, requested):
-    """Check quota for resource; return min(requested, allowed)."""
-    quota = get_project_quotas(context, context.project_id)
-    allowed = _get_request_allotment(requested, 0, quota[resource])
-    return min(requested, allowed)
-
-
-def allowed_metadata_items(context, requested_metadata_items):
-    """Return the number of metadata items allowed."""
-    return _calculate_simple_quota(context, 'metadata_items',
-                                   requested_metadata_items)
-
-
-def allowed_injected_files(context, requested_injected_files):
-    """Return the number of injected files allowed."""
-    return _calculate_simple_quota(context, 'injected_files',
-                                   requested_injected_files)
-
-
-def allowed_injected_file_content_bytes(context, requested_bytes):
-    """Return the number of bytes allowed per injected file content."""
-    resource = 'injected_file_content_bytes'
-    return _calculate_simple_quota(context, resource, requested_bytes)
-
-
-def allowed_injected_file_path_bytes(context):
-    """Return the number of bytes allowed in an injected file path."""
-    return FLAGS.quota_injected_file_path_bytes
+class DbQuotaDriver(object):
+    """
+    Driver to perform necessary checks to enforce quotas and obtain
+    quota information.  The default driver utilizes the local
+    database.
+    """
+
+    def get_by_project(self, context, project_id, resource):
+        """Get a specific quota by project."""
+
+        return db.quota_get(context, project_id, resource)
+
+    def get_by_class(self, context, quota_class, resource):
+        """Get a specific quota by quota class."""
+
+        return db.quota_class_get(context, quota_class, resource)
+
+    def get_defaults(self, context, resources):
+        """Given a list of resources, retrieve the default quotas.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        """
+
+        quotas = {}
+        for resource in resources.values():
+            quotas[resource.name] = resource.default
+
+        return quotas
+
+    def get_class_quotas(self, context, resources, quota_class,
+                         defaults=True):
+        """
+        Given a list of resources, retrieve the quotas for the given
+        quota class.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        :param quota_class: The name of the quota class to return
+                            quotas for.
+        :param defaults: If True, the default value will be reported
+                         if there is no specific value for the
+                         resource.
+        """
+
+        quotas = {}
+        class_quotas = db.quota_class_get_all_by_name(context, quota_class)
+        for resource in resources.values():
+            if defaults or resource.name in class_quotas:
+                quotas[resource.name] = class_quotas.get(resource.name,
+                                                         resource.default)
+
+        return quotas
+
+    def get_project_quotas(self, context, resources, project_id,
+                           quota_class=None, defaults=True,
+                           usages=True):
+        """
+        Given a list of resources, retrieve the quotas for the given
+        project.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        :param project_id: The ID of the project to return quotas for.
+        :param quota_class: If project_id != context.project_id, the
+                            quota class cannot be determined.  This
+                            parameter allows it to be specified.  It
+                            will be ignored if project_id ==
+                            context.project_id.
+        :param defaults: If True, the quota class value (or the
+                         default value, if there is no value from the
+                         quota class) will be reported if there is no
+                         specific value for the resource.
+        :param usages: If True, the current in_use and reserved counts
+                       will also be returned.
+        """
+
+        quotas = {}
+        project_quotas = db.quota_get_all_by_project(context, project_id)
+        if usages:
+            project_usages = db.quota_usage_get_all_by_project(context,
+                                                               project_id)
+
+        # Get the quotas for the appropriate class.  If the project ID
+        # matches the one in the context, we use the quota_class from
+        # the context, otherwise, we use the provided quota_class (if
+        # any)
+        if project_id == context.project_id:
+            quota_class = context.quota_class
+        if quota_class:
+            class_quotas = db.quota_class_get_all_by_name(context, quota_class)
+        else:
+            class_quotas = {}
+
+        for resource in resources.values():
+            # Omit default/quota class values
+            if not defaults and resource.name not in project_quotas:
+                continue
+
+            quotas[resource.name] = dict(
+                limit=project_quotas.get(resource.name, class_quotas.get(
+                        resource.name, resource.default)),
+                )
+
+            # Include usages if desired.  This is optional because one
+            # internal consumer of this interface wants to access the
+            # usages directly from inside a transaction.
+            if usages:
+                usage = project_usages.get(resource.name, {})
+                quotas[resource.name].update(
+                    in_use=usage.get('in_use', 0),
+                    reserved=usage.get('reserved', 0),
+                    )
+
+        return quotas
+
+    def _get_quotas(self, context, resources, keys, has_sync):
+        """
+        A helper method which retrieves the quotas for the specific
+        resources identified by keys, and which apply to the current
+        context.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        :param keys: A list of the desired quotas to retrieve.
+        :param has_sync: If True, indicates that the resource must
+                         have a sync attribute; if False, indicates
+                         that the resource must NOT have a sync
+                         attribute.
+        """
+
+        # Filter resources
+        if has_sync:
+            sync_filt = lambda x: hasattr(x, 'sync')
+        else:
+            sync_filt = lambda x: not hasattr(x, 'sync')
+        desired = set(keys)
+        sub_resources = dict((k, v) for k, v in resources.items()
+                             if k in desired and sync_filt(v))
+
+        # Make sure we accounted for all of them...
+        if len(keys) != len(sub_resources):
+            unknown = desired - set(sub_resources.keys())
+            raise exception.QuotaResourceUnknown(unknown=sorted(unknown))
+
+        # Grab and return the quotas (without usages)
+        quotas = self.get_project_quotas(context, sub_resources,
+                                         context.project_id,
+                                         context.quota_class, usages=False)
+
+        return dict((k, v['limit']) for k, v in quotas.items())
+
+    def limit_check(self, context, resources, values):
+        """Check simple quota limits.
+
+        For limits--those quotas for which there is no usage
+        synchronization function--this method checks that a set of
+        proposed values are permitted by the limit restriction.
+
+        This method will raise a QuotaResourceUnknown exception if a
+        given resource is unknown or if it is not a simple limit
+        resource.
+
+        If any of the proposed values is over the defined quota, an
+        OverQuota exception will be raised with the sorted list of the
+        resources which are too high.  Otherwise, the method returns
+        nothing.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        :param values: A dictionary of the values to check against the
+                       quota.
+        """
+
+        # Ensure no value is less than zero
+        unders = [key for key, val in values.items() if val < 0]
+        if unders:
+            raise exception.InvalidQuotaValue(unders=sorted(unders))
+
+        # Get the applicable quotas
+        quotas = self._get_quotas(context, resources, values.keys(),
+                                  has_sync=False)
+        # Check the quotas and construct a list of the resources that
+        # would be put over limit by the desired values
+        overs = [key for key, val in values.items()
+                 if quotas[key] >= 0 and quotas[key] < val]
+        if overs:
+            raise exception.OverQuota(overs=sorted(overs), quotas=quotas,
+                                      usages={})
+
+    def reserve(self, context, resources, deltas, expire=None):
+        """Check quotas and reserve resources.
+
+        For counting quotas--those quotas for which there is a usage
+        synchronization function--this method checks quotas against
+        current usage and the desired deltas.
+
+        This method will raise a QuotaResourceUnknown exception if a
+        given resource is unknown or if it does not have a usage
+        synchronization function.
+
+        If any of the proposed values is over the defined quota, an
+        OverQuota exception will be raised with the sorted list of the
+        resources which are too high.  Otherwise, the method returns a
+        list of reservation UUIDs which were created.
+
+        :param context: The request context, for access checks.
+        :param resources: A dictionary of the registered resources.
+        :param deltas: A dictionary of the proposed delta changes.
+        :param expire: An optional parameter specifying an expiration
+                       time for the reservations.  If it is a simple
+                       number, it is interpreted as a number of
+                       seconds and added to the current time; if it is
+                       a datetime.timedelta object, it will also be
+                       added to the current time.  A datetime.datetime
+                       object will be interpreted as the absolute
+                       expiration time.  If None is specified, the
+                       default expiration time set by
+                       --default-reservation-expire will be used (this
+                       value will be treated as a number of seconds).
+        """
+
+        # Set up the reservation expiration
+        if expire is None:
+            expire = FLAGS.reservation_expire
+        if isinstance(expire, (int, long)):
+            expire = datetime.timedelta(seconds=expire)
+        if isinstance(expire, datetime.timedelta):
+            expire = timeutils.utcnow() + expire
+        if not isinstance(expire, datetime.datetime):
+            raise exception.InvalidReservationExpiration(expire=expire)
+
+        # Get the applicable quotas.
+        # NOTE(Vek): We're not worried about races at this point.
+        #            Yes, the admin may be in the process of reducing
+        #            quotas, but that's a pretty rare thing.
+        quotas = self._get_quotas(context, resources, deltas.keys(),
+                                  has_sync=True)
+
+        # NOTE(Vek): Most of the work here has to be done in the DB
+        #            API, because we have to do it in a transaction,
+        #            which means access to the session.  Since the
+        #            session isn't available outside the DBAPI, we
+        #            have to do the work there.
+        return db.quota_reserve(context, resources, quotas, deltas, expire,
+                                FLAGS.until_refresh, FLAGS.max_age)
+
+    def commit(self, context, reservations):
+        """Commit reservations.
+
+        :param context: The request context, for access checks.
+        :param reservations: A list of the reservation UUIDs, as
+                             returned by the reserve() method.
+        """
+
+        db.reservation_commit(context, reservations)
+
+    def rollback(self, context, reservations):
+        """Roll back reservations.
+
+        :param context: The request context, for access checks.
+        :param reservations: A list of the reservation UUIDs, as
+                             returned by the reserve() method.
+        """
+
+        db.reservation_rollback(context, reservations)
+
+    def destroy_all_by_project(self, context, project_id):
+        """
+        Destroy all quotas, usages, and reservations associated with a
+        project.
+
+        :param context: The request context, for access checks.
+        :param project_id: The ID of the project being deleted.
+        """
+
+        db.quota_destroy_all_by_project(context, project_id)
+
+    def expire(self, context):
+        """Expire reservations.
+
+        Explores all currently existing reservations and rolls back
+        any that have expired.
+
+        :param context: The request context, for access checks.
+        """
+
+        db.reservation_expire(context)
+
+
+class BaseResource(object):
+    """Describe a single resource for quota checking."""
+
+    def __init__(self, name, flag=None):
+        """
+        Initializes a Resource.
+
+        :param name: The name of the resource, i.e., "instances".
+        :param flag: The name of the flag or configuration option
+                     which specifies the default value of the quota
+                     for this resource.
+        """
+
+        self.name = name
+        self.flag = flag
+
+    def quota(self, driver, context, **kwargs):
+        """
+        Given a driver and context, obtain the quota for this
+        resource.
+
+        :param driver: A quota driver.
+        :param context: The request context.
+        :param project_id: The project to obtain the quota value for.
+                           If not provided, it is taken from the
+                           context.  If it is given as None, no
+                           project-specific quota will be searched
+                           for.
+        :param quota_class: The quota class corresponding to the
+                            project, or for which the quota is to be
+                            looked up.  If not provided, it is taken
+                            from the context.  If it is given as None,
+                            no quota class-specific quota will be
+                            searched for.  Note that the quota class
+                            defaults to the value in the context,
+                            which may not correspond to the project if
+                            project_id is not the same as the one in
+                            the context.
+        """
+
+        # Get the project ID
+        project_id = kwargs.get('project_id', context.project_id)
+
+        # Ditto for the quota class
+        quota_class = kwargs.get('quota_class', context.quota_class)
+
+        # Look up the quota for the project
+        if project_id:
+            try:
+                return driver.get_by_project(context, project_id, self.name)
+            except exception.ProjectQuotaNotFound:
+                pass
+
+        # Try for the quota class
+        if quota_class:
+            try:
+                return driver.get_by_class(context, quota_class, self.name)
+            except exception.QuotaClassNotFound:
+                pass
+
+        # OK, return the default
+        return self.default
+
+    @property
+    def default(self):
+        """Return the default value of the quota."""
+
+        return FLAGS[self.flag] if self.flag else -1
+
+
+class ReservableResource(BaseResource):
+    """Describe a reservable resource."""
+
+    def __init__(self, name, sync, flag=None):
+        """
+        Initializes a ReservableResource.
+
+        Reservable resources are those resources which directly
+        correspond to objects in the database, i.e., instances, cores,
+        etc.  A ReservableResource must be constructed with a usage
+        synchronization function, which will be called to determine the
+        current counts of one or more resources.
+
+        The usage synchronization function will be passed three
+        arguments: an admin context, the project ID, and an opaque
+        session object, which should in turn be passed to the
+        underlying database function.  Synchronization functions
+        should return a dictionary mapping resource names to the
+        current in_use count for those resources; more than one
+        resource and resource count may be returned.  Note that
+        synchronization functions may be associated with more than one
+        ReservableResource.
+
+        :param name: The name of the resource, i.e., "instances".
+        :param sync: A callable which returns a dictionary to
+                     resynchronize the in_use count for one or more
+                     resources, as described above.
+        :param flag: The name of the flag or configuration option
+                     which specifies the default value of the quota
+                     for this resource.
+        """
+
+        super(ReservableResource, self).__init__(name, flag=flag)
+        self.sync = sync
+
+
+class AbsoluteResource(BaseResource):
+    """Describe a non-reservable resource."""
+
+    pass
+
+
+class CountableResource(AbsoluteResource):
+    """
+    Describe a resource where the counts aren't based solely on the
+    project ID.
+    """
+
+    def __init__(self, name, count, flag=None):
+        """
+        Initializes a CountableResource.
+
+        Countable resources are those resources which directly
+        correspond to objects in the database, i.e., instances, cores,
+        etc., but for which a count by project ID is inappropriate.  A
+        CountableResource must be constructed with a counting
+        function, which will be called to determine the current counts
+        of the resource.
+
+        The counting function will be passed the context, along with
+        the extra positional and keyword arguments that are passed to
+        Quota.count().  It should return an integer specifying the
+        count.
+
+        Note that this counting is not performed in a transaction-safe
+        manner.  This resource class is a temporary measure to provide
+        required functionality, until a better approach to solving
+        this problem can be evolved.
+
+        :param name: The name of the resource, i.e., "instances".
+        :param count: A callable which returns the count of the
+                      resource.  The arguments passed are as described
+                      above.
+        :param flag: The name of the flag or configuration option
+                     which specifies the default value of the quota
+                     for this resource.
+        """
+
+        super(CountableResource, self).__init__(name, flag=flag)
+        self.count = count
+
+
+class QuotaEngine(object):
+    """Represent the set of recognized quotas."""
+
+    def __init__(self, quota_driver_class=None):
+        """Initialize a Quota object."""
+
+        if not quota_driver_class:
+            quota_driver_class = FLAGS.quota_driver
+
+        if isinstance(quota_driver_class, basestring):
+            quota_driver_class = importutils.import_object(quota_driver_class)
+
+        self._resources = {}
+        self._driver = quota_driver_class
+
+    def __contains__(self, resource):
+        return resource in self._resources
+
+    def register_resource(self, resource):
+        """Register a resource."""
+
+        self._resources[resource.name] = resource
+
+    def register_resources(self, resources):
+        """Register a list of resources."""
+
+        for resource in resources:
+            self.register_resource(resource)
+
+    def get_by_project(self, context, project_id, resource):
+        """Get a specific quota by project."""
+
+        return self._driver.get_by_project(context, project_id, resource)
+
+    def get_by_class(self, context, quota_class, resource):
+        """Get a specific quota by quota class."""
+
+        return self._driver.get_by_class(context, quota_class, resource)
+
+    def get_defaults(self, context):
+        """Retrieve the default quotas.
+
+        :param context: The request context, for access checks.
+        """
+
+        return self._driver.get_defaults(context, self._resources)
+
+    def get_class_quotas(self, context, quota_class, defaults=True):
+        """Retrieve the quotas for the given quota class.
+
+        :param context: The request context, for access checks.
+        :param quota_class: The name of the quota class to return
+                            quotas for.
+        :param defaults: If True, the default value will be reported
+                         if there is no specific value for the
+                         resource.
+        """
+
+        return self._driver.get_class_quotas(context, self._resources,
+                                             quota_class, defaults=defaults)
+
+    def get_project_quotas(self, context, project_id, quota_class=None,
+                           defaults=True, usages=True):
+        """Retrieve the quotas for the given project.
+
+        :param context: The request context, for access checks.
+        :param project_id: The ID of the project to return quotas for.
+        :param quota_class: If project_id != context.project_id, the
+                            quota class cannot be determined.  This
+                            parameter allows it to be specified.
+        :param defaults: If True, the quota class value (or the
+                         default value, if there is no value from the
+                         quota class) will be reported if there is no
+                         specific value for the resource.
+        :param usages: If True, the current in_use and reserved counts
+                       will also be returned.
+        """
+
+        return self._driver.get_project_quotas(context, self._resources,
+                                              project_id,
+                                              quota_class=quota_class,
+                                              defaults=defaults,
+                                              usages=usages)
+
+    def count(self, context, resource, *args, **kwargs):
+        """Count a resource.
+
+        For countable resources, invokes the count() function and
+        returns its result.  Arguments following the context and
+        resource are passed directly to the count function declared by
+        the resource.
+
+        :param context: The request context, for access checks.
+        :param resource: The name of the resource, as a string.
+        """
+
+        # Get the resource
+        res = self._resources.get(resource)
+        if not res or not hasattr(res, 'count'):
+            raise exception.QuotaResourceUnknown(unknown=[resource])
+
+        return res.count(context, *args, **kwargs)
+
+    def limit_check(self, context, **values):
+        """Check simple quota limits.
+
+        For limits--those quotas for which there is no usage
+        synchronization function--this method checks that a set of
+        proposed values are permitted by the limit restriction.  The
+        values to check are given as keyword arguments, where the key
+        identifies the specific quota limit to check, and the value is
+        the proposed value.
+
+        This method will raise a QuotaResourceUnknown exception if a
+        given resource is unknown or if it is not a simple limit
+        resource.
+
+        If any of the proposed values is over the defined quota, an
+        OverQuota exception will be raised with the sorted list of the
+        resources which are too high.  Otherwise, the method returns
+        nothing.
+
+        :param context: The request context, for access checks.
+        """
+
+        return self._driver.limit_check(context, self._resources, values)
+
+    def reserve(self, context, expire=None, **deltas):
+        """Check quotas and reserve resources.
+
+        For counting quotas--those quotas for which there is a usage
+        synchronization function--this method checks quotas against
+        current usage and the desired deltas.  The deltas are given as
+        keyword arguments, and current usage and other reservations
+        are factored into the quota check.
+
+        This method will raise a QuotaResourceUnknown exception if a
+        given resource is unknown or if it does not have a usage
+        synchronization function.
+
+        If any of the proposed values is over the defined quota, an
+        OverQuota exception will be raised with the sorted list of the
+        resources which are too high.  Otherwise, the method returns a
+        list of reservation UUIDs which were created.
+
+        :param context: The request context, for access checks.
+        :param expire: An optional parameter specifying an expiration
+                       time for the reservations.  If it is a simple
+                       number, it is interpreted as a number of
+                       seconds and added to the current time; if it is
+                       a datetime.timedelta object, it will also be
+                       added to the current time.  A datetime.datetime
+                       object will be interpreted as the absolute
+                       expiration time.  If None is specified, the
+                       default expiration time set by
+                       --default-reservation-expire will be used (this
+                       value will be treated as a number of seconds).
+        """
+
+        reservations = self._driver.reserve(context, self._resources, deltas,
+                                            expire=expire)
+
+        LOG.debug(_("Created reservations %(reservations)s") % locals())
+
+        return reservations
+
+    def commit(self, context, reservations):
+        """Commit reservations.
+
+        :param context: The request context, for access checks.
+        :param reservations: A list of the reservation UUIDs, as
+                             returned by the reserve() method.
+        """
+
+        try:
+            self._driver.commit(context, reservations)
+        except Exception:
+            # NOTE(Vek): Ignoring exceptions here is safe, because the
+            # usage resynchronization and the reservation expiration
+            # mechanisms will resolve the issue.  The exception is
+            # logged, however, because this is less than optimal.
+            LOG.exception(_("Failed to commit reservations "
+                            "%(reservations)s") % locals())
+
+    def rollback(self, context, reservations):
+        """Roll back reservations.
+
+        :param context: The request context, for access checks.
+        :param reservations: A list of the reservation UUIDs, as
+                             returned by the reserve() method.
+        """
+
+        try:
+            self._driver.rollback(context, reservations)
+        except Exception:
+            # NOTE(Vek): Ignoring exceptions here is safe, because the
+            # usage resynchronization and the reservation expiration
+            # mechanisms will resolve the issue.  The exception is
+            # logged, however, because this is less than optimal.
+            LOG.exception(_("Failed to roll back reservations "
+                            "%(reservations)s") % locals())
+
+    def destroy_all_by_project(self, context, project_id):
+        """
+        Destroy all quotas, usages, and reservations associated with a
+        project.
+
+        :param context: The request context, for access checks.
+        :param project_id: The ID of the project being deleted.
+        """
+
+        self._driver.destroy_all_by_project(context, project_id)
+
+    def expire(self, context):
+        """Expire reservations.
+
+        Explores all currently existing reservations and rolls back
+        any that have expired.
+
+        :param context: The request context, for access checks.
+        """
+
+        self._driver.expire(context)
+
+    @property
+    def resources(self):
+        return sorted(self._resources.keys())
+
+
+def _sync_instances(context, project_id, session):
+    return dict(zip(('instances', 'cores', 'ram'),
+                    db.instance_data_get_for_project(
+                context, project_id, session=session)))
+
+
+def _sync_volumes(context, project_id, session):
+    return dict(zip(('volumes', 'gigabytes'),
+                    db.volume_data_get_for_project(
+                context, project_id, session=session)))
+
+
+QUOTAS = QuotaEngine()
+
+
+resources = [
+    ReservableResource('volumes', _sync_volumes, 'quota_volumes'),
+    ReservableResource('gigabytes', _sync_volumes, 'quota_gigabytes'),
+    ]
+
+
+QUOTAS.register_resources(resources)
index 509c7211a603a6b539f95bbfc6d10673d6b0093e..2c1426c331fa4f8afc7a8e1a7ec478fea3ccbb82 100644 (file)
@@ -25,6 +25,7 @@ import webob.request
 from cinder.api import auth as api_auth
 from cinder.api import openstack as openstack_api
 from cinder.api.openstack import auth
+from cinder.api.openstack.volume import limits
 from cinder.api.openstack import urlmap
 from cinder.api.openstack import volume
 from cinder.api.openstack.volume import versions
diff --git a/cinder/tests/api/openstack/volume/test_limits.py b/cinder/tests/api/openstack/volume/test_limits.py
new file mode 100644 (file)
index 0000000..aaa9eb8
--- /dev/null
@@ -0,0 +1,896 @@
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+#    Licensed under the Apache License, Version 2.0 (the "License"); you may
+#    not use this file except in compliance with the License. You may obtain
+#    a copy of the License at
+#
+#         http://www.apache.org/licenses/LICENSE-2.0
+#
+#    Unless required by applicable law or agreed to in writing, software
+#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+#    License for the specific language governing permissions and limitations
+#    under the License.
+
+"""
+Tests dealing with HTTP rate-limiting.
+"""
+
+import httplib
+import StringIO
+from xml.dom import minidom
+
+from lxml import etree
+import webob
+
+from cinder.api.openstack.volume import limits
+from cinder.api.openstack.volume import views
+from cinder.api.openstack import xmlutil
+import cinder.context
+from cinder.openstack.common import jsonutils
+from cinder import test
+
+
+TEST_LIMITS = [
+    limits.Limit("GET", "/delayed", "^/delayed", 1, limits.PER_MINUTE),
+    limits.Limit("POST", "*", ".*", 7, limits.PER_MINUTE),
+    limits.Limit("POST", "/volumes", "^/volumes", 3, limits.PER_MINUTE),
+    limits.Limit("PUT", "*", "", 10, limits.PER_MINUTE),
+    limits.Limit("PUT", "/volumes", "^/volumes", 5, limits.PER_MINUTE),
+]
+NS = {
+    'atom': 'http://www.w3.org/2005/Atom',
+    'ns': 'http://docs.openstack.org/common/api/v1.0'
+}
+
+
+class BaseLimitTestSuite(test.TestCase):
+    """Base test suite which provides relevant stubs and time abstraction."""
+
+    def setUp(self):
+        super(BaseLimitTestSuite, self).setUp()
+        self.time = 0.0
+        self.stubs.Set(limits.Limit, "_get_time", self._get_time)
+        self.absolute_limits = {}
+
+        def stub_get_project_quotas(context, project_id, usages=True):
+            return dict((k, dict(limit=v))
+                        for k, v in self.absolute_limits.items())
+
+        self.stubs.Set(cinder.quota.QUOTAS, "get_project_quotas",
+                       stub_get_project_quotas)
+
+    def _get_time(self):
+        """Return the "time" according to this test suite."""
+        return self.time
+
+
+class LimitsControllerTest(BaseLimitTestSuite):
+    """
+    Tests for `limits.LimitsController` class.
+    """
+
+    def setUp(self):
+        """Run before each test."""
+        super(LimitsControllerTest, self).setUp()
+        self.controller = limits.create_resource()
+
+    def _get_index_request(self, accept_header="application/json"):
+        """Helper to set routing arguments."""
+        request = webob.Request.blank("/")
+        request.accept = accept_header
+        request.environ["wsgiorg.routing_args"] = (None, {
+            "action": "index",
+            "controller": "",
+        })
+        context = cinder.context.RequestContext('testuser', 'testproject')
+        request.environ["cinder.context"] = context
+        return request
+
+    def _populate_limits(self, request):
+        """Put limit info into a request."""
+        _limits = [
+            limits.Limit("GET", "*", ".*", 10, 60).display(),
+            limits.Limit("POST", "*", ".*", 5, 60 * 60).display(),
+            limits.Limit("GET", "changes-since*", "changes-since",
+                         5, 60).display(),
+        ]
+        request.environ["cinder.limits"] = _limits
+        return request
+
+    def test_empty_index_json(self):
+        """Test getting empty limit details in JSON."""
+        request = self._get_index_request()
+        response = request.get_response(self.controller)
+        expected = {
+            "limits": {
+                "rate": [],
+                "absolute": {},
+            },
+        }
+        body = jsonutils.loads(response.body)
+        self.assertEqual(expected, body)
+
+    def test_index_json(self):
+        """Test getting limit details in JSON."""
+        request = self._get_index_request()
+        request = self._populate_limits(request)
+        self.absolute_limits = {
+            'gigabytes': 512,
+            'volumes': 5,
+        }
+        response = request.get_response(self.controller)
+        expected = {
+            "limits": {
+                "rate": [
+                    {
+                        "regex": ".*",
+                        "uri": "*",
+                        "limit": [
+                            {
+                                "verb": "GET",
+                                "next-available": "1970-01-01T00:00:00Z",
+                                "unit": "MINUTE",
+                                "value": 10,
+                                "remaining": 10,
+                            },
+                            {
+                                "verb": "POST",
+                                "next-available": "1970-01-01T00:00:00Z",
+                                "unit": "HOUR",
+                                "value": 5,
+                                "remaining": 5,
+                            },
+                        ],
+                    },
+                    {
+                        "regex": "changes-since",
+                        "uri": "changes-since*",
+                        "limit": [
+                            {
+                                "verb": "GET",
+                                "next-available": "1970-01-01T00:00:00Z",
+                                "unit": "MINUTE",
+                                "value": 5,
+                                "remaining": 5,
+                            },
+                        ],
+                    },
+
+                ],
+                "absolute": {
+                    "maxTotalVolumeGigabytes": 512,
+                    "maxTotalVolumes": 5,
+                    },
+            },
+        }
+        body = jsonutils.loads(response.body)
+        self.assertEqual(expected, body)
+
+    def _populate_limits_diff_regex(self, request):
+        """Put limit info into a request."""
+        _limits = [
+            limits.Limit("GET", "*", ".*", 10, 60).display(),
+            limits.Limit("GET", "*", "*.*", 10, 60).display(),
+        ]
+        request.environ["cinder.limits"] = _limits
+        return request
+
+    def test_index_diff_regex(self):
+        """Test getting limit details in JSON."""
+        request = self._get_index_request()
+        request = self._populate_limits_diff_regex(request)
+        response = request.get_response(self.controller)
+        expected = {
+            "limits": {
+                "rate": [
+                    {
+                        "regex": ".*",
+                        "uri": "*",
+                        "limit": [
+                            {
+                                "verb": "GET",
+                                "next-available": "1970-01-01T00:00:00Z",
+                                "unit": "MINUTE",
+                                "value": 10,
+                                "remaining": 10,
+                            },
+                        ],
+                    },
+                    {
+                        "regex": "*.*",
+                        "uri": "*",
+                        "limit": [
+                            {
+                                "verb": "GET",
+                                "next-available": "1970-01-01T00:00:00Z",
+                                "unit": "MINUTE",
+                                "value": 10,
+                                "remaining": 10,
+                            },
+                        ],
+                    },
+
+                ],
+                "absolute": {},
+            },
+        }
+        body = jsonutils.loads(response.body)
+        self.assertEqual(expected, body)
+
+    def _test_index_absolute_limits_json(self, expected):
+        request = self._get_index_request()
+        response = request.get_response(self.controller)
+        body = jsonutils.loads(response.body)
+        self.assertEqual(expected, body['limits']['absolute'])
+
+    def test_index_ignores_extra_absolute_limits_json(self):
+        self.absolute_limits = {'unknown_limit': 9001}
+        self._test_index_absolute_limits_json({})
+
+
+class TestLimiter(limits.Limiter):
+    pass
+
+
+class LimitMiddlewareTest(BaseLimitTestSuite):
+    """
+    Tests for the `limits.RateLimitingMiddleware` class.
+    """
+
+    @webob.dec.wsgify
+    def _empty_app(self, request):
+        """Do-nothing WSGI app."""
+        pass
+
+    def setUp(self):
+        """Prepare middleware for use through fake WSGI app."""
+        super(LimitMiddlewareTest, self).setUp()
+        _limits = '(GET, *, .*, 1, MINUTE)'
+        self.app = limits.RateLimitingMiddleware(self._empty_app, _limits,
+                                                 "%s.TestLimiter" %
+                                                 self.__class__.__module__)
+
+    def test_limit_class(self):
+        """Test that middleware selected correct limiter class."""
+        assert isinstance(self.app._limiter, TestLimiter)
+
+    def test_good_request(self):
+        """Test successful GET request through middleware."""
+        request = webob.Request.blank("/")
+        response = request.get_response(self.app)
+        self.assertEqual(200, response.status_int)
+
+    def test_limited_request_json(self):
+        """Test a rate-limited (413) GET request through middleware."""
+        request = webob.Request.blank("/")
+        response = request.get_response(self.app)
+        self.assertEqual(200, response.status_int)
+
+        request = webob.Request.blank("/")
+        response = request.get_response(self.app)
+        self.assertEqual(response.status_int, 413)
+
+        self.assertTrue('Retry-After' in response.headers)
+        retry_after = int(response.headers['Retry-After'])
+        self.assertAlmostEqual(retry_after, 60, 1)
+
+        body = jsonutils.loads(response.body)
+        expected = "Only 1 GET request(s) can be made to * every minute."
+        value = body["overLimitFault"]["details"].strip()
+        self.assertEqual(value, expected)
+
+    def test_limited_request_xml(self):
+        """Test a rate-limited (413) response as XML"""
+        request = webob.Request.blank("/")
+        response = request.get_response(self.app)
+        self.assertEqual(200, response.status_int)
+
+        request = webob.Request.blank("/")
+        request.accept = "application/xml"
+        response = request.get_response(self.app)
+        self.assertEqual(response.status_int, 413)
+
+        root = minidom.parseString(response.body).childNodes[0]
+        expected = "Only 1 GET request(s) can be made to * every minute."
+
+        details = root.getElementsByTagName("details")
+        self.assertEqual(details.length, 1)
+
+        value = details.item(0).firstChild.data.strip()
+        self.assertEqual(value, expected)
+
+
+class LimitTest(BaseLimitTestSuite):
+    """
+    Tests for the `limits.Limit` class.
+    """
+
+    def test_GET_no_delay(self):
+        """Test a limit handles 1 GET per second."""
+        limit = limits.Limit("GET", "*", ".*", 1, 1)
+        delay = limit("GET", "/anything")
+        self.assertEqual(None, delay)
+        self.assertEqual(0, limit.next_request)
+        self.assertEqual(0, limit.last_request)
+
+    def test_GET_delay(self):
+        """Test two calls to 1 GET per second limit."""
+        limit = limits.Limit("GET", "*", ".*", 1, 1)
+        delay = limit("GET", "/anything")
+        self.assertEqual(None, delay)
+
+        delay = limit("GET", "/anything")
+        self.assertEqual(1, delay)
+        self.assertEqual(1, limit.next_request)
+        self.assertEqual(0, limit.last_request)
+
+        self.time += 4
+
+        delay = limit("GET", "/anything")
+        self.assertEqual(None, delay)
+        self.assertEqual(4, limit.next_request)
+        self.assertEqual(4, limit.last_request)
+
+
+class ParseLimitsTest(BaseLimitTestSuite):
+    """
+    Tests for the default limits parser in the in-memory
+    `limits.Limiter` class.
+    """
+
+    def test_invalid(self):
+        """Test that parse_limits() handles invalid input correctly."""
+        self.assertRaises(ValueError, limits.Limiter.parse_limits,
+                          ';;;;;')
+
+    def test_bad_rule(self):
+        """Test that parse_limits() handles bad rules correctly."""
+        self.assertRaises(ValueError, limits.Limiter.parse_limits,
+                          'GET, *, .*, 20, minute')
+
+    def test_missing_arg(self):
+        """Test that parse_limits() handles missing args correctly."""
+        self.assertRaises(ValueError, limits.Limiter.parse_limits,
+                          '(GET, *, .*, 20)')
+
+    def test_bad_value(self):
+        """Test that parse_limits() handles bad values correctly."""
+        self.assertRaises(ValueError, limits.Limiter.parse_limits,
+                          '(GET, *, .*, foo, minute)')
+
+    def test_bad_unit(self):
+        """Test that parse_limits() handles bad units correctly."""
+        self.assertRaises(ValueError, limits.Limiter.parse_limits,
+                          '(GET, *, .*, 20, lightyears)')
+
+    def test_multiple_rules(self):
+        """Test that parse_limits() handles multiple rules correctly."""
+        try:
+            l = limits.Limiter.parse_limits('(get, *, .*, 20, minute);'
+                                            '(PUT, /foo*, /foo.*, 10, hour);'
+                                            '(POST, /bar*, /bar.*, 5, second);'
+                                            '(Say, /derp*, /derp.*, 1, day)')
+        except ValueError, e:
+            assert False, str(e)
+
+        # Make sure the number of returned limits are correct
+        self.assertEqual(len(l), 4)
+
+        # Check all the verbs...
+        expected = ['GET', 'PUT', 'POST', 'SAY']
+        self.assertEqual([t.verb for t in l], expected)
+
+        # ...the URIs...
+        expected = ['*', '/foo*', '/bar*', '/derp*']
+        self.assertEqual([t.uri for t in l], expected)
+
+        # ...the regexes...
+        expected = ['.*', '/foo.*', '/bar.*', '/derp.*']
+        self.assertEqual([t.regex for t in l], expected)
+
+        # ...the values...
+        expected = [20, 10, 5, 1]
+        self.assertEqual([t.value for t in l], expected)
+
+        # ...and the units...
+        expected = [limits.PER_MINUTE, limits.PER_HOUR,
+                    limits.PER_SECOND, limits.PER_DAY]
+        self.assertEqual([t.unit for t in l], expected)
+
+
+class LimiterTest(BaseLimitTestSuite):
+    """
+    Tests for the in-memory `limits.Limiter` class.
+    """
+
+    def setUp(self):
+        """Run before each test."""
+        super(LimiterTest, self).setUp()
+        userlimits = {'user:user3': ''}
+        self.limiter = limits.Limiter(TEST_LIMITS, **userlimits)
+
+    def _check(self, num, verb, url, username=None):
+        """Check and yield results from checks."""
+        for x in xrange(num):
+            yield self.limiter.check_for_delay(verb, url, username)[0]
+
+    def _check_sum(self, num, verb, url, username=None):
+        """Check and sum results from checks."""
+        results = self._check(num, verb, url, username)
+        return sum(item for item in results if item)
+
+    def test_no_delay_GET(self):
+        """
+        Simple test to ensure no delay on a single call for a limit verb we
+        didn"t set.
+        """
+        delay = self.limiter.check_for_delay("GET", "/anything")
+        self.assertEqual(delay, (None, None))
+
+    def test_no_delay_PUT(self):
+        """
+        Simple test to ensure no delay on a single call for a known limit.
+        """
+        delay = self.limiter.check_for_delay("PUT", "/anything")
+        self.assertEqual(delay, (None, None))
+
+    def test_delay_PUT(self):
+        """
+        Ensure the 11th PUT will result in a delay of 6.0 seconds until
+        the next request will be granced.
+        """
+        expected = [None] * 10 + [6.0]
+        results = list(self._check(11, "PUT", "/anything"))
+
+        self.assertEqual(expected, results)
+
+    def test_delay_POST(self):
+        """
+        Ensure the 8th POST will result in a delay of 6.0 seconds until
+        the next request will be granced.
+        """
+        expected = [None] * 7
+        results = list(self._check(7, "POST", "/anything"))
+        self.assertEqual(expected, results)
+
+        expected = 60.0 / 7.0
+        results = self._check_sum(1, "POST", "/anything")
+        self.failUnlessAlmostEqual(expected, results, 8)
+
+    def test_delay_GET(self):
+        """
+        Ensure the 11th GET will result in NO delay.
+        """
+        expected = [None] * 11
+        results = list(self._check(11, "GET", "/anything"))
+
+        self.assertEqual(expected, results)
+
+    def test_delay_PUT_volumes(self):
+        """
+        Ensure PUT on /volumes limits at 5 requests, and PUT elsewhere is still
+        OK after 5 requests...but then after 11 total requests, PUT limiting
+        kicks in.
+        """
+        # First 6 requests on PUT /volumes
+        expected = [None] * 5 + [12.0]
+        results = list(self._check(6, "PUT", "/volumes"))
+        self.assertEqual(expected, results)
+
+        # Next 5 request on PUT /anything
+        expected = [None] * 4 + [6.0]
+        results = list(self._check(5, "PUT", "/anything"))
+        self.assertEqual(expected, results)
+
+    def test_delay_PUT_wait(self):
+        """
+        Ensure after hitting the limit and then waiting for the correct
+        amount of time, the limit will be lifted.
+        """
+        expected = [None] * 10 + [6.0]
+        results = list(self._check(11, "PUT", "/anything"))
+        self.assertEqual(expected, results)
+
+        # Advance time
+        self.time += 6.0
+
+        expected = [None, 6.0]
+        results = list(self._check(2, "PUT", "/anything"))
+        self.assertEqual(expected, results)
+
+    def test_multiple_delays(self):
+        """
+        Ensure multiple requests still get a delay.
+        """
+        expected = [None] * 10 + [6.0] * 10
+        results = list(self._check(20, "PUT", "/anything"))
+        self.assertEqual(expected, results)
+
+        self.time += 1.0
+
+        expected = [5.0] * 10
+        results = list(self._check(10, "PUT", "/anything"))
+        self.assertEqual(expected, results)
+
+    def test_user_limit(self):
+        """
+        Test user-specific limits.
+        """
+        self.assertEqual(self.limiter.levels['user3'], [])
+
+    def test_multiple_users(self):
+        """
+        Tests involving multiple users.
+        """
+        # User1
+        expected = [None] * 10 + [6.0] * 10
+        results = list(self._check(20, "PUT", "/anything", "user1"))
+        self.assertEqual(expected, results)
+
+        # User2
+        expected = [None] * 10 + [6.0] * 5
+        results = list(self._check(15, "PUT", "/anything", "user2"))
+        self.assertEqual(expected, results)
+
+        # User3
+        expected = [None] * 20
+        results = list(self._check(20, "PUT", "/anything", "user3"))
+        self.assertEqual(expected, results)
+
+        self.time += 1.0
+
+        # User1 again
+        expected = [5.0] * 10
+        results = list(self._check(10, "PUT", "/anything", "user1"))
+        self.assertEqual(expected, results)
+
+        self.time += 1.0
+
+        # User1 again
+        expected = [4.0] * 5
+        results = list(self._check(5, "PUT", "/anything", "user2"))
+        self.assertEqual(expected, results)
+
+
+class WsgiLimiterTest(BaseLimitTestSuite):
+    """
+    Tests for `limits.WsgiLimiter` class.
+    """
+
+    def setUp(self):
+        """Run before each test."""
+        super(WsgiLimiterTest, self).setUp()
+        self.app = limits.WsgiLimiter(TEST_LIMITS)
+
+    def _request_data(self, verb, path):
+        """Get data decribing a limit request verb/path."""
+        return jsonutils.dumps({"verb": verb, "path": path})
+
+    def _request(self, verb, url, username=None):
+        """Make sure that POSTing to the given url causes the given username
+        to perform the given action.  Make the internal rate limiter return
+        delay and make sure that the WSGI app returns the correct response.
+        """
+        if username:
+            request = webob.Request.blank("/%s" % username)
+        else:
+            request = webob.Request.blank("/")
+
+        request.method = "POST"
+        request.body = self._request_data(verb, url)
+        response = request.get_response(self.app)
+
+        if "X-Wait-Seconds" in response.headers:
+            self.assertEqual(response.status_int, 403)
+            return response.headers["X-Wait-Seconds"]
+
+        self.assertEqual(response.status_int, 204)
+
+    def test_invalid_methods(self):
+        """Only POSTs should work."""
+        requests = []
+        for method in ["GET", "PUT", "DELETE", "HEAD", "OPTIONS"]:
+            request = webob.Request.blank("/", method=method)
+            response = request.get_response(self.app)
+            self.assertEqual(response.status_int, 405)
+
+    def test_good_url(self):
+        delay = self._request("GET", "/something")
+        self.assertEqual(delay, None)
+
+    def test_escaping(self):
+        delay = self._request("GET", "/something/jump%20up")
+        self.assertEqual(delay, None)
+
+    def test_response_to_delays(self):
+        delay = self._request("GET", "/delayed")
+        self.assertEqual(delay, None)
+
+        delay = self._request("GET", "/delayed")
+        self.assertEqual(delay, '60.00')
+
+    def test_response_to_delays_usernames(self):
+        delay = self._request("GET", "/delayed", "user1")
+        self.assertEqual(delay, None)
+
+        delay = self._request("GET", "/delayed", "user2")
+        self.assertEqual(delay, None)
+
+        delay = self._request("GET", "/delayed", "user1")
+        self.assertEqual(delay, '60.00')
+
+        delay = self._request("GET", "/delayed", "user2")
+        self.assertEqual(delay, '60.00')
+
+
+class FakeHttplibSocket(object):
+    """
+    Fake `httplib.HTTPResponse` replacement.
+    """
+
+    def __init__(self, response_string):
+        """Initialize new `FakeHttplibSocket`."""
+        self._buffer = StringIO.StringIO(response_string)
+
+    def makefile(self, _mode, _other):
+        """Returns the socket's internal buffer."""
+        return self._buffer
+
+
+class FakeHttplibConnection(object):
+    """
+    Fake `httplib.HTTPConnection`.
+    """
+
+    def __init__(self, app, host):
+        """
+        Initialize `FakeHttplibConnection`.
+        """
+        self.app = app
+        self.host = host
+
+    def request(self, method, path, body="", headers=None):
+        """
+        Requests made via this connection actually get translated and routed
+        into our WSGI app, we then wait for the response and turn it back into
+        an `httplib.HTTPResponse`.
+        """
+        if not headers:
+            headers = {}
+
+        req = webob.Request.blank(path)
+        req.method = method
+        req.headers = headers
+        req.host = self.host
+        req.body = body
+
+        resp = str(req.get_response(self.app))
+        resp = "HTTP/1.0 %s" % resp
+        sock = FakeHttplibSocket(resp)
+        self.http_response = httplib.HTTPResponse(sock)
+        self.http_response.begin()
+
+    def getresponse(self):
+        """Return our generated response from the request."""
+        return self.http_response
+
+
+def wire_HTTPConnection_to_WSGI(host, app):
+    """Monkeypatches HTTPConnection so that if you try to connect to host, you
+    are instead routed straight to the given WSGI app.
+
+    After calling this method, when any code calls
+
+    httplib.HTTPConnection(host)
+
+    the connection object will be a fake.  Its requests will be sent directly
+    to the given WSGI app rather than through a socket.
+
+    Code connecting to hosts other than host will not be affected.
+
+    This method may be called multiple times to map different hosts to
+    different apps.
+
+    This method returns the original HTTPConnection object, so that the caller
+    can restore the default HTTPConnection interface (for all hosts).
+    """
+    class HTTPConnectionDecorator(object):
+        """Wraps the real HTTPConnection class so that when you instantiate
+        the class you might instead get a fake instance."""
+
+        def __init__(self, wrapped):
+            self.wrapped = wrapped
+
+        def __call__(self, connection_host, *args, **kwargs):
+            if connection_host == host:
+                return FakeHttplibConnection(app, host)
+            else:
+                return self.wrapped(connection_host, *args, **kwargs)
+
+    oldHTTPConnection = httplib.HTTPConnection
+    httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)
+    return oldHTTPConnection
+
+
+class WsgiLimiterProxyTest(BaseLimitTestSuite):
+    """
+    Tests for the `limits.WsgiLimiterProxy` class.
+    """
+
+    def setUp(self):
+        """
+        Do some nifty HTTP/WSGI magic which allows for WSGI to be called
+        directly by something like the `httplib` library.
+        """
+        super(WsgiLimiterProxyTest, self).setUp()
+        self.app = limits.WsgiLimiter(TEST_LIMITS)
+        self.oldHTTPConnection = (
+            wire_HTTPConnection_to_WSGI("169.254.0.1:80", self.app))
+        self.proxy = limits.WsgiLimiterProxy("169.254.0.1:80")
+
+    def test_200(self):
+        """Successful request test."""
+        delay = self.proxy.check_for_delay("GET", "/anything")
+        self.assertEqual(delay, (None, None))
+
+    def test_403(self):
+        """Forbidden request test."""
+        delay = self.proxy.check_for_delay("GET", "/delayed")
+        self.assertEqual(delay, (None, None))
+
+        delay, error = self.proxy.check_for_delay("GET", "/delayed")
+        error = error.strip()
+
+        expected = ("60.00", "403 Forbidden\n\nOnly 1 GET request(s) can be "
+                    "made to /delayed every minute.")
+
+        self.assertEqual((delay, error), expected)
+
+    def tearDown(self):
+        # restore original HTTPConnection object
+        httplib.HTTPConnection = self.oldHTTPConnection
+
+
+class LimitsViewBuilderTest(test.TestCase):
+    def setUp(self):
+        super(LimitsViewBuilderTest, self).setUp()
+        self.view_builder = views.limits.ViewBuilder()
+        self.rate_limits = [{"URI": "*",
+                             "regex": ".*",
+                             "value": 10,
+                             "verb": "POST",
+                             "remaining": 2,
+                             "unit": "MINUTE",
+                             "resetTime": 1311272226},
+                            {"URI": "*/volumes",
+                             "regex": "^/volumes",
+                             "value": 50,
+                             "verb": "POST",
+                             "remaining": 10,
+                             "unit": "DAY",
+                             "resetTime": 1311272226}]
+        self.absolute_limits = {"metadata_items": 1,
+                                "injected_files": 5,
+                                "injected_file_content_bytes": 5}
+
+    def test_build_limits(self):
+        expected_limits = {"limits": {
+                "rate": [{
+                      "uri": "*",
+                      "regex": ".*",
+                      "limit": [{"value": 10,
+                                 "verb": "POST",
+                                 "remaining": 2,
+                                 "unit": "MINUTE",
+                                 "next-available": "2011-07-21T18:17:06Z"}]},
+                   {"uri": "*/volumes",
+                    "regex": "^/volumes",
+                    "limit": [{"value": 50,
+                               "verb": "POST",
+                               "remaining": 10,
+                               "unit": "DAY",
+                               "next-available": "2011-07-21T18:17:06Z"}]}],
+                "absolute": {"maxServerMeta": 1,
+                             "maxImageMeta": 1,
+                             "maxPersonality": 5,
+                             "maxPersonalitySize": 5}}}
+
+        output = self.view_builder.build(self.rate_limits,
+                                         self.absolute_limits)
+        self.assertDictMatch(output, expected_limits)
+
+    def test_build_limits_empty_limits(self):
+        expected_limits = {"limits": {"rate": [],
+                           "absolute": {}}}
+
+        abs_limits = {}
+        rate_limits = []
+        output = self.view_builder.build(rate_limits, abs_limits)
+        self.assertDictMatch(output, expected_limits)
+
+
+class LimitsXMLSerializationTest(test.TestCase):
+    def test_xml_declaration(self):
+        serializer = limits.LimitsTemplate()
+
+        fixture = {"limits": {
+                   "rate": [],
+                   "absolute": {}}}
+
+        output = serializer.serialize(fixture)
+        has_dec = output.startswith("<?xml version='1.0' encoding='UTF-8'?>")
+        self.assertTrue(has_dec)
+
+    def test_index(self):
+        serializer = limits.LimitsTemplate()
+        fixture = {
+            "limits": {
+                   "rate": [{
+                         "uri": "*",
+                         "regex": ".*",
+                         "limit": [{
+                              "value": 10,
+                              "verb": "POST",
+                              "remaining": 2,
+                              "unit": "MINUTE",
+                              "next-available": "2011-12-15T22:42:45Z"}]},
+                          {"uri": "*/servers",
+                           "regex": "^/servers",
+                           "limit": [{
+                              "value": 50,
+                              "verb": "POST",
+                              "remaining": 10,
+                              "unit": "DAY",
+                              "next-available": "2011-12-15T22:42:45Z"}]}],
+                    "absolute": {"maxServerMeta": 1,
+                                 "maxImageMeta": 1,
+                                 "maxPersonality": 5,
+                                 "maxPersonalitySize": 10240}}}
+
+        output = serializer.serialize(fixture)
+        root = etree.XML(output)
+        xmlutil.validate_schema(root, 'limits')
+
+        #verify absolute limits
+        absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
+        self.assertEqual(len(absolutes), 4)
+        for limit in absolutes:
+            name = limit.get('name')
+            value = limit.get('value')
+            self.assertEqual(value, str(fixture['limits']['absolute'][name]))
+
+        #verify rate limits
+        rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
+        self.assertEqual(len(rates), 2)
+        for i, rate in enumerate(rates):
+            for key in ['uri', 'regex']:
+                self.assertEqual(rate.get(key),
+                                 str(fixture['limits']['rate'][i][key]))
+            rate_limits = rate.xpath('ns:limit', namespaces=NS)
+            self.assertEqual(len(rate_limits), 1)
+            for j, limit in enumerate(rate_limits):
+                for key in ['verb', 'value', 'remaining', 'unit',
+                            'next-available']:
+                    self.assertEqual(limit.get(key),
+                         str(fixture['limits']['rate'][i]['limit'][j][key]))
+
+    def test_index_no_limits(self):
+        serializer = limits.LimitsTemplate()
+
+        fixture = {"limits": {
+                   "rate": [],
+                   "absolute": {}}}
+
+        output = serializer.serialize(fixture)
+        root = etree.XML(output)
+        xmlutil.validate_schema(root, 'limits')
+
+        #verify absolute limits
+        absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
+        self.assertEqual(len(absolutes), 0)
+
+        #verify rate limits
+        rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
+        self.assertEqual(len(rates), 0)
index f3252c7dcdee4cc830640911d5dee4216943d4dd..0b82bde99306b7b8326956285f829443da8c6d1b 100644 (file)
@@ -15,8 +15,8 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-import unittest
 import time
+import unittest
 
 from cinder import service
 from cinder.openstack.common import log as logging
index b2f63641925f32386949dea84a9931487c8e473a..ad165587db597d25b0b388c1fb6fc6aaf32bf54b 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import datetime
+
 from cinder import context
 from cinder import db
+from cinder.db.sqlalchemy import api as sqa_api
+from cinder.db.sqlalchemy import models as sqa_models
 from cinder import exception
 from cinder import flags
-from cinder import quota
 from cinder.openstack.common import rpc
+from cinder.openstack.common import timeutils
+from cinder import quota
 from cinder import test
+import cinder.tests.image.fake
 from cinder import volume
 
 
 FLAGS = flags.FLAGS
 
 
-class GetQuotaTestCase(test.TestCase):
-    def setUp(self):
-        super(GetQuotaTestCase, self).setUp()
-        self.flags(quota_instances=10,
-                   quota_cores=20,
-                   quota_ram=50 * 1024,
-                   quota_volumes=10,
-                   quota_gigabytes=1000,
-                   quota_floating_ips=10,
-                   quota_security_groups=10,
-                   quota_security_group_rules=20,
-                   quota_metadata_items=128,
-                   quota_injected_files=5,
-                   quota_injected_file_content_bytes=10 * 1024)
-        self.context = context.RequestContext('admin', 'admin', is_admin=True)
-
-    def _stub_class(self):
-        def fake_quota_class_get_all_by_name(context, quota_class):
-            result = dict(class_name=quota_class)
-            if quota_class == 'test_class':
-                result.update(
-                    instances=5,
-                    cores=10,
-                    ram=25 * 1024,
-                    volumes=5,
-                    gigabytes=500,
-                    floating_ips=5,
-                    quota_security_groups=10,
-                    quota_security_group_rules=20,
-                    metadata_items=64,
-                    injected_files=2,
-                    injected_file_content_bytes=5 * 1024,
-                    invalid_quota=100,
-                    )
-            return result
-
-        self.stubs.Set(db, 'quota_class_get_all_by_name',
-                       fake_quota_class_get_all_by_name)
-
-    def _stub_project(self, override=False):
-        def fake_quota_get_all_by_project(context, project_id):
-            result = dict(project_id=project_id)
-            if override:
-                result.update(
-                    instances=2,
-                    cores=5,
-                    ram=12 * 1024,
-                    volumes=2,
-                    gigabytes=250,
-                    floating_ips=2,
-                    security_groups=5,
-                    security_group_rules=10,
-                    metadata_items=32,
-                    injected_files=1,
-                    injected_file_content_bytes=2 * 1024,
-                    invalid_quota=50,
-                    )
-            return result
-
-        self.stubs.Set(db, 'quota_get_all_by_project',
-                       fake_quota_get_all_by_project)
-
-    def test_default_quotas(self):
-        result = quota._get_default_quotas()
-        self.assertEqual(result, dict(
-                instances=10,
-                cores=20,
-                ram=50 * 1024,
-                volumes=10,
-                gigabytes=1000,
-                floating_ips=10,
-                security_groups=10,
-                security_group_rules=20,
-                metadata_items=128,
-                injected_files=5,
-                injected_file_content_bytes=10 * 1024,
-                ))
-
-    def test_default_quotas_unlimited(self):
-        self.flags(quota_instances=-1,
-                   quota_cores=-1,
-                   quota_ram=-1,
-                   quota_volumes=-1,
-                   quota_gigabytes=-1,
-                   quota_floating_ips=-1,
-                   quota_security_groups=-1,
-                   quota_security_group_rules=-1,
-                   quota_metadata_items=-1,
-                   quota_injected_files=-1,
-                   quota_injected_file_content_bytes=-1)
-        result = quota._get_default_quotas()
-        self.assertEqual(result, dict(
-                instances=-1,
-                cores=-1,
-                ram=-1,
-                volumes=-1,
-                gigabytes=-1,
-                floating_ips=-1,
-                security_groups=-1,
-                security_group_rules=-1,
-                metadata_items=-1,
-                injected_files=-1,
-                injected_file_content_bytes=-1,
-                ))
-
-    def test_class_quotas_noclass(self):
-        self._stub_class()
-        result = quota.get_class_quotas(self.context, 'noclass')
-        self.assertEqual(result, dict(
-                instances=10,
-                cores=20,
-                ram=50 * 1024,
-                volumes=10,
-                gigabytes=1000,
-                floating_ips=10,
-                security_groups=10,
-                security_group_rules=20,
-                metadata_items=128,
-                injected_files=5,
-                injected_file_content_bytes=10 * 1024,
-                ))
-
-    def test_class_quotas(self):
-        self._stub_class()
-        result = quota.get_class_quotas(self.context, 'test_class')
-        self.assertEqual(result, dict(
-                instances=5,
-                cores=10,
-                ram=25 * 1024,
-                volumes=5,
-                gigabytes=500,
-                floating_ips=5,
-                security_groups=10,
-                security_group_rules=20,
-                metadata_items=64,
-                injected_files=2,
-                injected_file_content_bytes=5 * 1024,
-                ))
-
-    def test_project_quotas_defaults_noclass(self):
-        self._stub_class()
-        self._stub_project()
-        result = quota.get_project_quotas(self.context, 'admin')
-        self.assertEqual(result, dict(
-                instances=10,
-                cores=20,
-                ram=50 * 1024,
-                volumes=10,
-                gigabytes=1000,
-                floating_ips=10,
-                security_groups=10,
-                security_group_rules=20,
-                metadata_items=128,
-                injected_files=5,
-                injected_file_content_bytes=10 * 1024,
-                ))
-
-    def test_project_quotas_overrides_noclass(self):
-        self._stub_class()
-        self._stub_project(True)
-        result = quota.get_project_quotas(self.context, 'admin')
-        self.assertEqual(result, dict(
-                instances=2,
-                cores=5,
-                ram=12 * 1024,
-                volumes=2,
-                gigabytes=250,
-                floating_ips=2,
-                security_groups=5,
-                security_group_rules=10,
-                metadata_items=32,
-                injected_files=1,
-                injected_file_content_bytes=2 * 1024,
-                ))
-
-    def test_project_quotas_defaults_withclass(self):
-        self._stub_class()
-        self._stub_project()
-        self.context.quota_class = 'test_class'
-        result = quota.get_project_quotas(self.context, 'admin')
-        self.assertEqual(result, dict(
-                instances=5,
-                cores=10,
-                ram=25 * 1024,
-                volumes=5,
-                gigabytes=500,
-                floating_ips=5,
-                security_groups=10,
-                security_group_rules=20,
-                metadata_items=64,
-                injected_files=2,
-                injected_file_content_bytes=5 * 1024,
-                ))
-
-    def test_project_quotas_overrides_withclass(self):
-        self._stub_class()
-        self._stub_project(True)
-        self.context.quota_class = 'test_class'
-        result = quota.get_project_quotas(self.context, 'admin')
-        self.assertEqual(result, dict(
-                instances=2,
-                cores=5,
-                ram=12 * 1024,
-                volumes=2,
-                gigabytes=250,
-                floating_ips=2,
-                security_groups=5,
-                security_group_rules=10,
-                metadata_items=32,
-                injected_files=1,
-                injected_file_content_bytes=2 * 1024,
-                ))
-
-
-class QuotaTestCase(test.TestCase):
-
-    class StubImageService(object):
-
-        def show(self, *args, **kwargs):
-            return {"properties": {}}
+class QuotaIntegrationTestCase(test.TestCase):
 
     def setUp(self):
-        super(QuotaTestCase, self).setUp()
+        super(QuotaIntegrationTestCase, self).setUp()
         self.flags(quota_volumes=2,
                    quota_gigabytes=20)
+
+        # Apparently needed by the RPC tests...
+        #self.network = self.start_service('network')
+
         self.user_id = 'admin'
         self.project_id = 'admin'
         self.context = context.RequestContext(self.user_id,
@@ -266,6 +57,10 @@ class QuotaTestCase(test.TestCase):
 
         self.stubs.Set(rpc, 'call', rpc_call_wrapper)
 
+    def tearDown(self):
+        super(QuotaIntegrationTestCase, self).tearDown()
+        cinder.tests.image.fake.FakeImageService_reset()
+
     def _create_volume(self, size=10):
         """Create a test volume"""
         vol = {}
@@ -274,16 +69,6 @@ class QuotaTestCase(test.TestCase):
         vol['size'] = size
         return db.volume_create(self.context, vol)['id']
 
-    def test_unlimited_volumes(self):
-        self.flags(quota_volumes=10, quota_gigabytes=-1)
-        volumes = quota.allowed_volumes(self.context, 100, 1)
-        self.assertEqual(volumes, 10)
-        db.quota_create(self.context, self.project_id, 'volumes', -1)
-        volumes = quota.allowed_volumes(self.context, 100, 1)
-        self.assertEqual(volumes, 100)
-        volumes = quota.allowed_volumes(self.context, 101, 1)
-        self.assertEqual(volumes, 101)
-
     def test_too_many_volumes(self):
         volume_ids = []
         for i in range(FLAGS.quota_volumes):
@@ -304,3 +89,1281 @@ class QuotaTestCase(test.TestCase):
                           self.context, 10, '', '', None)
         for volume_id in volume_ids:
             db.volume_destroy(self.context, volume_id)
+
+
+class FakeContext(object):
+    def __init__(self, project_id, quota_class):
+        self.is_admin = False
+        self.user_id = 'fake_user'
+        self.project_id = project_id
+        self.quota_class = quota_class
+
+    def elevated(self):
+        elevated = self.__class__(self.project_id, self.quota_class)
+        elevated.is_admin = True
+        return elevated
+
+
+class FakeDriver(object):
+    def __init__(self, by_project=None, by_class=None, reservations=None):
+        self.called = []
+        self.by_project = by_project or {}
+        self.by_class = by_class or {}
+        self.reservations = reservations or []
+
+    def get_by_project(self, context, project_id, resource):
+        self.called.append(('get_by_project', context, project_id, resource))
+        try:
+            return self.by_project[project_id][resource]
+        except KeyError:
+            raise exception.ProjectQuotaNotFound(project_id=project_id)
+
+    def get_by_class(self, context, quota_class, resource):
+        self.called.append(('get_by_class', context, quota_class, resource))
+        try:
+            return self.by_class[quota_class][resource]
+        except KeyError:
+            raise exception.QuotaClassNotFound(class_name=quota_class)
+
+    def get_defaults(self, context, resources):
+        self.called.append(('get_defaults', context, resources))
+        return resources
+
+    def get_class_quotas(self, context, resources, quota_class,
+                         defaults=True):
+        self.called.append(('get_class_quotas', context, resources,
+                            quota_class, defaults))
+        return resources
+
+    def get_project_quotas(self, context, resources, project_id,
+                           quota_class=None, defaults=True, usages=True):
+        self.called.append(('get_project_quotas', context, resources,
+                            project_id, quota_class, defaults, usages))
+        return resources
+
+    def limit_check(self, context, resources, values):
+        self.called.append(('limit_check', context, resources, values))
+
+    def reserve(self, context, resources, deltas, expire=None):
+        self.called.append(('reserve', context, resources, deltas, expire))
+        return self.reservations
+
+    def commit(self, context, reservations):
+        self.called.append(('commit', context, reservations))
+
+    def rollback(self, context, reservations):
+        self.called.append(('rollback', context, reservations))
+
+    def destroy_all_by_project(self, context, project_id):
+        self.called.append(('destroy_all_by_project', context, project_id))
+
+    def expire(self, context):
+        self.called.append(('expire', context))
+
+
+class BaseResourceTestCase(test.TestCase):
+    def test_no_flag(self):
+        resource = quota.BaseResource('test_resource')
+
+        self.assertEqual(resource.name, 'test_resource')
+        self.assertEqual(resource.flag, None)
+        self.assertEqual(resource.default, -1)
+
+    def test_with_flag(self):
+        # We know this flag exists, so use it...
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+
+        self.assertEqual(resource.name, 'test_resource')
+        self.assertEqual(resource.flag, 'quota_volumes')
+        self.assertEqual(resource.default, 10)
+
+    def test_with_flag_no_quota(self):
+        self.flags(quota_volumes=-1)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+
+        self.assertEqual(resource.name, 'test_resource')
+        self.assertEqual(resource.flag, 'quota_volumes')
+        self.assertEqual(resource.default, -1)
+
+    def test_quota_no_project_no_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver()
+        context = FakeContext(None, None)
+        quota_value = resource.quota(driver, context)
+
+        self.assertEqual(quota_value, 10)
+
+    def test_quota_with_project_no_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver(by_project=dict(
+                test_project=dict(test_resource=15),
+                ))
+        context = FakeContext('test_project', None)
+        quota_value = resource.quota(driver, context)
+
+        self.assertEqual(quota_value, 15)
+
+    def test_quota_no_project_with_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver(by_class=dict(
+                test_class=dict(test_resource=20),
+                ))
+        context = FakeContext(None, 'test_class')
+        quota_value = resource.quota(driver, context)
+
+        self.assertEqual(quota_value, 20)
+
+    def test_quota_with_project_with_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver(by_project=dict(
+                test_project=dict(test_resource=15),
+                ),
+                            by_class=dict(
+                test_class=dict(test_resource=20),
+                ))
+        context = FakeContext('test_project', 'test_class')
+        quota_value = resource.quota(driver, context)
+
+        self.assertEqual(quota_value, 15)
+
+    def test_quota_override_project_with_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver(by_project=dict(
+                test_project=dict(test_resource=15),
+                override_project=dict(test_resource=20),
+                ))
+        context = FakeContext('test_project', 'test_class')
+        quota_value = resource.quota(driver, context,
+                                     project_id='override_project')
+
+        self.assertEqual(quota_value, 20)
+
+    def test_quota_with_project_override_class(self):
+        self.flags(quota_volumes=10)
+        resource = quota.BaseResource('test_resource', 'quota_volumes')
+        driver = FakeDriver(by_class=dict(
+                test_class=dict(test_resource=15),
+                override_class=dict(test_resource=20),
+                ))
+        context = FakeContext('test_project', 'test_class')
+        quota_value = resource.quota(driver, context,
+                                     quota_class='override_class')
+
+        self.assertEqual(quota_value, 20)
+
+
+class QuotaEngineTestCase(test.TestCase):
+    def test_init(self):
+        quota_obj = quota.QuotaEngine()
+
+        self.assertEqual(quota_obj._resources, {})
+        self.assertTrue(isinstance(quota_obj._driver, quota.DbQuotaDriver))
+
+    def test_init_override_string(self):
+        quota_obj = quota.QuotaEngine(
+            quota_driver_class='cinder.tests.test_quota.FakeDriver')
+
+        self.assertEqual(quota_obj._resources, {})
+        self.assertTrue(isinstance(quota_obj._driver, FakeDriver))
+
+    def test_init_override_obj(self):
+        quota_obj = quota.QuotaEngine(quota_driver_class=FakeDriver)
+
+        self.assertEqual(quota_obj._resources, {})
+        self.assertEqual(quota_obj._driver, FakeDriver)
+
+    def test_register_resource(self):
+        quota_obj = quota.QuotaEngine()
+        resource = quota.AbsoluteResource('test_resource')
+        quota_obj.register_resource(resource)
+
+        self.assertEqual(quota_obj._resources, dict(test_resource=resource))
+
+    def test_register_resources(self):
+        quota_obj = quota.QuotaEngine()
+        resources = [
+            quota.AbsoluteResource('test_resource1'),
+            quota.AbsoluteResource('test_resource2'),
+            quota.AbsoluteResource('test_resource3'),
+            ]
+        quota_obj.register_resources(resources)
+
+        self.assertEqual(quota_obj._resources, dict(
+                test_resource1=resources[0],
+                test_resource2=resources[1],
+                test_resource3=resources[2],
+                ))
+
+    def test_sync_predeclared(self):
+        quota_obj = quota.QuotaEngine()
+
+        def spam(*args, **kwargs):
+            pass
+
+        resource = quota.ReservableResource('test_resource', spam)
+        quota_obj.register_resource(resource)
+
+        self.assertEqual(resource.sync, spam)
+
+    def test_sync_multi(self):
+        quota_obj = quota.QuotaEngine()
+
+        def spam(*args, **kwargs):
+            pass
+
+        resources = [
+            quota.ReservableResource('test_resource1', spam),
+            quota.ReservableResource('test_resource2', spam),
+            quota.ReservableResource('test_resource3', spam),
+            quota.ReservableResource('test_resource4', spam),
+            ]
+        quota_obj.register_resources(resources[:2])
+
+        self.assertEqual(resources[0].sync, spam)
+        self.assertEqual(resources[1].sync, spam)
+        self.assertEqual(resources[2].sync, spam)
+        self.assertEqual(resources[3].sync, spam)
+
+    def test_get_by_project(self):
+        context = FakeContext('test_project', 'test_class')
+        driver = FakeDriver(by_project=dict(
+                test_project=dict(test_resource=42)))
+        quota_obj = quota.QuotaEngine(quota_driver_class=driver)
+        result = quota_obj.get_by_project(context, 'test_project',
+                                          'test_resource')
+
+        self.assertEqual(driver.called, [
+                ('get_by_project', context, 'test_project', 'test_resource'),
+                ])
+        self.assertEqual(result, 42)
+
+    def test_get_by_class(self):
+        context = FakeContext('test_project', 'test_class')
+        driver = FakeDriver(by_class=dict(
+                test_class=dict(test_resource=42)))
+        quota_obj = quota.QuotaEngine(quota_driver_class=driver)
+        result = quota_obj.get_by_class(context, 'test_class', 'test_resource')
+
+        self.assertEqual(driver.called, [
+                ('get_by_class', context, 'test_class', 'test_resource'),
+                ])
+        self.assertEqual(result, 42)
+
+    def _make_quota_obj(self, driver):
+        quota_obj = quota.QuotaEngine(quota_driver_class=driver)
+        resources = [
+            quota.AbsoluteResource('test_resource4'),
+            quota.AbsoluteResource('test_resource3'),
+            quota.AbsoluteResource('test_resource2'),
+            quota.AbsoluteResource('test_resource1'),
+            ]
+        quota_obj.register_resources(resources)
+
+        return quota_obj
+
+    def test_get_defaults(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        result = quota_obj.get_defaults(context)
+
+        self.assertEqual(driver.called, [
+                ('get_defaults', context, quota_obj._resources),
+                ])
+        self.assertEqual(result, quota_obj._resources)
+
+    def test_get_class_quotas(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        result1 = quota_obj.get_class_quotas(context, 'test_class')
+        result2 = quota_obj.get_class_quotas(context, 'test_class', False)
+
+        self.assertEqual(driver.called, [
+                ('get_class_quotas', context, quota_obj._resources,
+                 'test_class', True),
+                ('get_class_quotas', context, quota_obj._resources,
+                 'test_class', False),
+                ])
+        self.assertEqual(result1, quota_obj._resources)
+        self.assertEqual(result2, quota_obj._resources)
+
+    def test_get_project_quotas(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        result1 = quota_obj.get_project_quotas(context, 'test_project')
+        result2 = quota_obj.get_project_quotas(context, 'test_project',
+                                               quota_class='test_class',
+                                               defaults=False,
+                                               usages=False)
+
+        self.assertEqual(driver.called, [
+                ('get_project_quotas', context, quota_obj._resources,
+                 'test_project', None, True, True),
+                ('get_project_quotas', context, quota_obj._resources,
+                 'test_project', 'test_class', False, False),
+                ])
+        self.assertEqual(result1, quota_obj._resources)
+        self.assertEqual(result2, quota_obj._resources)
+
+    def test_count_no_resource(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          quota_obj.count, context, 'test_resource5',
+                          True, foo='bar')
+
+    def test_count_wrong_resource(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          quota_obj.count, context, 'test_resource1',
+                          True, foo='bar')
+
+    def test_count(self):
+        def fake_count(context, *args, **kwargs):
+            self.assertEqual(args, (True,))
+            self.assertEqual(kwargs, dict(foo='bar'))
+            return 5
+
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.register_resource(quota.CountableResource('test_resource5',
+                                                            fake_count))
+        result = quota_obj.count(context, 'test_resource5', True, foo='bar')
+
+        self.assertEqual(result, 5)
+
+    def test_limit_check(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.limit_check(context, test_resource1=4, test_resource2=3,
+                              test_resource3=2, test_resource4=1)
+
+        self.assertEqual(driver.called, [
+                ('limit_check', context, quota_obj._resources, dict(
+                        test_resource1=4,
+                        test_resource2=3,
+                        test_resource3=2,
+                        test_resource4=1,
+                        )),
+                ])
+
+    def test_reserve(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver(reservations=[
+                'resv-01', 'resv-02', 'resv-03', 'resv-04',
+                ])
+        quota_obj = self._make_quota_obj(driver)
+        result1 = quota_obj.reserve(context, test_resource1=4,
+                                    test_resource2=3, test_resource3=2,
+                                    test_resource4=1)
+        result2 = quota_obj.reserve(context, expire=3600,
+                                    test_resource1=1, test_resource2=2,
+                                    test_resource3=3, test_resource4=4)
+
+        self.assertEqual(driver.called, [
+                ('reserve', context, quota_obj._resources, dict(
+                        test_resource1=4,
+                        test_resource2=3,
+                        test_resource3=2,
+                        test_resource4=1,
+                        ), None),
+                ('reserve', context, quota_obj._resources, dict(
+                        test_resource1=1,
+                        test_resource2=2,
+                        test_resource3=3,
+                        test_resource4=4,
+                        ), 3600),
+                ])
+        self.assertEqual(result1, [
+                'resv-01', 'resv-02', 'resv-03', 'resv-04',
+                ])
+        self.assertEqual(result2, [
+                'resv-01', 'resv-02', 'resv-03', 'resv-04',
+                ])
+
+    def test_commit(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.commit(context, ['resv-01', 'resv-02', 'resv-03'])
+
+        self.assertEqual(driver.called, [
+                ('commit', context, ['resv-01', 'resv-02', 'resv-03']),
+                ])
+
+    def test_rollback(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.rollback(context, ['resv-01', 'resv-02', 'resv-03'])
+
+        self.assertEqual(driver.called, [
+                ('rollback', context, ['resv-01', 'resv-02', 'resv-03']),
+                ])
+
+    def test_destroy_all_by_project(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.destroy_all_by_project(context, 'test_project')
+
+        self.assertEqual(driver.called, [
+                ('destroy_all_by_project', context, 'test_project'),
+                ])
+
+    def test_expire(self):
+        context = FakeContext(None, None)
+        driver = FakeDriver()
+        quota_obj = self._make_quota_obj(driver)
+        quota_obj.expire(context)
+
+        self.assertEqual(driver.called, [
+                ('expire', context),
+                ])
+
+    def test_resources(self):
+        quota_obj = self._make_quota_obj(None)
+
+        self.assertEqual(quota_obj.resources,
+                         ['test_resource1', 'test_resource2',
+                          'test_resource3', 'test_resource4'])
+
+
+class DbQuotaDriverTestCase(test.TestCase):
+    def setUp(self):
+        super(DbQuotaDriverTestCase, self).setUp()
+
+        self.flags(quota_volumes=10,
+                   quota_gigabytes=1000,
+                   reservation_expire=86400,
+                   until_refresh=0,
+                   max_age=0,
+                   )
+
+        self.driver = quota.DbQuotaDriver()
+
+        self.calls = []
+
+        timeutils.set_time_override()
+
+    def tearDown(self):
+        timeutils.clear_time_override()
+        super(DbQuotaDriverTestCase, self).tearDown()
+
+    def test_get_defaults(self):
+        # Use our pre-defined resources
+        result = self.driver.get_defaults(None, quota.QUOTAS._resources)
+
+        self.assertEqual(result, dict(
+                volumes=10,
+                gigabytes=1000,
+                ))
+
+    def _stub_quota_class_get_all_by_name(self):
+        # Stub out quota_class_get_all_by_name
+        def fake_qcgabn(context, quota_class):
+            self.calls.append('quota_class_get_all_by_name')
+            self.assertEqual(quota_class, 'test_class')
+            return dict(
+                gigabytes=500,
+                volumes=10,
+                )
+        self.stubs.Set(db, 'quota_class_get_all_by_name', fake_qcgabn)
+
+    def test_get_class_quotas(self):
+        self._stub_quota_class_get_all_by_name()
+        result = self.driver.get_class_quotas(None, quota.QUOTAS._resources,
+                                              'test_class')
+
+        self.assertEqual(self.calls, ['quota_class_get_all_by_name'])
+        self.assertEqual(result, dict(
+                volumes=10,
+                gigabytes=500,
+                ))
+
+    def test_get_class_quotas_no_defaults(self):
+        self._stub_quota_class_get_all_by_name()
+        result = self.driver.get_class_quotas(None, quota.QUOTAS._resources,
+                                              'test_class', False)
+
+        self.assertEqual(self.calls, ['quota_class_get_all_by_name'])
+        self.assertEqual(result, dict(
+                volumes=10,
+                gigabytes=500,
+                ))
+
+    def _stub_get_by_project(self):
+        def fake_qgabp(context, project_id):
+            self.calls.append('quota_get_all_by_project')
+            self.assertEqual(project_id, 'test_project')
+            return dict(
+                volumes=10,
+                gigabytes=50,
+                reserved=0
+                )
+
+        def fake_qugabp(context, project_id):
+            self.calls.append('quota_usage_get_all_by_project')
+            self.assertEqual(project_id, 'test_project')
+            return dict(
+                volumes=dict(in_use=2, reserved=0),
+                gigabytes=dict(in_use=10, reserved=0),
+                )
+
+        self.stubs.Set(db, 'quota_get_all_by_project', fake_qgabp)
+        self.stubs.Set(db, 'quota_usage_get_all_by_project', fake_qugabp)
+
+        self._stub_quota_class_get_all_by_name()
+
+    def test_get_project_quotas(self):
+        self._stub_get_by_project()
+        result = self.driver.get_project_quotas(
+            FakeContext('test_project', 'test_class'),
+            quota.QUOTAS._resources, 'test_project')
+
+        self.assertEqual(self.calls, [
+                'quota_get_all_by_project',
+                'quota_usage_get_all_by_project',
+                'quota_class_get_all_by_name',
+                ])
+        self.assertEqual(result, dict(
+                volumes=dict(
+                    limit=10,
+                    in_use=2,
+                    reserved=0,
+                    ),
+                gigabytes=dict(
+                    limit=50,
+                    in_use=10,
+                    reserved=0,
+                    ),
+                ))
+
+    def test_get_project_quotas_alt_context_no_class(self):
+        self._stub_get_by_project()
+        result = self.driver.get_project_quotas(
+            FakeContext('other_project', 'other_class'),
+            quota.QUOTAS._resources, 'test_project')
+
+        self.assertEqual(self.calls, [
+                'quota_get_all_by_project',
+                'quota_usage_get_all_by_project',
+                ])
+        self.assertEqual(result, dict(
+                volumes=dict(
+                    limit=10,
+                    in_use=2,
+                    reserved=0,
+                    ),
+                gigabytes=dict(
+                    limit=50,
+                    in_use=10,
+                    reserved=0,
+                    ),
+                ))
+
+    def test_get_project_quotas_alt_context_with_class(self):
+        self._stub_get_by_project()
+        result = self.driver.get_project_quotas(
+            FakeContext('other_project', 'other_class'),
+            quota.QUOTAS._resources, 'test_project', quota_class='test_class')
+
+        self.assertEqual(self.calls, [
+                'quota_get_all_by_project',
+                'quota_usage_get_all_by_project',
+                'quota_class_get_all_by_name',
+                ])
+        self.assertEqual(result, dict(
+                volumes=dict(
+                    limit=10,
+                    in_use=2,
+                    reserved=0,
+                    ),
+                gigabytes=dict(
+                    limit=50,
+                    in_use=10,
+                    reserved=0,
+                    ),
+                ))
+
+    def test_get_project_quotas_no_defaults(self):
+        self._stub_get_by_project()
+        result = self.driver.get_project_quotas(
+            FakeContext('test_project', 'test_class'),
+            quota.QUOTAS._resources, 'test_project', defaults=False)
+
+        self.assertEqual(self.calls, [
+                'quota_get_all_by_project',
+                'quota_usage_get_all_by_project',
+                'quota_class_get_all_by_name',
+                ])
+        self.assertEqual(result, dict(
+                gigabytes=dict(
+                    limit=50,
+                    in_use=10,
+                    reserved=0,
+                    ),
+                volumes=dict(
+                    limit=10,
+                    in_use=2,
+                    reserved=0,
+                    ),
+                ))
+
+    def test_get_project_quotas_no_usages(self):
+        self._stub_get_by_project()
+        result = self.driver.get_project_quotas(
+            FakeContext('test_project', 'test_class'),
+            quota.QUOTAS._resources, 'test_project', usages=False)
+
+        self.assertEqual(self.calls, [
+                'quota_get_all_by_project',
+                'quota_class_get_all_by_name',
+                ])
+        self.assertEqual(result, dict(
+                volumes=dict(
+                    limit=10,
+                    ),
+                gigabytes=dict(
+                    limit=50,
+                    ),
+                ))
+
+    def _stub_get_project_quotas(self):
+        def fake_get_project_quotas(context, resources, project_id,
+                                    quota_class=None, defaults=True,
+                                    usages=True):
+            self.calls.append('get_project_quotas')
+            return dict((k, dict(limit=v.default))
+                        for k, v in resources.items())
+
+        self.stubs.Set(self.driver, 'get_project_quotas',
+                       fake_get_project_quotas)
+
+    def test_get_quotas_has_sync_unknown(self):
+        self._stub_get_project_quotas()
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          self.driver._get_quotas,
+                          None, quota.QUOTAS._resources,
+                          ['unknown'], True)
+        self.assertEqual(self.calls, [])
+
+    def test_get_quotas_no_sync_unknown(self):
+        self._stub_get_project_quotas()
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          self.driver._get_quotas,
+                          None, quota.QUOTAS._resources,
+                          ['unknown'], False)
+        self.assertEqual(self.calls, [])
+
+    def test_get_quotas_has_sync_no_sync_resource(self):
+        self._stub_get_project_quotas()
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          self.driver._get_quotas,
+                          None, quota.QUOTAS._resources,
+                          ['metadata_items'], True)
+        self.assertEqual(self.calls, [])
+
+    def test_get_quotas_no_sync_has_sync_resource(self):
+        self._stub_get_project_quotas()
+        self.assertRaises(exception.QuotaResourceUnknown,
+                          self.driver._get_quotas,
+                          None, quota.QUOTAS._resources,
+                          ['volumes'], False)
+        self.assertEqual(self.calls, [])
+
+    def test_get_quotas_has_sync(self):
+        self._stub_get_project_quotas()
+        result = self.driver._get_quotas(FakeContext('test_project',
+                                                     'test_class'),
+                                         quota.QUOTAS._resources,
+                                         ['volumes', 'gigabytes'],
+                                         True)
+
+        self.assertEqual(self.calls, ['get_project_quotas'])
+        self.assertEqual(result, dict(
+                volumes=10,
+                gigabytes=1000,
+                ))
+
+    def _stub_quota_reserve(self):
+        def fake_quota_reserve(context, resources, quotas, deltas, expire,
+                               until_refresh, max_age):
+            self.calls.append(('quota_reserve', expire, until_refresh,
+                               max_age))
+            return ['resv-1', 'resv-2', 'resv-3']
+        self.stubs.Set(db, 'quota_reserve', fake_quota_reserve)
+
+    def test_reserve_bad_expire(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        self.assertRaises(exception.InvalidReservationExpiration,
+                          self.driver.reserve,
+                          FakeContext('test_project', 'test_class'),
+                          quota.QUOTAS._resources,
+                          dict(volumes=2), expire='invalid')
+        self.assertEqual(self.calls, [])
+
+    def test_reserve_default_expire(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2))
+
+        expire = timeutils.utcnow() + datetime.timedelta(seconds=86400)
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 0, 0),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+    def test_reserve_int_expire(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2), expire=3600)
+
+        expire = timeutils.utcnow() + datetime.timedelta(seconds=3600)
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 0, 0),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+    def test_reserve_timedelta_expire(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        expire_delta = datetime.timedelta(seconds=60)
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2), expire=expire_delta)
+
+        expire = timeutils.utcnow() + expire_delta
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 0, 0),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+    def test_reserve_datetime_expire(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        expire = timeutils.utcnow() + datetime.timedelta(seconds=120)
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2), expire=expire)
+
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 0, 0),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+    def test_reserve_until_refresh(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        self.flags(until_refresh=500)
+        expire = timeutils.utcnow() + datetime.timedelta(seconds=120)
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2), expire=expire)
+
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 500, 0),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+    def test_reserve_max_age(self):
+        self._stub_get_project_quotas()
+        self._stub_quota_reserve()
+        self.flags(max_age=86400)
+        expire = timeutils.utcnow() + datetime.timedelta(seconds=120)
+        result = self.driver.reserve(FakeContext('test_project', 'test_class'),
+                                     quota.QUOTAS._resources,
+                                     dict(volumes=2), expire=expire)
+
+        self.assertEqual(self.calls, [
+                'get_project_quotas',
+                ('quota_reserve', expire, 0, 86400),
+                ])
+        self.assertEqual(result, ['resv-1', 'resv-2', 'resv-3'])
+
+
+class FakeSession(object):
+    def begin(self):
+        return self
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_traceback):
+        return False
+
+
+class FakeUsage(sqa_models.QuotaUsage):
+    def save(self, *args, **kwargs):
+        pass
+
+
+class QuotaReserveSqlAlchemyTestCase(test.TestCase):
+    # cinder.db.sqlalchemy.api.quota_reserve is so complex it needs its
+    # own test case, and since it's a quota manipulator, this is the
+    # best place to put it...
+
+    def setUp(self):
+        super(QuotaReserveSqlAlchemyTestCase, self).setUp()
+
+        self.sync_called = set()
+
+        def make_sync(res_name):
+            def sync(context, project_id, session):
+                self.sync_called.add(res_name)
+                if res_name in self.usages:
+                    if self.usages[res_name].in_use < 0:
+                        return {res_name: 2}
+                    else:
+                        return {res_name: self.usages[res_name].in_use - 1}
+                return {res_name: 0}
+            return sync
+
+        self.resources = {}
+        for res_name in ('volumes', 'gigabytes'):
+            res = quota.ReservableResource(res_name, make_sync(res_name))
+            self.resources[res_name] = res
+
+        self.expire = timeutils.utcnow() + datetime.timedelta(seconds=3600)
+
+        self.usages = {}
+        self.usages_created = {}
+        self.reservations_created = {}
+
+        def fake_get_session():
+            return FakeSession()
+
+        def fake_get_quota_usages(context, session):
+            return self.usages.copy()
+
+        def fake_quota_usage_create(context, project_id, resource, in_use,
+                                    reserved, until_refresh, session=None,
+                                    save=True):
+            quota_usage_ref = self._make_quota_usage(
+                project_id, resource, in_use, reserved, until_refresh,
+                timeutils.utcnow(), timeutils.utcnow())
+
+            self.usages_created[resource] = quota_usage_ref
+
+            return quota_usage_ref
+
+        def fake_reservation_create(context, uuid, usage_id, project_id,
+                                    resource, delta, expire, session=None):
+            reservation_ref = self._make_reservation(
+                uuid, usage_id, project_id, resource, delta, expire,
+                timeutils.utcnow(), timeutils.utcnow())
+
+            self.reservations_created[resource] = reservation_ref
+
+            return reservation_ref
+
+        self.stubs.Set(sqa_api, 'get_session', fake_get_session)
+        self.stubs.Set(sqa_api, '_get_quota_usages', fake_get_quota_usages)
+        self.stubs.Set(sqa_api, 'quota_usage_create', fake_quota_usage_create)
+        self.stubs.Set(sqa_api, 'reservation_create', fake_reservation_create)
+
+        timeutils.set_time_override()
+
+    def _make_quota_usage(self, project_id, resource, in_use, reserved,
+                          until_refresh, created_at, updated_at):
+        quota_usage_ref = FakeUsage()
+        quota_usage_ref.id = len(self.usages) + len(self.usages_created)
+        quota_usage_ref.project_id = project_id
+        quota_usage_ref.resource = resource
+        quota_usage_ref.in_use = in_use
+        quota_usage_ref.reserved = reserved
+        quota_usage_ref.until_refresh = until_refresh
+        quota_usage_ref.created_at = created_at
+        quota_usage_ref.updated_at = updated_at
+        quota_usage_ref.deleted_at = None
+        quota_usage_ref.deleted = False
+
+        return quota_usage_ref
+
+    def init_usage(self, project_id, resource, in_use, reserved,
+                   until_refresh=None, created_at=None, updated_at=None):
+        if created_at is None:
+            created_at = timeutils.utcnow()
+        if updated_at is None:
+            updated_at = timeutils.utcnow()
+
+        quota_usage_ref = self._make_quota_usage(project_id, resource, in_use,
+                                                 reserved, until_refresh,
+                                                 created_at, updated_at)
+
+        self.usages[resource] = quota_usage_ref
+
+    def compare_usage(self, usage_dict, expected):
+        for usage in expected:
+            resource = usage['resource']
+            for key, value in usage.items():
+                actual = getattr(usage_dict[resource], key)
+                self.assertEqual(actual, value,
+                                 "%s != %s on usage for resource %s" %
+                                 (actual, value, resource))
+
+    def _make_reservation(self, uuid, usage_id, project_id, resource,
+                          delta, expire, created_at, updated_at):
+        reservation_ref = sqa_models.Reservation()
+        reservation_ref.id = len(self.reservations_created)
+        reservation_ref.uuid = uuid
+        reservation_ref.usage_id = usage_id
+        reservation_ref.project_id = project_id
+        reservation_ref.resource = resource
+        reservation_ref.delta = delta
+        reservation_ref.expire = expire
+        reservation_ref.created_at = created_at
+        reservation_ref.updated_at = updated_at
+        reservation_ref.deleted_at = None
+        reservation_ref.deleted = False
+
+        return reservation_ref
+
+    def compare_reservation(self, reservations, expected):
+        reservations = set(reservations)
+        for resv in expected:
+            resource = resv['resource']
+            resv_obj = self.reservations_created[resource]
+
+            self.assertIn(resv_obj.uuid, reservations)
+            reservations.discard(resv_obj.uuid)
+
+            for key, value in resv.items():
+                actual = getattr(resv_obj, key)
+                self.assertEqual(actual, value,
+                                 "%s != %s on reservation for resource %s" %
+                                 (actual, value, resource))
+
+        self.assertEqual(len(reservations), 0)
+
+    def test_quota_reserve_create_usages(self):
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 0, 0)
+
+        self.assertEqual(self.sync_called, set(['volumes', 'gigabytes']))
+        self.compare_usage(self.usages_created, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=0,
+                     reserved=2,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=0,
+                     reserved=2 * 1024,
+                     until_refresh=None),
+                ])
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages_created['volumes'],
+                     project_id='test_project',
+                     delta=2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages_created['gigabytes'],
+                     delta=2 * 1024),
+                ])
+
+    def test_quota_reserve_negative_in_use(self):
+        self.init_usage('test_project', 'volumes', -1, 0, until_refresh=1)
+        self.init_usage('test_project', 'gigabytes', -1, 0, until_refresh=1)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 5, 0)
+
+        self.assertEqual(self.sync_called, set(['volumes', 'gigabytes']))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2,
+                     until_refresh=5),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2 * 1024,
+                     until_refresh=5),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     delta=2 * 1024),
+                ])
+
+    def test_quota_reserve_until_refresh(self):
+        self.init_usage('test_project', 'volumes', 3, 0, until_refresh=1)
+        self.init_usage('test_project', 'gigabytes', 3, 0, until_refresh=1)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 5, 0)
+
+        self.assertEqual(self.sync_called, set(['volumes', 'gigabytes']))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2,
+                     until_refresh=5),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2 * 1024,
+                     until_refresh=5),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     delta=2 * 1024),
+                ])
+
+    def test_quota_reserve_max_age(self):
+        max_age = 3600
+        record_created = (timeutils.utcnow() -
+                          datetime.timedelta(seconds=max_age))
+        self.init_usage('test_project', 'volumes', 3, 0,
+                        created_at=record_created, updated_at=record_created)
+        self.init_usage('test_project', 'gigabytes', 3, 0,
+                        created_at=record_created, updated_at=record_created)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 0, max_age)
+
+        self.assertEqual(self.sync_called, set(['volumes', 'gigabytes']))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=2,
+                     reserved=2 * 1024,
+                     until_refresh=None),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     delta=2 * 1024),
+                ])
+
+    def test_quota_reserve_no_refresh(self):
+        self.init_usage('test_project', 'volumes', 3, 0)
+        self.init_usage('test_project', 'gigabytes', 3, 0)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 0, 0)
+
+        self.assertEqual(self.sync_called, set([]))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=3,
+                     reserved=2,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=3,
+                     reserved=2 * 1024,
+                     until_refresh=None),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     delta=2 * 1024),
+                ])
+
+    def test_quota_reserve_unders(self):
+        self.init_usage('test_project', 'volumes', 1, 0)
+        self.init_usage('test_project', 'gigabytes', 1 * 1024, 0)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=-2,
+            gigabytes=-2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 0, 0)
+
+        self.assertEqual(self.sync_called, set([]))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=1,
+                     reserved=0,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=1 * 1024,
+                     reserved=0,
+                     until_refresh=None),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=-2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     delta=-2 * 1024),
+                ])
+
+    def test_quota_reserve_overs(self):
+        self.init_usage('test_project', 'volumes', 4, 0)
+        self.init_usage('test_project', 'gigabytes', 10 * 1024, 0)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=2,
+            gigabytes=2 * 1024,
+            )
+        self.assertRaises(exception.OverQuota,
+                          sqa_api.quota_reserve,
+                          context, self.resources, quotas,
+                          deltas, self.expire, 0, 0)
+
+        self.assertEqual(self.sync_called, set([]))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=4,
+                     reserved=0,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=10 * 1024,
+                     reserved=0,
+                     until_refresh=None),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.assertEqual(self.reservations_created, {})
+
+    def test_quota_reserve_reduction(self):
+        self.init_usage('test_project', 'volumes', 10, 0)
+        self.init_usage('test_project', 'gigabytes', 20 * 1024, 0)
+        context = FakeContext('test_project', 'test_class')
+        quotas = dict(
+            volumes=5,
+            gigabytes=10 * 1024,
+            )
+        deltas = dict(
+            volumes=-2,
+            gigabytes=-2 * 1024,
+            )
+        result = sqa_api.quota_reserve(context, self.resources, quotas,
+                                       deltas, self.expire, 0, 0)
+
+        self.assertEqual(self.sync_called, set([]))
+        self.compare_usage(self.usages, [
+                dict(resource='volumes',
+                     project_id='test_project',
+                     in_use=10,
+                     reserved=0,
+                     until_refresh=None),
+                dict(resource='gigabytes',
+                     project_id='test_project',
+                     in_use=20 * 1024,
+                     reserved=0,
+                     until_refresh=None),
+                ])
+        self.assertEqual(self.usages_created, {})
+        self.compare_reservation(result, [
+                dict(resource='volumes',
+                     usage_id=self.usages['volumes'],
+                     project_id='test_project',
+                     delta=-2),
+                dict(resource='gigabytes',
+                     usage_id=self.usages['gigabytes'],
+                     project_id='test_project',
+                     delta=-2 * 1024),
+                ])
index 45ffa73b9d233a40a3f7db50bb2b7ad4dbee10ea..f5f4865a91ca6de043a02ae5a801e9b2f1103a92 100644 (file)
@@ -40,6 +40,7 @@ from cinder import quota
 from cinder import test
 import cinder.volume.api
 
+QUOTAS = quota.QUOTAS
 FLAGS = flags.FLAGS
 
 
@@ -85,6 +86,20 @@ class VolumeTestCase(test.TestCase):
 
     def test_create_delete_volume(self):
         """Test volume can be created and deleted."""
+        # Need to stub out reserve, commit, and rollback
+        def fake_reserve(context, expire=None, **deltas):
+            return ["RESERVATION"]
+
+        def fake_commit(context, reservations):
+            pass
+
+        def fake_rollback(context, reservations):
+            pass
+
+        self.stubs.Set(QUOTAS, "reserve", fake_reserve)
+        self.stubs.Set(QUOTAS, "commit", fake_commit)
+        self.stubs.Set(QUOTAS, "rollback", fake_rollback)
+
         volume = self._create_volume()
         volume_id = volume['id']
         self.assertEquals(len(test_notifier.NOTIFICATIONS), 0)
@@ -554,10 +569,18 @@ class VolumeTestCase(test.TestCase):
             os.unlink(dst_path)
 
     def _do_test_create_volume_with_size(self, size):
-        def fake_allowed_volumes(context, requested_volumes, size):
-            return requested_volumes
+        def fake_reserve(context, expire=None, **deltas):
+            return ["RESERVATION"]
+
+        def fake_commit(context, reservations):
+            pass
+
+        def fake_rollback(context, reservations):
+            pass
 
-        self.stubs.Set(quota, 'allowed_volumes', fake_allowed_volumes)
+        self.stubs.Set(QUOTAS, "reserve", fake_reserve)
+        self.stubs.Set(QUOTAS, "commit", fake_commit)
+        self.stubs.Set(QUOTAS, "rollback", fake_rollback)
 
         volume_api = cinder.volume.api.API()
 
@@ -576,10 +599,18 @@ class VolumeTestCase(test.TestCase):
         self._do_test_create_volume_with_size('2')
 
     def test_create_volume_with_bad_size(self):
-        def fake_allowed_volumes(context, requested_volumes, size):
-            return requested_volumes
+        def fake_reserve(context, expire=None, **deltas):
+            return ["RESERVATION"]
+
+        def fake_commit(context, reservations):
+            pass
+
+        def fake_rollback(context, reservations):
+            pass
 
-        self.stubs.Set(quota, 'allowed_volumes', fake_allowed_volumes)
+        self.stubs.Set(QUOTAS, "reserve", fake_reserve)
+        self.stubs.Set(QUOTAS, "commit", fake_commit)
+        self.stubs.Set(QUOTAS, "rollback", fake_rollback)
 
         volume_api = cinder.volume.api.API()
 
index 480086c7d00ff23e0dd45eea76c3aca73f14c9c6..a13d801de70217c26580a77906dafd1bc1ab9c0c 100644 (file)
@@ -22,18 +22,17 @@ Handles all requests relating to volumes.
 
 import functools
 
-from eventlet import greenthread
-
+from cinder.db import base
 from cinder import exception
 from cinder import flags
 from cinder.openstack.common import cfg
 from cinder.image import glance
 from cinder.openstack.common import log as logging
 from cinder.openstack.common import rpc
-import cinder.policy
 from cinder.openstack.common import timeutils
+import cinder.policy
 from cinder import quota
-from cinder.db import base
+
 
 volume_host_opt = cfg.BoolOpt('snapshot_same_host',
         default=True,
@@ -45,6 +44,7 @@ flags.DECLARE('storage_availability_zone', 'cinder.volume.manager')
 
 LOG = logging.getLogger(__name__)
 GB = 1048576 * 1024
+QUOTAS = quota.QUOTAS
 
 
 def wrap_check_policy(func):
@@ -107,11 +107,30 @@ class API(base.Base):
             msg = (_("Volume size '%s' must be an integer and greater than 0")
                    % size)
             raise exception.InvalidInput(reason=msg)
-        if quota.allowed_volumes(context, 1, size) < 1:
+        try:
+            reservations = QUOTAS.reserve(context, volumes=1, gigabytes=size)
+        except exception.OverQuota as e:
+            overs = e.kwargs['overs']
+            usages = e.kwargs['usages']
+            quotas = e.kwargs['quotas']
+
+            def _consumed(name):
+                return (usages[name]['reserved'] + usages[name]['in_use'])
+
             pid = context.project_id
-            LOG.warn(_("Quota exceeded for %(pid)s, tried to create"
-                    " %(size)sG volume") % locals())
-            raise exception.QuotaError(code="VolumeSizeTooLarge")
+            if 'gigabytes' in overs:
+                consumed = _consumed('gigabytes')
+                quota = quotas['gigabytes']
+                LOG.warn(_("Quota exceeded for %(pid)s, tried to create "
+                           "%(size)sG volume (%(consumed)dG of %(quota)dG "
+                           "already consumed)") % locals())
+                raise exception.VolumeSizeExceedsAvailableQuota()
+            elif 'volumes' in overs:
+                consumed = _consumed('volumes')
+                LOG.warn(_("Quota exceeded for %(pid)s, tried to create "
+                           "volume (%(consumed)d volumes already consumed)")
+                           % locals())
+                raise exception.VolumeLimitExceeded(allowed=quotas['volumes'])
 
         if image_id:
             # check image existence
@@ -143,6 +162,7 @@ class API(base.Base):
             'volume_type_id': volume_type_id,
             'metadata': metadata,
             }
+
         volume = self.db.volume_create(context, options)
         rpc.cast(context,
                  FLAGS.scheduler_topic,
@@ -153,7 +173,8 @@ class API(base.Base):
                            "image_id": image_id}})
         return volume
 
-    def _cast_create_volume(self, context, volume_id, snapshot_id):
+    def _cast_create_volume(self, context, volume_id,
+                            snapshot_id, reservations):
 
         # NOTE(Rongze Zhu): It is a simple solution for bug 1008866
         # If snapshot_id is set, make the call create volume directly to
@@ -178,7 +199,8 @@ class API(base.Base):
                      {"method": "create_volume",
                       "args": {"topic": FLAGS.volume_topic,
                                "volume_id": volume_id,
-                               "snapshot_id": snapshot_id}})
+                               "snapshot_id": snapshot_id,
+                               "reservations": reservations}})
 
     @wrap_check_policy
     def delete(self, context, volume):
@@ -416,7 +438,7 @@ class API(base.Base):
 
     @wrap_check_policy
     def delete_volume_metadata(self, context, volume, key):
-        """Delete the given metadata item from an volume."""
+        """Delete the given metadata item from a volume."""
         self.db.volume_metadata_delete(context, volume['id'], key)
 
     @wrap_check_policy
index 660f29356e7e350be9fec9219b9b967521c20c24..9f7383a4cc9e993a5297bed575fb05df4d4aaa10 100644 (file)
 
     "volume_extension:types_manage": [["rule:admin_api"]],
     "volume_extension:types_extra_specs": [["rule:admin_api"]],
-    "volume_extension:extended_snapshot_attributes": []
+    "volume_extension:extended_snapshot_attributes": [],
+
+    "volume_extension:quotas:show": [],
+    "volume_extension:quotas:update_for_project": [["rule:admin_api"]],
+    "volume_extension:quotas:update_for_user": [["rule:admin_or_projectadmin"]],
+    "volume_extension:quota_classes": []
+
 }