HEAD_FILENAME = 'HEAD'
HEADS_FILENAME = 'HEADS'
CURRENT_RELEASE = "liberty"
+RELEASES = (CURRENT_RELEASE,)
MIGRATION_BRANCHES = ('expand', 'contract')
MIGRATION_ENTRYPOINTS = 'neutron.db.alembic_migrations'
def do_check_migration(config, cmd):
do_alembic_command(config, 'branches')
+ validate_labels(config)
validate_heads_file(config)
sql=CONF.command.sql)
-def _get_branch_label(branch):
+def _get_branch_label(branch, release=None):
'''Get the latest branch label corresponding to release cycle.'''
- return '%s_%s' % (CURRENT_RELEASE, branch)
+ return '%s_%s' % (release or CURRENT_RELEASE, branch)
def _get_branch_head(branch):
update_heads_file(config)
+def _compare_labels(revision, expected_labels):
+ # validate that the script has the only label that corresponds to path
+ bad_labels = revision.branch_labels - expected_labels
+ if bad_labels:
+ script_name = os.path.basename(revision.path)
+ alembic_util.err(
+ _('Unexpected label for script %(script_name)s: %(labels)s') %
+ {'script_name': script_name,
+ 'labels': bad_labels}
+ )
+
+
+def _validate_single_revision_labels(script_dir, revision,
+ release=None, branch=None):
+ if branch is not None:
+ branch_label = _get_branch_label(branch, release=release)
+ expected_labels = set([branch_label])
+ else:
+ expected_labels = set()
+
+ _compare_labels(revision, expected_labels)
+
+ # if it's not the root element of the branch, expect the parent of the
+ # script to have the same label
+ if revision.down_revision is not None:
+ down_revision = script_dir.get_revision(revision.down_revision)
+ _compare_labels(down_revision, expected_labels)
+
+
+def _validate_revision(script_dir, revision):
+ for branch in MIGRATION_BRANCHES:
+ for release in RELEASES:
+ marker = os.path.join(release, branch)
+ if marker in revision.path:
+ _validate_single_revision_labels(
+ script_dir, revision, release=release, branch=branch)
+ return
+
+ # validate script from branchless part of migration rules
+ _validate_single_revision_labels(script_dir, revision)
+
+
+def validate_labels(config):
+ script_dir = alembic_script.ScriptDirectory.from_config(config)
+ revisions = [v for v in script_dir.walk_revisions(base='base',
+ head='heads')]
+ for revision in revisions:
+ _validate_revision(script_dir, revision)
+
+
def _get_sorted_heads(script):
'''Get the list of heads for all branches, sorted.'''
heads = script.get_heads()
service = ''
+class FakeRevision(object):
+ path = 'fakepath'
+
+ def __init__(self, labels=None, down_revision=None):
+ if not labels:
+ labels = set()
+ self.branch_labels = labels
+ self.down_revision = down_revision
+
+
class MigrationEntrypointsMemento(fixtures.Fixture):
'''Create a copy of the migration entrypoints map so it can be restored
during test cleanup.
cli.migration_entrypoints[project] = entrypoint
def _main_test_helper(self, argv, func_name, exp_args=(), exp_kwargs=[{}]):
- with mock.patch.object(sys, 'argv', argv), mock.patch.object(
- cli, 'run_sanity_checks'):
+ with mock.patch.object(sys, 'argv', argv),\
+ mock.patch.object(cli, 'run_sanity_checks'),\
+ mock.patch.object(cli, 'validate_labels'):
+
cli.main()
self.do_alembic_cmd.assert_has_calls(
[mock.call(mock.ANY, func_name, *exp_args, **kwargs)
def test_get_subproject_base_not_installed(self):
self.assertRaises(
SystemExit, cli._get_subproject_base, 'not-installed')
+
+ def test__get_branch_label_current(self):
+ self.assertEqual('%s_fakebranch' % cli.CURRENT_RELEASE,
+ cli._get_branch_label('fakebranch'))
+
+ def test__get_branch_label_other_release(self):
+ self.assertEqual('fakerelease_fakebranch',
+ cli._get_branch_label('fakebranch',
+ release='fakerelease'))
+
+ def test__compare_labels_ok(self):
+ labels = {'label1', 'label2'}
+ fake_revision = FakeRevision(labels)
+ cli._compare_labels(fake_revision, {'label1', 'label2'})
+
+ def test__compare_labels_fail_unexpected_labels(self):
+ labels = {'label1', 'label2', 'label3'}
+ fake_revision = FakeRevision(labels)
+ self.assertRaises(
+ SystemExit,
+ cli._compare_labels, fake_revision, {'label1', 'label2'})
+
+ @mock.patch.object(cli, '_compare_labels')
+ def test__validate_single_revision_labels_branchless_fail_different_labels(
+ self, compare_mock):
+
+ fake_down_revision = FakeRevision()
+ fake_revision = FakeRevision(down_revision=fake_down_revision)
+
+ script_dir = mock.Mock()
+ script_dir.get_revision.return_value = fake_down_revision
+ cli._validate_single_revision_labels(script_dir, fake_revision,
+ branch=None)
+
+ expected_labels = set()
+ compare_mock.assert_has_calls(
+ [mock.call(revision, expected_labels)
+ for revision in (fake_revision, fake_down_revision)]
+ )
+
+ @mock.patch.object(cli, '_compare_labels')
+ def test__validate_single_revision_labels_branches_fail_different_labels(
+ self, compare_mock):
+
+ fake_down_revision = FakeRevision()
+ fake_revision = FakeRevision(down_revision=fake_down_revision)
+
+ script_dir = mock.Mock()
+ script_dir.get_revision.return_value = fake_down_revision
+ cli._validate_single_revision_labels(
+ script_dir, fake_revision,
+ release='fakerelease', branch='fakebranch')
+
+ expected_labels = {'fakerelease_fakebranch'}
+ compare_mock.assert_has_calls(
+ [mock.call(revision, expected_labels)
+ for revision in (fake_revision, fake_down_revision)]
+ )
+
+ @mock.patch.object(cli, '_validate_single_revision_labels')
+ def test__validate_revision_validates_branches(self, validate_mock):
+ script_dir = mock.Mock()
+ fake_revision = FakeRevision()
+ release = cli.RELEASES[0]
+ branch = cli.MIGRATION_BRANCHES[0]
+ fake_revision.path = os.path.join('/fake/path', release, branch)
+ cli._validate_revision(script_dir, fake_revision)
+ validate_mock.assert_called_with(
+ script_dir, fake_revision, release=release, branch=branch)
+
+ @mock.patch.object(cli, '_validate_single_revision_labels')
+ def test__validate_revision_validates_branchless_migrations(
+ self, validate_mock):
+
+ script_dir = mock.Mock()
+ fake_revision = FakeRevision()
+ cli._validate_revision(script_dir, fake_revision)
+ validate_mock.assert_called_with(script_dir, fake_revision)
+
+ @mock.patch.object(cli, '_validate_revision')
+ @mock.patch('alembic.script.ScriptDirectory.walk_revisions')
+ def test_validate_labels_walks_thru_all_revisions(
+ self, walk_mock, validate_mock):
+
+ revisions = [mock.Mock() for i in range(10)]
+ walk_mock.return_value = revisions
+ cli.validate_labels(self.configs[0])
+ validate_mock.assert_has_calls(
+ [mock.call(mock.ANY, revision) for revision in revisions]
+ )
+
+
+class TestSafetyChecks(base.BaseTestCase):
+
+ def test_validate_labels(self, *mocks):
+ cli.validate_labels(cli.get_neutron_config())