]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Don't pass 'session' arg to public DB API methods
authorRoman Podolyaka <rpodolyaka@mirantis.com>
Thu, 27 Jun 2013 15:34:21 +0000 (18:34 +0300)
committerRoman Podolyaka <rpodolyaka@mirantis.com>
Mon, 8 Jul 2013 14:48:01 +0000 (17:48 +0300)
DB API is an abstraction layer, which is used to make it
possible to switch DB backends easily (though we've got only
SQLAlchemy backend at the moment).

Public methods of DB API should not accept any backend-specific
arguments (i. e. a Session instance, that is an SQLAlchemy entity
to work with DB transactions).

This patch removes 'session' argument from all DB API public methods
(except volume_data_get_for_project() and snapshot_data_get_fro_project(),
which are a bit tricky and will be fixed by another patch).

If a DB API method must be called by another one in the context of a started
transaction, a private method is used. It accepts the same arguments as the
corresponding public method plus one additional argument to pass the transactional
context (in case of SQLAlchemy backend it's a Session instance).

Blueprint: db-session-cleanup

Change-Id: Iabe7ea834ec07f6520614de2461b9ad1ab7a7ac2

cinder/db/api.py
cinder/db/sqlalchemy/api.py
cinder/tests/test_quota.py

index 598ed67debf4e76fc8b6808039fdc1c298b8ee92..0cd60c1c04e3fa08180ce367f994b6da31ba2b0e 100644 (file)
@@ -210,11 +210,10 @@ def volume_create(context, values):
     return IMPL.volume_create(context, values)
 
 
-def volume_data_get_for_host(context, host, session=None):
+def volume_data_get_for_host(context, host):
     """Get (volume_count, gigabytes) for project."""
     return IMPL.volume_data_get_for_host(context,
-                                         host,
-                                         session)
+                                         host)
 
 
 def volume_data_get_for_project(context, project_id, volume_type_id=None,
index 60105cb7b5f2f99f953144ec335653f041e92b9c..7259db9a20dd653be88118bd6405b98844f29f29 100644 (file)
@@ -248,12 +248,12 @@ def exact_filter(query, model, filters, legal_keys):
 def service_destroy(context, service_id):
     session = get_session()
     with session.begin():
-        service_ref = service_get(context, service_id, session=session)
+        service_ref = _service_get(context, service_id, session=session)
         service_ref.delete(session=session)
 
 
 @require_admin_context
-def service_get(context, service_id, session=None):
+def _service_get(context, service_id, session=None):
     result = model_query(
         context,
         models.Service,
@@ -266,6 +266,11 @@ def service_get(context, service_id, session=None):
     return result
 
 
+@require_admin_context
+def service_get(context, service_id):
+    return _service_get(context, service_id)
+
+
 @require_admin_context
 def service_get_all(context, disabled=None):
     query = model_query(context, models.Service)
@@ -364,7 +369,7 @@ def service_create(context, values):
 def service_update(context, service_id, values):
     session = get_session()
     with session.begin():
-        service_ref = service_get(context, service_id, session=session)
+        service_ref = _service_get(context, service_id, session=session)
         service_ref.update(values)
         service_ref.save(session=session)
 
@@ -429,7 +434,7 @@ def iscsi_target_create_safe(context, values):
 
 
 @require_context
-def quota_get(context, project_id, resource, session=None):
+def _quota_get(context, project_id, resource, session=None):
     result = model_query(context, models.Quota, session=session,
                          read_deleted="no").\
         filter_by(project_id=project_id).\
@@ -442,6 +447,11 @@ def quota_get(context, project_id, resource, session=None):
     return result
 
 
+@require_context
+def quota_get(context, project_id, resource):
+    return _quota_get(context, project_id, resource)
+
+
 @require_context
 def quota_get_all_by_project(context, project_id):
     authorize_project_context(context, project_id)
@@ -471,7 +481,7 @@ def quota_create(context, project_id, resource, limit):
 def quota_update(context, project_id, resource, limit):
     session = get_session()
     with session.begin():
-        quota_ref = quota_get(context, project_id, resource, session=session)
+        quota_ref = _quota_get(context, project_id, resource, session=session)
         quota_ref.hard_limit = limit
         quota_ref.save(session=session)
 
@@ -480,7 +490,7 @@ def quota_update(context, project_id, resource, limit):
 def quota_destroy(context, project_id, resource):
     session = get_session()
     with session.begin():
-        quota_ref = quota_get(context, project_id, resource, session=session)
+        quota_ref = _quota_get(context, project_id, resource, session=session)
         quota_ref.delete(session=session)
 
 
@@ -488,7 +498,7 @@ def quota_destroy(context, project_id, resource):
 
 
 @require_context
-def quota_class_get(context, class_name, resource, session=None):
+def _quota_class_get(context, class_name, resource, session=None):
     result = model_query(context, models.QuotaClass, session=session,
                          read_deleted="no").\
         filter_by(class_name=class_name).\
@@ -501,6 +511,11 @@ def quota_class_get(context, class_name, resource, session=None):
     return result
 
 
+@require_context
+def quota_class_get(context, class_name, resource):
+    return _quota_class_get(context, class_name, resource)
+
+
 def quota_class_get_default(context):
     rows = model_query(context, models.QuotaClass,
                        read_deleted="no").\
@@ -542,8 +557,8 @@ def quota_class_create(context, class_name, resource, limit):
 def quota_class_update(context, class_name, resource, limit):
     session = get_session()
     with session.begin():
-        quota_class_ref = quota_class_get(context, class_name, resource,
-                                          session=session)
+        quota_class_ref = _quota_class_get(context, class_name, resource,
+                                           session=session)
         quota_class_ref.hard_limit = limit
         quota_class_ref.save(session=session)
 
@@ -552,8 +567,8 @@ def quota_class_update(context, class_name, resource, limit):
 def quota_class_destroy(context, class_name, resource):
     session = get_session()
     with session.begin():
-        quota_class_ref = quota_class_get(context, class_name, resource,
-                                          session=session)
+        quota_class_ref = _quota_class_get(context, class_name, resource,
+                                           session=session)
         quota_class_ref.delete(session=session)
 
 
@@ -574,9 +589,8 @@ def quota_class_destroy_all_by_name(context, class_name):
 
 
 @require_context
-def quota_usage_get(context, project_id, resource, session=None):
-    result = model_query(context, models.QuotaUsage, session=session,
-                         read_deleted="no").\
+def quota_usage_get(context, project_id, resource):
+    result = model_query(context, models.QuotaUsage, read_deleted="no").\
         filter_by(project_id=project_id).\
         filter_by(resource=resource).\
         first()
@@ -603,8 +617,9 @@ def quota_usage_get_all_by_project(context, project_id):
 
 
 @require_admin_context
-def quota_usage_create(context, project_id, resource, in_use, reserved,
-                       until_refresh, session=None):
+def _quota_usage_create(context, project_id, resource, in_use, reserved,
+                        until_refresh, session=None):
+
     quota_usage_ref = models.QuotaUsage()
     quota_usage_ref.project_id = project_id
     quota_usage_ref.resource = resource
@@ -616,11 +631,17 @@ def quota_usage_create(context, project_id, resource, in_use, reserved,
     return quota_usage_ref
 
 
+@require_admin_context
+def quota_usage_create(context, project_id, resource, in_use, reserved,
+                       until_refresh):
+    return _quota_usage_create(context, project_id, resource, in_use, reserved,
+                               until_refresh)
+
 ###################
 
 
 @require_context
-def reservation_get(context, uuid, session=None):
+def _reservation_get(context, uuid, session=None):
     result = model_query(context, models.Reservation, session=session,
                          read_deleted="no").\
         filter_by(uuid=uuid).first()
@@ -631,6 +652,11 @@ def reservation_get(context, uuid, session=None):
     return result
 
 
+@require_context
+def reservation_get(context, uuid):
+    return _reservation_get(context, uuid)
+
+
 @require_context
 def reservation_get_all_by_project(context, project_id):
     authorize_project_context(context, project_id)
@@ -647,8 +673,8 @@ def reservation_get_all_by_project(context, project_id):
 
 
 @require_admin_context
-def reservation_create(context, uuid, usage, project_id, resource, delta,
-                       expire, session=None):
+def _reservation_create(context, uuid, usage, project_id, resource, delta,
+                        expire, session=None):
     reservation_ref = models.Reservation()
     reservation_ref.uuid = uuid
     reservation_ref.usage_id = usage['id']
@@ -660,11 +686,18 @@ def reservation_create(context, uuid, usage, project_id, resource, delta,
     return reservation_ref
 
 
+@require_admin_context
+def reservation_create(context, uuid, usage, project_id, resource, delta,
+                       expire):
+    return _reservation_create(context, uuid, usage, project_id, resource,
+                               delta, expire)
+
+
 @require_admin_context
 def reservation_destroy(context, uuid):
     session = get_session()
     with session.begin():
-        reservation_ref = reservation_get(context, uuid, session=session)
+        reservation_ref = _reservation_get(context, uuid, session=session)
         reservation_ref.delete(session=session)
 
 
@@ -707,12 +740,12 @@ def quota_reserve(context, resources, quotas, deltas, expire,
             # Do we need to refresh the usage?
             refresh = False
             if resource not in usages:
-                usages[resource] = quota_usage_create(elevated,
-                                                      project_id,
-                                                      resource,
-                                                      0, 0,
-                                                      until_refresh or None,
-                                                      session=session)
+                usages[resource] = _quota_usage_create(elevated,
+                                                       project_id,
+                                                       resource,
+                                                       0, 0,
+                                                       until_refresh or None,
+                                                       session=session)
                 refresh = True
             elif usages[resource].in_use < 0:
                 # Negative in_use count indicates a desync, so try to
@@ -735,12 +768,14 @@ def quota_reserve(context, resources, quotas, deltas, expire,
                 for res, in_use in updates.items():
                     # Make sure we have a destination for the usage!
                     if res not in usages:
-                        usages[res] = quota_usage_create(elevated,
-                                                         project_id,
-                                                         res,
-                                                         0, 0,
-                                                         until_refresh or None,
-                                                         session=session)
+                        usages[res] = _quota_usage_create(
+                            elevated,
+                            project_id,
+                            res,
+                            0, 0,
+                            until_refresh or None,
+                            session=session
+                        )
 
                     # Update the usage
                     usages[res].in_use = in_use
@@ -782,12 +817,12 @@ def quota_reserve(context, resources, quotas, deltas, expire,
         if not overs:
             reservations = []
             for resource, delta in deltas.items():
-                reservation = reservation_create(elevated,
-                                                 str(uuid.uuid4()),
-                                                 usages[resource],
-                                                 project_id,
-                                                 resource, delta, expire,
-                                                 session=session)
+                reservation = _reservation_create(elevated,
+                                                  str(uuid.uuid4()),
+                                                  usages[resource],
+                                                  project_id,
+                                                  resource, delta, expire,
+                                                  session=session)
                 reservations.append(reservation.uuid)
 
                 # Also update the reserved quantity
@@ -948,7 +983,7 @@ def volume_attached(context, volume_id, instance_uuid, host_name, mountpoint):
 
     session = get_session()
     with session.begin():
-        volume_ref = volume_get(context, volume_id, session=session)
+        volume_ref = _volume_get(context, volume_id, session=session)
         volume_ref['status'] = 'in-use'
         volume_ref['mountpoint'] = mountpoint
         volume_ref['attach_status'] = 'attached'
@@ -970,16 +1005,15 @@ def volume_create(context, values):
     with session.begin():
         volume_ref.save(session=session)
 
-    return volume_get(context, values['id'], session=session)
+    return _volume_get(context, values['id'], session=session)
 
 
 @require_admin_context
-def volume_data_get_for_host(context, host, session=None):
+def volume_data_get_for_host(context, host):
     result = model_query(context,
                          func.count(models.Volume.id),
                          func.sum(models.Volume.size),
-                         read_deleted="no",
-                         session=session).\
+                         read_deleted="no").\
         filter_by(host=host).\
         first()
 
@@ -988,8 +1022,8 @@ def volume_data_get_for_host(context, host, session=None):
 
 
 @require_admin_context
-def volume_data_get_for_project(context, project_id, volume_type_id=None,
-                                session=None):
+def _volume_data_get_for_project(context, project_id, volume_type_id=None,
+                                 session=None):
     query = model_query(context,
                         func.count(models.Volume.id),
                         func.sum(models.Volume.size),
@@ -1006,6 +1040,13 @@ def volume_data_get_for_project(context, project_id, volume_type_id=None,
     return (result[0] or 0, result[1] or 0)
 
 
+@require_admin_context
+def volume_data_get_for_project(context, project_id, volume_type_id=None,
+                                session=None):
+    return _volume_data_get_for_project(context, project_id, volume_type_id,
+                                        session)
+
+
 @require_admin_context
 def volume_destroy(context, volume_id):
     session = get_session()
@@ -1030,7 +1071,7 @@ def volume_destroy(context, volume_id):
 def volume_detached(context, volume_id):
     session = get_session()
     with session.begin():
-        volume_ref = volume_get(context, volume_id, session=session)
+        volume_ref = _volume_get(context, volume_id, session=session)
         volume_ref['status'] = 'available'
         volume_ref['mountpoint'] = None
         volume_ref['attach_status'] = 'detached'
@@ -1048,7 +1089,7 @@ def _volume_get_query(context, session=None, project_only=False):
 
 
 @require_context
-def volume_get(context, volume_id, session=None):
+def _volume_get(context, volume_id, session=None):
     result = _volume_get_query(context, session=session, project_only=True).\
         filter_by(id=volume_id).\
         first()
@@ -1059,13 +1100,18 @@ def volume_get(context, volume_id, session=None):
     return result
 
 
+@require_context
+def volume_get(context, volume_id):
+    return _volume_get(context, volume_id)
+
+
 @require_admin_context
 def volume_get_all(context, marker, limit, sort_key, sort_dir):
     query = _volume_get_query(context)
 
     marker_volume = None
     if marker is not None:
-        marker_volume = volume_get(context, marker)
+        marker_volume = _volume_get(context, marker)
 
     query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
                                            [sort_key, 'created_at', 'id'],
@@ -1102,7 +1148,7 @@ def volume_get_all_by_project(context, project_id, marker, limit, sort_key,
 
     marker_volume = None
     if marker is not None:
-        marker_volume = volume_get(context, marker)
+        marker_volume = _volume_get(context, marker)
 
     query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
                                            [sort_key, 'created_at', 'id'],
@@ -1134,7 +1180,7 @@ def volume_update(context, volume_id, values):
                                values.pop('metadata'),
                                delete=True)
     with session.begin():
-        volume_ref = volume_get(context, volume_id, session=session)
+        volume_ref = _volume_get(context, volume_id, session=session)
         volume_ref.update(values)
         volume_ref.save(session=session)
         return volume_ref
@@ -1171,7 +1217,7 @@ def volume_metadata_delete(context, volume_id, key):
 
 @require_context
 @require_volume_exists
-def volume_metadata_get_item(context, volume_id, key, session=None):
+def _volume_metadata_get_item(context, volume_id, key, session=None):
     result = _volume_metadata_get_query(context, volume_id, session=session).\
         filter_by(key=key).\
         first()
@@ -1182,6 +1228,12 @@ def volume_metadata_get_item(context, volume_id, key, session=None):
     return result
 
 
+@require_context
+@require_volume_exists
+def volume_metadata_get_item(context, volume_id, key):
+    return _volume_metadata_get_item(context, volume_id, key)
+
+
 @require_context
 @require_volume_exists
 def volume_metadata_update(context, volume_id, metadata, delete):
@@ -1192,8 +1244,8 @@ def volume_metadata_update(context, volume_id, metadata, delete):
         original_metadata = volume_metadata_get(context, volume_id)
         for meta_key, meta_value in original_metadata.iteritems():
             if meta_key not in metadata:
-                meta_ref = volume_metadata_get_item(context, volume_id,
-                                                    meta_key, session)
+                meta_ref = _volume_metadata_get_item(context, volume_id,
+                                                     meta_key, session)
                 meta_ref.update({'deleted': True})
                 meta_ref.save(session=session)
 
@@ -1206,8 +1258,8 @@ def volume_metadata_update(context, volume_id, metadata, delete):
         item = {"value": meta_value}
 
         try:
-            meta_ref = volume_metadata_get_item(context, volume_id,
-                                                meta_key, session)
+            meta_ref = _volume_metadata_get_item(context, volume_id,
+                                                 meta_key, session)
         except exception.VolumeMetadataNotFound as e:
             meta_ref = models.VolumeMetadata()
             item.update({"key": meta_key, "volume_id": volume_id})
@@ -1234,7 +1286,7 @@ def snapshot_create(context, values):
     with session.begin():
         snapshot_ref.save(session=session)
 
-    return snapshot_get(context, values['id'], session=session)
+    return _snapshot_get(context, values['id'], session=session)
 
 
 @require_admin_context
@@ -1250,7 +1302,7 @@ def snapshot_destroy(context, snapshot_id):
 
 
 @require_context
-def snapshot_get(context, snapshot_id, session=None):
+def _snapshot_get(context, snapshot_id, session=None):
     result = model_query(context, models.Snapshot, session=session,
                          project_only=True).\
         options(joinedload('volume')).\
@@ -1263,6 +1315,11 @@ def snapshot_get(context, snapshot_id, session=None):
     return result
 
 
+@require_context
+def snapshot_get(context, snapshot_id):
+    return _snapshot_get(context, snapshot_id)
+
+
 @require_admin_context
 def snapshot_get_all(context):
     return model_query(context, models.Snapshot).\
@@ -1289,8 +1346,8 @@ def snapshot_get_all_by_project(context, project_id):
 
 
 @require_context
-def snapshot_data_get_for_project(context, project_id, volume_type_id=None,
-                                  session=None):
+def _snapshot_data_get_for_project(context, project_id, volume_type_id=None,
+                                   session=None):
     authorize_project_context(context, project_id)
     query = model_query(context,
                         func.count(models.Snapshot.id),
@@ -1308,6 +1365,13 @@ def snapshot_data_get_for_project(context, project_id, volume_type_id=None,
     return (result[0] or 0, result[1] or 0)
 
 
+@require_context
+def snapshot_data_get_for_project(context, project_id, volume_type_id=None,
+                                  session=None):
+    return _snapshot_data_get_for_project(context, project_id, volume_type_id,
+                                          session)
+
+
 @require_context
 def snapshot_get_active_by_window(context, begin, end=None, project_id=None):
     """Return snapshots that were active during window."""
@@ -1328,7 +1392,7 @@ def snapshot_get_active_by_window(context, begin, end=None, project_id=None):
 def snapshot_update(context, snapshot_id, values):
     session = get_session()
     with session.begin():
-        snapshot_ref = snapshot_get(context, snapshot_id, session=session)
+        snapshot_ref = _snapshot_get(context, snapshot_id, session=session)
         snapshot_ref.update(values)
         snapshot_ref.save(session=session)
 
@@ -1364,7 +1428,7 @@ def snapshot_metadata_delete(context, snapshot_id, key):
 
 @require_context
 @require_snapshot_exists
-def snapshot_metadata_get_item(context, snapshot_id, key, session=None):
+def _snapshot_metadata_get_item(context, snapshot_id, key, session=None):
     result = _snapshot_metadata_get_query(context,
                                           snapshot_id,
                                           session=session).\
@@ -1377,6 +1441,12 @@ def snapshot_metadata_get_item(context, snapshot_id, key, session=None):
     return result
 
 
+@require_context
+@require_snapshot_exists
+def snapshot_metadata_get_item(context, snapshot_id, key):
+    return _snapshot_metadata_get_item(context, snapshot_id, key)
+
+
 @require_context
 @require_snapshot_exists
 def snapshot_metadata_update(context, snapshot_id, metadata, delete):
@@ -1387,8 +1457,8 @@ def snapshot_metadata_update(context, snapshot_id, metadata, delete):
         original_metadata = snapshot_metadata_get(context, snapshot_id)
         for meta_key, meta_value in original_metadata.iteritems():
             if meta_key not in metadata:
-                meta_ref = snapshot_metadata_get_item(context, snapshot_id,
-                                                      meta_key, session)
+                meta_ref = _snapshot_metadata_get_item(context, snapshot_id,
+                                                       meta_key, session)
                 meta_ref.update({'deleted': True})
                 meta_ref.save(session=session)
 
@@ -1401,8 +1471,8 @@ def snapshot_metadata_update(context, snapshot_id, metadata, delete):
         item = {"value": meta_value}
 
         try:
-            meta_ref = snapshot_metadata_get_item(context, snapshot_id,
-                                                  meta_key, session)
+            meta_ref = _snapshot_metadata_get_item(context, snapshot_id,
+                                                   meta_key, session)
         except exception.SnapshotMetadataNotFound as e:
             meta_ref = models.SnapshotMetadata()
             item.update({"key": meta_key, "snapshot_id": snapshot_id})
@@ -1427,14 +1497,14 @@ def migration_create(context, values):
 def migration_update(context, id, values):
     session = get_session()
     with session.begin():
-        migration = migration_get(context, id, session=session)
+        migration = _migration_get(context, id, session=session)
         migration.update(values)
         migration.save(session=session)
         return migration
 
 
 @require_admin_context
-def migration_get(context, id, session=None):
+def _migration_get(context, id, session=None):
     result = model_query(context, models.Migration, session=session,
                          read_deleted="yes").\
         filter_by(id=id).\
@@ -1446,6 +1516,11 @@ def migration_get(context, id, session=None):
     return result
 
 
+@require_admin_context
+def migration_get(context, id):
+    return _migration_get(context, id)
+
+
 @require_admin_context
 def migration_get_by_instance_and_status(context, instance_uuid, status):
     result = model_query(context, models.Migration, read_deleted="yes").\
@@ -1461,12 +1536,11 @@ def migration_get_by_instance_and_status(context, instance_uuid, status):
 
 
 @require_admin_context
-def migration_get_all_unconfirmed(context, confirm_window, session=None):
+def migration_get_all_unconfirmed(context, confirm_window):
     confirm_window = timeutils.utcnow() - datetime.timedelta(
         seconds=confirm_window)
 
-    return model_query(context, models.Migration, session=session,
-                       read_deleted="yes").\
+    return model_query(context, models.Migration, read_deleted="yes").\
         filter(models.Migration.updated_at <= confirm_window).\
         filter_by(status="finished").\
         all()
@@ -1489,12 +1563,12 @@ def volume_type_create(context, values):
     session = get_session()
     with session.begin():
         try:
-            volume_type_get_by_name(context, values['name'], session)
+            _volume_type_get_by_name(context, values['name'], session)
             raise exception.VolumeTypeExists(id=values['name'])
         except exception.VolumeTypeNotFoundByName:
             pass
         try:
-            volume_type_get(context, values['id'], session)
+            _volume_type_get(context, values['id'], session)
             raise exception.VolumeTypeExists(id=values['id'])
         except exception.VolumeTypeNotFound:
             pass
@@ -1533,8 +1607,7 @@ def volume_type_get_all(context, inactive=False, filters=None):
 
 
 @require_context
-def volume_type_get(context, id, session=None):
-    """Returns a dict describing specific volume_type"""
+def _volume_type_get(context, id, session=None):
     result = model_query(context, models.VolumeTypes, session=session).\
         options(joinedload('extra_specs')).\
         filter_by(id=id).\
@@ -1547,8 +1620,14 @@ def volume_type_get(context, id, session=None):
 
 
 @require_context
-def volume_type_get_by_name(context, name, session=None):
+def volume_type_get(context, id):
     """Returns a dict describing specific volume_type"""
+
+    return _volume_type_get(context, id)
+
+
+@require_context
+def _volume_type_get_by_name(context, name, session=None):
     result = model_query(context, models.VolumeTypes, session=session).\
         options(joinedload('extra_specs')).\
         filter_by(name=name).\
@@ -1560,9 +1639,16 @@ def volume_type_get_by_name(context, name, session=None):
         return _dict_with_extra_specs(result)
 
 
+@require_context
+def volume_type_get_by_name(context, name):
+    """Returns a dict describing specific volume_type"""
+
+    return _volume_type_get_by_name(context, name)
+
+
 @require_admin_context
 def volume_type_destroy(context, id):
-    volume_type_get(context, id)
+    _volume_type_get(context, id)
 
     session = get_session()
     with session.begin():
@@ -1621,7 +1707,7 @@ def volume_type_extra_specs_get(context, volume_type_id):
 @require_context
 def volume_type_extra_specs_delete(context, volume_type_id, key):
     session = get_session()
-    volume_type_extra_specs_get_item(context, volume_type_id, key, session)
+    _volume_type_extra_specs_get_item(context, volume_type_id, key, session)
     _volume_type_extra_specs_query(context, volume_type_id).\
         filter_by(key=key).\
         update({'deleted': True,
@@ -1630,8 +1716,8 @@ def volume_type_extra_specs_delete(context, volume_type_id, key):
 
 
 @require_context
-def volume_type_extra_specs_get_item(context, volume_type_id, key,
-                                     session=None):
+def _volume_type_extra_specs_get_item(context, volume_type_id, key,
+                                      session=None):
     result = _volume_type_extra_specs_query(
         context, volume_type_id, session=session).\
         filter_by(key=key).\
@@ -1645,6 +1731,11 @@ def volume_type_extra_specs_get_item(context, volume_type_id, key,
     return result
 
 
+@require_context
+def volume_type_extra_specs_get_item(context, volume_type_id, key):
+    return _volume_type_extra_specs_get_item(context, volume_type_id, key)
+
+
 @require_context
 def volume_type_extra_specs_update_or_create(context, volume_type_id,
                                              specs):
@@ -1652,7 +1743,7 @@ def volume_type_extra_specs_update_or_create(context, volume_type_id,
     spec_ref = None
     for key, value in specs.iteritems():
         try:
-            spec_ref = volume_type_extra_specs_get_item(
+            spec_ref = _volume_type_extra_specs_get_item(
                 context, volume_type_id, key, session)
         except exception.VolumeTypeExtraSpecsNotFound as e:
             spec_ref = models.VolumeTypeExtraSpecs()
@@ -1668,40 +1759,48 @@ def volume_type_extra_specs_update_or_create(context, volume_type_id,
 
 @require_context
 @require_volume_exists
-def volume_glance_metadata_get(context, volume_id, session=None):
+def _volume_glance_metadata_get(context, volume_id, session=None):
+    return model_query(context, models.VolumeGlanceMetadata, session=session).\
+        filter_by(volume_id=volume_id).\
+        filter_by(deleted=False).\
+        all()
+
+
+@require_context
+@require_volume_exists
+def volume_glance_metadata_get(context, volume_id):
     """Return the Glance metadata for the specified volume."""
-    if not session:
-        session = get_session()
 
-    return session.query(models.VolumeGlanceMetadata).\
-        filter_by(volume_id=volume_id).\
-        filter_by(deleted=False).all()
+    return _volume_glance_metadata_get(context, volume_id)
 
 
 @require_context
 @require_snapshot_exists
-def volume_snapshot_glance_metadata_get(context, snapshot_id, session=None):
+def _volume_snapshot_glance_metadata_get(context, snapshot_id, session=None):
+    return model_query(context, models.VolumeGlanceMetadata, session=session).\
+        filter_by(snapshot_id=snapshot_id).\
+        filter_by(deleted=False).\
+        all()
+
+
+@require_context
+@require_snapshot_exists
+def volume_snapshot_glance_metadata_get(context, snapshot_id):
     """Return the Glance metadata for the specified snapshot."""
-    if not session:
-        session = get_session()
 
-    return session.query(models.VolumeGlanceMetadata).\
-        filter_by(snapshot_id=snapshot_id).\
-        filter_by(deleted=False).all()
+    return _volume_snapshot_glance_metadata_get(context, snapshot_id)
 
 
 @require_context
 @require_volume_exists
-def volume_glance_metadata_create(context, volume_id, key, value,
-                                  session=None):
+def volume_glance_metadata_create(context, volume_id, key, value):
     """
     Update the Glance metadata for a volume by adding a new key:value pair.
     This API does not support changing the value of a key once it has been
     created.
     """
-    if session is None:
-        session = get_session()
 
+    session = get_session()
     with session.begin():
         rows = session.query(models.VolumeGlanceMetadata).\
             filter_by(volume_id=volume_id).\
@@ -1724,17 +1823,15 @@ def volume_glance_metadata_create(context, volume_id, key, value,
 
 @require_context
 @require_snapshot_exists
-def volume_glance_metadata_copy_to_snapshot(context, snapshot_id, volume_id,
-                                            session=None):
+def volume_glance_metadata_copy_to_snapshot(context, snapshot_id, volume_id):
     """
     Update the Glance metadata for a snapshot by copying all of the key:value
     pairs from the originating volume. This is so that a volume created from
     the snapshot will retain the original metadata.
     """
-    if session is None:
-        session = get_session()
 
-    metadata = volume_glance_metadata_get(context, volume_id, session=session)
+    session = get_session()
+    metadata = _volume_glance_metadata_get(context, volume_id, session=session)
     with session.begin():
         for meta in metadata:
             vol_glance_metadata = models.VolumeGlanceMetadata()
@@ -1749,19 +1846,17 @@ def volume_glance_metadata_copy_to_snapshot(context, snapshot_id, volume_id,
 @require_volume_exists
 def volume_glance_metadata_copy_from_volume_to_volume(context,
                                                       src_volume_id,
-                                                      volume_id,
-                                                      session=None):
+                                                      volume_id):
     """
     Update the Glance metadata for a volume by copying all of the key:value
     pairs from the originating volume. This is so that a volume created from
     the volume (clone) will retain the original metadata.
     """
-    if session is None:
-        session = get_session()
 
-    metadata = volume_glance_metadata_get(context,
-                                          src_volume_id,
-                                          session=session)
+    session = get_session()
+    metadata = _volume_glance_metadata_get(context,
+                                           src_volume_id,
+                                           session=session)
     with session.begin():
         for meta in metadata:
             vol_glance_metadata = models.VolumeGlanceMetadata()
@@ -1774,18 +1869,16 @@ def volume_glance_metadata_copy_from_volume_to_volume(context,
 
 @require_context
 @require_volume_exists
-def volume_glance_metadata_copy_to_volume(context, volume_id, snapshot_id,
-                                          session=None):
+def volume_glance_metadata_copy_to_volume(context, volume_id, snapshot_id):
     """
     Update the Glance metadata from a volume (created from a snapshot) by
     copying all of the key:value pairs from the originating snapshot. This is
     so that the Glance metadata from the original volume is retained.
     """
-    if session is None:
-        session = get_session()
 
-    metadata = volume_snapshot_glance_metadata_get(context, snapshot_id,
-                                                   session=session)
+    session = get_session()
+    metadata = _volume_snapshot_glance_metadata_get(context, snapshot_id,
+                                                    session=session)
     with session.begin():
         for meta in metadata:
             vol_glance_metadata = models.VolumeGlanceMetadata()
@@ -1981,9 +2074,8 @@ def sm_volume_get_all(context):
 
 
 @require_context
-def backup_get(context, backup_id, session=None):
-    result = model_query(context, models.Backup,
-                         session=session, project_only=True).\
+def backup_get(context, backup_id):
+    result = model_query(context, models.Backup, project_only=True).\
         filter_by(id=backup_id).\
         first()
 
@@ -2054,7 +2146,7 @@ def backup_destroy(context, backup_id):
 
 
 @require_context
-def transfer_get(context, transfer_id, session=None):
+def _transfer_get(context, transfer_id, session=None):
     query = model_query(context, models.Transfer,
                         session=session).\
         filter_by(id=transfer_id)
@@ -2072,6 +2164,11 @@ def transfer_get(context, transfer_id, session=None):
     return result
 
 
+@require_context
+def transfer_get(context, transfer_id):
+    return _transfer_get(context, transfer_id)
+
+
 def _translate_transfers(transfers):
     results = []
     for transfer in transfers:
@@ -2110,9 +2207,9 @@ def transfer_create(context, values):
         values['id'] = str(uuid.uuid4())
     session = get_session()
     with session.begin():
-        volume_ref = volume_get(context,
-                                values['volume_id'],
-                                session=session)
+        volume_ref = _volume_get(context,
+                                 values['volume_id'],
+                                 session=session)
         if volume_ref['status'] != 'available':
             msg = _('Volume must be available')
             LOG.error(msg)
@@ -2129,12 +2226,12 @@ def transfer_create(context, values):
 def transfer_destroy(context, transfer_id):
     session = get_session()
     with session.begin():
-        transfer_ref = transfer_get(context,
-                                    transfer_id,
-                                    session=session)
-        volume_ref = volume_get(context,
-                                transfer_ref['volume_id'],
-                                session=session)
+        transfer_ref = _transfer_get(context,
+                                     transfer_id,
+                                     session=session)
+        volume_ref = _volume_get(context,
+                                 transfer_ref['volume_id'],
+                                 session=session)
         # If the volume state is not 'awaiting-transfer' don't change it, but
         # we can still mark the transfer record as deleted.
         if volume_ref['status'] != 'awaiting-transfer':
@@ -2156,9 +2253,9 @@ def transfer_destroy(context, transfer_id):
 def transfer_accept(context, transfer_id, user_id, project_id):
     session = get_session()
     with session.begin():
-        transfer_ref = transfer_get(context, transfer_id, session)
+        transfer_ref = _transfer_get(context, transfer_id, session)
         volume_id = transfer_ref['volume_id']
-        volume_ref = volume_get(context, volume_id, session=session)
+        volume_ref = _volume_get(context, volume_id, session=session)
         if volume_ref['status'] != 'awaiting-transfer':
             volume_status = volume_ref['status']
             msg = _('Transfer %(transfer_id)s: Volume id %(volume_id)s in '
index 58de0b398fec766e931093654e91b585a77b2da6..8913dfc3c98141dbd1a4e7f486ec46624d47acc0 100644 (file)
@@ -1149,8 +1149,8 @@ class QuotaReserveSqlAlchemyTestCase(test.TestCase):
 
         self.stubs.Set(sqa_api, 'get_session', fake_get_session)
         self.stubs.Set(sqa_api, '_get_quota_usages', fake_get_quota_usages)
-        self.stubs.Set(sqa_api, 'quota_usage_create', fake_quota_usage_create)
-        self.stubs.Set(sqa_api, 'reservation_create', fake_reservation_create)
+        self.stubs.Set(sqa_api, '_quota_usage_create', fake_quota_usage_create)
+        self.stubs.Set(sqa_api, '_reservation_create', fake_reservation_create)
 
         timeutils.set_time_override()