]> review.fuel-infra Code Review - openstack-build/neutron-build.git/commitdiff
[neutron-db-manage] check_migration: validate labels
authorIhar Hrachyshka <ihrachys@redhat.com>
Tue, 28 Jul 2015 22:16:17 +0000 (00:16 +0200)
committerIhar Hrachyshka <ihrachys@redhat.com>
Thu, 13 Aug 2015 14:51:44 +0000 (16:51 +0200)
Guard against potential down_revision interleave by checking that each
revision has the only revision that corresponds to its location in the
migration tree, and that its parent also has that same single label.

Partially-Implements: blueprint online-schema-migrations
Change-Id: Ia812e8283f4da955610fe043aba3ad0298ede24b

neutron/db/migration/cli.py
neutron/tests/unit/db/test_migration.py

index 53c5393bd891160a07ff864f5659af0341edcd4f..521af0172df030f75393e422b866985df1a92019 100644 (file)
@@ -31,6 +31,7 @@ from neutron.common import utils
 HEAD_FILENAME = 'HEAD'
 HEADS_FILENAME = 'HEADS'
 CURRENT_RELEASE = "liberty"
+RELEASES = (CURRENT_RELEASE,)
 MIGRATION_BRANCHES = ('expand', 'contract')
 
 MIGRATION_ENTRYPOINTS = 'neutron.db.alembic_migrations'
@@ -114,6 +115,7 @@ def _get_alembic_entrypoint(project):
 
 def do_check_migration(config, cmd):
     do_alembic_command(config, 'branches')
+    validate_labels(config)
     validate_heads_file(config)
 
 
@@ -158,9 +160,9 @@ def do_stamp(config, cmd):
                        sql=CONF.command.sql)
 
 
-def _get_branch_label(branch):
+def _get_branch_label(branch, release=None):
     '''Get the latest branch label corresponding to release cycle.'''
-    return '%s_%s' % (CURRENT_RELEASE, branch)
+    return '%s_%s' % (release or CURRENT_RELEASE, branch)
 
 
 def _get_branch_head(branch):
@@ -201,6 +203,56 @@ def do_revision(config, cmd):
     update_heads_file(config)
 
 
+def _compare_labels(revision, expected_labels):
+    # validate that the script has the only label that corresponds to path
+    bad_labels = revision.branch_labels - expected_labels
+    if bad_labels:
+        script_name = os.path.basename(revision.path)
+        alembic_util.err(
+            _('Unexpected label for script %(script_name)s: %(labels)s') %
+            {'script_name': script_name,
+             'labels': bad_labels}
+        )
+
+
+def _validate_single_revision_labels(script_dir, revision,
+                                     release=None, branch=None):
+    if branch is not None:
+        branch_label = _get_branch_label(branch, release=release)
+        expected_labels = set([branch_label])
+    else:
+        expected_labels = set()
+
+    _compare_labels(revision, expected_labels)
+
+    # if it's not the root element of the branch, expect the parent of the
+    # script to have the same label
+    if revision.down_revision is not None:
+        down_revision = script_dir.get_revision(revision.down_revision)
+        _compare_labels(down_revision, expected_labels)
+
+
+def _validate_revision(script_dir, revision):
+    for branch in MIGRATION_BRANCHES:
+        for release in RELEASES:
+            marker = os.path.join(release, branch)
+            if marker in revision.path:
+                _validate_single_revision_labels(
+                    script_dir, revision, release=release, branch=branch)
+                return
+
+    # validate script from branchless part of migration rules
+    _validate_single_revision_labels(script_dir, revision)
+
+
+def validate_labels(config):
+    script_dir = alembic_script.ScriptDirectory.from_config(config)
+    revisions = [v for v in script_dir.walk_revisions(base='base',
+                                                      head='heads')]
+    for revision in revisions:
+        _validate_revision(script_dir, revision)
+
+
 def _get_sorted_heads(script):
     '''Get the list of heads for all branches, sorted.'''
     heads = script.get_heads()
index 87f57f7e16c0690f18bd3929a7ecbc53e90a90d5..9cffe4e843e8f92d34274c9a7d377a62dc689264 100644 (file)
@@ -31,6 +31,16 @@ class FakeConfig(object):
     service = ''
 
 
+class FakeRevision(object):
+    path = 'fakepath'
+
+    def __init__(self, labels=None, down_revision=None):
+        if not labels:
+            labels = set()
+        self.branch_labels = labels
+        self.down_revision = down_revision
+
+
 class MigrationEntrypointsMemento(fixtures.Fixture):
     '''Create a copy of the migration entrypoints map so it can be restored
        during test cleanup.
@@ -126,8 +136,10 @@ class TestCli(base.BaseTestCase):
             cli.migration_entrypoints[project] = entrypoint
 
     def _main_test_helper(self, argv, func_name, exp_args=(), exp_kwargs=[{}]):
-        with mock.patch.object(sys, 'argv', argv), mock.patch.object(
-                cli, 'run_sanity_checks'):
+        with mock.patch.object(sys, 'argv', argv),\
+            mock.patch.object(cli, 'run_sanity_checks'),\
+            mock.patch.object(cli, 'validate_labels'):
+
             cli.main()
             self.do_alembic_cmd.assert_has_calls(
                 [mock.call(mock.ANY, func_name, *exp_args, **kwargs)
@@ -369,3 +381,99 @@ class TestCli(base.BaseTestCase):
     def test_get_subproject_base_not_installed(self):
         self.assertRaises(
             SystemExit, cli._get_subproject_base, 'not-installed')
+
+    def test__get_branch_label_current(self):
+        self.assertEqual('%s_fakebranch' % cli.CURRENT_RELEASE,
+                         cli._get_branch_label('fakebranch'))
+
+    def test__get_branch_label_other_release(self):
+        self.assertEqual('fakerelease_fakebranch',
+                         cli._get_branch_label('fakebranch',
+                                               release='fakerelease'))
+
+    def test__compare_labels_ok(self):
+        labels = {'label1', 'label2'}
+        fake_revision = FakeRevision(labels)
+        cli._compare_labels(fake_revision, {'label1', 'label2'})
+
+    def test__compare_labels_fail_unexpected_labels(self):
+        labels = {'label1', 'label2', 'label3'}
+        fake_revision = FakeRevision(labels)
+        self.assertRaises(
+            SystemExit,
+            cli._compare_labels, fake_revision, {'label1', 'label2'})
+
+    @mock.patch.object(cli, '_compare_labels')
+    def test__validate_single_revision_labels_branchless_fail_different_labels(
+        self, compare_mock):
+
+        fake_down_revision = FakeRevision()
+        fake_revision = FakeRevision(down_revision=fake_down_revision)
+
+        script_dir = mock.Mock()
+        script_dir.get_revision.return_value = fake_down_revision
+        cli._validate_single_revision_labels(script_dir, fake_revision,
+                                             branch=None)
+
+        expected_labels = set()
+        compare_mock.assert_has_calls(
+            [mock.call(revision, expected_labels)
+             for revision in (fake_revision, fake_down_revision)]
+        )
+
+    @mock.patch.object(cli, '_compare_labels')
+    def test__validate_single_revision_labels_branches_fail_different_labels(
+        self, compare_mock):
+
+        fake_down_revision = FakeRevision()
+        fake_revision = FakeRevision(down_revision=fake_down_revision)
+
+        script_dir = mock.Mock()
+        script_dir.get_revision.return_value = fake_down_revision
+        cli._validate_single_revision_labels(
+            script_dir, fake_revision,
+            release='fakerelease', branch='fakebranch')
+
+        expected_labels = {'fakerelease_fakebranch'}
+        compare_mock.assert_has_calls(
+            [mock.call(revision, expected_labels)
+             for revision in (fake_revision, fake_down_revision)]
+        )
+
+    @mock.patch.object(cli, '_validate_single_revision_labels')
+    def test__validate_revision_validates_branches(self, validate_mock):
+        script_dir = mock.Mock()
+        fake_revision = FakeRevision()
+        release = cli.RELEASES[0]
+        branch = cli.MIGRATION_BRANCHES[0]
+        fake_revision.path = os.path.join('/fake/path', release, branch)
+        cli._validate_revision(script_dir, fake_revision)
+        validate_mock.assert_called_with(
+            script_dir, fake_revision, release=release, branch=branch)
+
+    @mock.patch.object(cli, '_validate_single_revision_labels')
+    def test__validate_revision_validates_branchless_migrations(
+        self, validate_mock):
+
+        script_dir = mock.Mock()
+        fake_revision = FakeRevision()
+        cli._validate_revision(script_dir, fake_revision)
+        validate_mock.assert_called_with(script_dir, fake_revision)
+
+    @mock.patch.object(cli, '_validate_revision')
+    @mock.patch('alembic.script.ScriptDirectory.walk_revisions')
+    def test_validate_labels_walks_thru_all_revisions(
+        self, walk_mock, validate_mock):
+
+        revisions = [mock.Mock() for i in range(10)]
+        walk_mock.return_value = revisions
+        cli.validate_labels(self.configs[0])
+        validate_mock.assert_has_calls(
+            [mock.call(mock.ANY, revision) for revision in revisions]
+        )
+
+
+class TestSafetyChecks(base.BaseTestCase):
+
+    def test_validate_labels(self, *mocks):
+        cli.validate_labels(cli.get_neutron_config())