]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Validate filters in snapshot*, backup* in db.api
authorMichal Dulko <michal.dulko@intel.com>
Wed, 26 Aug 2015 12:10:00 +0000 (14:10 +0200)
committerMichal Dulko <michal.dulko@intel.com>
Thu, 27 Aug 2015 15:23:44 +0000 (17:23 +0200)
In db.sqlalchemy.api methods filters passed from c-api are applied
directly to the DB query. From all snapshot_get_all* methods filter
validation was done only for snapshot_get_all method. Backup methods
are missing the validation completely. This is causing an exception
about unknown DB column and returns 500 HTTP error when calling API
with an incorrect filter from an admin context (without admin context
filters are validated on an c-api level). This commit adds such
validation to snapshot_get_by_host, snapshot_get_all_by_project and
backup_get_all* methods to prevent such failures. Regression unit tests
are also added.

APIImpact
Closes-Bug: 1469678
Change-Id: I3a9dc6a430f2a149073592487437721a39f0afc5

cinder/db/sqlalchemy/api.py
cinder/tests/unit/test_db_api.py

index 5906fbd426306f570005874b22a0ae782f5dd055..f1c5d79dce1ffebfad8658fa11163e6f89120af1 100644 (file)
@@ -2089,6 +2089,9 @@ def snapshot_get_all(context, filters=None, marker=None, limit=None,
                       paired with corresponding item in sort_keys
     :returns: list of matching snapshots
     """
+    if filters and not is_valid_model_filters(models.Snapshot, filters):
+        return []
+
     session = get_session()
     with session.begin():
         query = _generate_paginate_query(context, session, marker, limit,
@@ -2127,6 +2130,9 @@ def snapshot_get_all_for_volume(context, volume_id):
 
 @require_context
 def snapshot_get_by_host(context, host, filters=None):
+    if filters and not is_valid_model_filters(models.Snapshot, filters):
+        return []
+
     query = model_query(context, models.Snapshot, read_deleted='no',
                         project_only=True)
     if filters:
@@ -2169,6 +2175,9 @@ def snapshot_get_all_by_project(context, project_id, filters=None, marker=None,
                       paired with corresponding item in sort_keys
     :returns: list of matching snapshots
     """
+    if filters and not is_valid_model_filters(models.Snapshot, filters):
+        return []
+
     authorize_project_context(context, project_id)
 
     # Add project_id to filters
@@ -3399,6 +3408,9 @@ def backup_get(context, backup_id):
 
 
 def _backup_get_all(context, filters=None):
+    if filters and not is_valid_model_filters(models.Backup, filters):
+        return []
+
     session = get_session()
     with session.begin():
         # Generate the query
index 06d68d5eda5ac448a59798dd228b1ff4a76d54c1..8723a7d8cd820d1c4b539c7de14752f125be509a 100644 (file)
@@ -1143,6 +1143,49 @@ class DBAPISnapshotTestCase(BaseTest):
                                             self.ctxt,
                                             'host2', {'status': 'error'}),
                                         ignored_keys='volume')
+        self._assertEqualListsOfObjects([],
+                                        db.snapshot_get_by_host(
+                                            self.ctxt,
+                                            'host2', {'fake_key': 'fake'}),
+                                        ignored_keys='volume')
+
+    def test_snapshot_get_all_by_project(self):
+        db.volume_create(self.ctxt, {'id': 1})
+        db.volume_create(self.ctxt, {'id': 2})
+        snapshot1 = db.snapshot_create(self.ctxt, {'id': 1, 'volume_id': 1,
+                                                   'project_id': 'project1'})
+        snapshot2 = db.snapshot_create(self.ctxt, {'id': 2, 'volume_id': 2,
+                                                   'status': 'error',
+                                                   'project_id': 'project2'})
+
+        self._assertEqualListsOfObjects([snapshot1],
+                                        db.snapshot_get_all_by_project(
+                                            self.ctxt,
+                                            'project1'),
+                                        ignored_keys='volume')
+        self._assertEqualListsOfObjects([snapshot2],
+                                        db.snapshot_get_all_by_project(
+                                            self.ctxt,
+                                            'project2'),
+                                        ignored_keys='volume')
+        self._assertEqualListsOfObjects([],
+                                        db.snapshot_get_all_by_project(
+                                            self.ctxt,
+                                            'project2',
+                                            {'status': 'available'}),
+                                        ignored_keys='volume')
+        self._assertEqualListsOfObjects([snapshot2],
+                                        db.snapshot_get_all_by_project(
+                                            self.ctxt,
+                                            'project2',
+                                            {'status': 'error'}),
+                                        ignored_keys='volume')
+        self._assertEqualListsOfObjects([],
+                                        db.snapshot_get_all_by_project(
+                                            self.ctxt,
+                                            'project2',
+                                            {'fake_key': 'fake'}),
+                                        ignored_keys='volume')
 
     def test_snapshot_metadata_get(self):
         metadata = {'a': 'b', 'c': 'd'}
@@ -1765,6 +1808,10 @@ class DBAPIBackupTestCase(BaseTest):
         filtered_backups = db.backup_get_all(self.ctxt, filters=filters)
         self._assertEqualListsOfObjects([self.created[1]], filtered_backups)
 
+        filters = {'fake_key': 'fake'}
+        filtered_backups = db.backup_get_all(self.ctxt, filters=filters)
+        self._assertEqualListsOfObjects([], filtered_backups)
+
     def test_backup_get_all_by_host(self):
         byhost = db.backup_get_all_by_host(self.ctxt,
                                            self.created[1]['host'])
@@ -1775,6 +1822,21 @@ class DBAPIBackupTestCase(BaseTest):
                                               self.created[1]['project_id'])
         self._assertEqualObjects(self.created[1], byproj[0])
 
+        byproj = db.backup_get_all_by_project(self.ctxt,
+                                              self.created[1]['project_id'],
+                                              {'fake_key': 'fake'})
+        self._assertEqualListsOfObjects([], byproj)
+
+    def test_backup_get_all_by_volume(self):
+        byvol = db.backup_get_all_by_volume(self.ctxt,
+                                            self.created[1]['volume_id'])
+        self._assertEqualObjects(self.created[1], byvol[0])
+
+        byvol = db.backup_get_all_by_volume(self.ctxt,
+                                            self.created[1]['volume_id'],
+                                            {'fake_key': 'fake'})
+        self._assertEqualListsOfObjects([], byvol)
+
     def test_backup_update_nonexistent(self):
         self.assertRaises(exception.BackupNotFound,
                           db.backup_update,