"""Utilities related to SSH connection management."""
-import os.path
+import os
+import string
from eventlet import pools
+from oslo.config import cfg
import paramiko
from cinder import exception
LOG = logging.getLogger(__name__)
+ssh_opts = [
+ cfg.BoolOpt('strict_ssh_host_key_policy',
+ default=False,
+ help='Option to enable strict host key checking. When '
+ 'set to "True" Cinder will only connect to systems '
+ 'with a host key present in the configured '
+ '"ssh_hosts_key_file". When set to "False" the host key '
+ 'will be saved upon first connection and used for '
+ 'subsequent connections. Default=False'),
+ cfg.StrOpt('ssh_hosts_key_file',
+ default='$state_path/ssh_known_hosts',
+ help='File containing SSH host keys for the systems with which '
+ 'Cinder needs to communicate. OPTIONAL: '
+ 'Default=$state_path/known_hosts'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(ssh_opts)
+
class SSHPool(pools.Pool):
"""A simple eventlet pool to hold ssh connections."""
self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None
self.privatekey = privatekey
- if 'missing_key_policy' in kwargs.keys():
- self.missing_key_policy = kwargs.pop('missing_key_policy')
- else:
- self.missing_key_policy = paramiko.AutoAddPolicy()
+ self.hosts_key_file = None
+
+ # Validate good config setting here.
+ # Paramiko handles the case where the file is inaccessible.
+ if not CONF.ssh_hosts_key_file:
+ raise exception.ParameterNotFound(param='ssh_hosts_key_file')
+ elif not os.path.isfile(CONF.ssh_hosts_key_file):
+ # If using the default path, just create the file.
+ if CONF.state_path in CONF.ssh_hosts_key_file:
+ open(CONF.ssh_hosts_key_file, 'a').close()
+ else:
+ msg = (_("Unable to find ssh_hosts_key_file: %s") %
+ CONF.ssh_hosts_key_file)
+ raise exception.InvalidInput(reason=msg)
+
if 'hosts_key_file' in kwargs.keys():
self.hosts_key_file = kwargs.pop('hosts_key_file')
+ LOG.info(_("Secondary ssh hosts key file %(kwargs)s will be "
+ "loaded along with %(conf)s from /etc/cinder.conf.") %
+ {'kwargs': self.hosts_key_file,
+ 'conf': CONF.ssh_hosts_key_file})
+
+ LOG.debug("Setting strict_ssh_host_key_policy to '%(policy)s' "
+ "using ssh_hosts_key_file '%(key_file)s'." %
+ {'policy': CONF.strict_ssh_host_key_policy,
+ 'key_file': CONF.ssh_hosts_key_file})
+
+ self.strict_ssh_host_key_policy = CONF.strict_ssh_host_key_policy
+
+ if not self.hosts_key_file:
+ self.hosts_key_file = CONF.ssh_hosts_key_file
else:
- self.hosts_key_file = None
+ self.hosts_key_file += ',' + CONF.ssh_hosts_key_file
+
super(SSHPool, self).__init__(*args, **kwargs)
def create(self):
try:
ssh = paramiko.SSHClient()
- ssh.set_missing_host_key_policy(self.missing_key_policy)
- if not self.hosts_key_file:
- ssh.load_system_host_keys()
+ if ',' in self.hosts_key_file:
+ files = string.split(self.hosts_key_file, ',')
+ for f in files:
+ ssh.load_host_keys(f)
else:
ssh.load_host_keys(self.hosts_key_file)
+ # If strict_ssh_host_key_policy is set we want to reject, by
+ # default if there is not entry in the known_hosts file.
+ # Otherwise we use AutoAddPolicy which accepts on the first
+ # Connect but fails if the keys change. load_host_keys can
+ # handle hashed known_host entries.
+ if self.strict_ssh_host_key_policy:
+ ssh.set_missing_host_key_policy(paramiko.RejectPolicy())
+ else:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
if self.password:
ssh.connect(self.ip,
port=self.port,
def get_policy(self):
return self.policy
+ def get_host_keys(self):
+ return '127.0.0.1 ssh-rsa deadbeef'
+
def close(self):
pass
class SSHPoolTestCase(test.TestCase):
"""Unit test for SSH Connection Pool."""
+ @mock.patch('__builtin__.open')
@mock.patch('paramiko.SSHClient')
- def test_ssh_key_policy(self, mock_sshclient):
- mock_sshclient.return_value = FakeSSHClient()
+ @mock.patch('os.path.isfile', return_value=True)
+ def test_ssh_default_hosts_key_file(self, mock_isfile, mock_sshclient,
+ mock_open):
+ mock_ssh = mock.MagicMock()
+ mock_sshclient.return_value = mock_ssh
# create with customized setting
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test",
password="test",
min_size=1,
- max_size=1,
- missing_key_policy=paramiko.RejectPolicy(),
- hosts_key_file='dummy_host_keyfile')
- with sshpool.item() as ssh:
- self.assertTrue(isinstance(ssh.get_policy(),
- paramiko.RejectPolicy))
- self.assertEqual(ssh.hosts_key_file, 'dummy_host_keyfile')
+ max_size=1)
+
+ host_key_files = sshpool.hosts_key_file
+
+ self.assertEqual('/var/lib/cinder/ssh_known_hosts', host_key_files)
+
+ mock_ssh.load_host_keys.assert_called_once_with(
+ '/var/lib/cinder/ssh_known_hosts')
- # create with default setting
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ @mock.patch('os.path.isfile', return_value=True)
+ def test_ssh_host_key_file_kwargs(self, mock_isfile, mock_sshclient,
+ mock_open):
+ mock_ssh = mock.MagicMock()
+ mock_sshclient.return_value = mock_ssh
+
+ # create with customized setting
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test",
password="test",
min_size=1,
- max_size=1)
- with sshpool.item() as ssh:
- self.assertTrue(isinstance(ssh.get_policy(),
- paramiko.AutoAddPolicy))
- self.assertEqual(ssh.system_host_keys, 'system_host_keys')
+ max_size=1,
+ hosts_key_file='dummy_host_keyfile')
+
+ host_key_files = sshpool.hosts_key_file
+ self.assertIn('dummy_host_keyfile', host_key_files)
+ self.assertIn('/var/lib/cinder/ssh_known_hosts', host_key_files)
+
+ expected = [
+ mock.call.load_host_keys('dummy_host_keyfile'),
+ mock.call.load_host_keys('/var/lib/cinder/ssh_known_hosts')]
+
+ mock_ssh.assert_has_calls(expected, any_order=True)
+
+ @mock.patch('__builtin__.open')
+ @mock.patch('os.path.isfile', return_value=True)
@mock.patch('paramiko.RSAKey.from_private_key_file')
@mock.patch('paramiko.SSHClient')
- def test_single_ssh_connect(self, mock_sshclient, mock_pkey):
+ def test_single_ssh_connect(self, mock_sshclient, mock_pkey, mock_isfile,
+ mock_open):
mock_sshclient.return_value = FakeSSHClient()
+ CONF.ssh_hosts_key_file = '/var/lib/cinder/ssh_known_hosts'
+
# create with password
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test",
min_size=1,
max_size=1)
+ @mock.patch('__builtin__.open')
@mock.patch('paramiko.SSHClient')
- def test_closed_reopend_ssh_connections(self, mock_sshclient):
+ def test_closed_reopened_ssh_connections(self, mock_sshclient, mock_open):
mock_sshclient.return_value = eval('FakeSSHClient')()
sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
"test",
self.assertNotEqual(first_id, third_id)
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ def test_missing_ssh_hosts_key_config(self, mock_sshclient, mock_open):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ CONF.ssh_hosts_key_file = None
+ # create with password
+ self.assertRaises(exception.ParameterNotFound,
+ ssh_utils.SSHPool,
+ "127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ def test_create_default_known_hosts_file(self, mock_sshclient,
+ mock_open):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ CONF.state_path = '/var/lib/cinder'
+ CONF.ssh_hosts_key_file = '/var/lib/cinder/ssh_known_hosts'
+
+ default_file = '/var/lib/cinder/ssh_known_hosts'
+
+ ssh_pool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+
+ with ssh_pool.item() as ssh:
+ mock_open.assert_called_once_with(default_file, 'a')
+ ssh_pool.remove(ssh)
+
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ def test_ssh_missing_hosts_key_file(self, mock_sshclient, mock_open):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ CONF.ssh_hosts_key_file = '/tmp/blah'
+
+ self.assertRaises(exception.InvalidInput,
+ ssh_utils.SSHPool,
+ "127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ @mock.patch('os.path.isfile', return_value=True)
+ def test_ssh_strict_host_key_policy(self, mock_isfile, mock_sshclient,
+ mock_open):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ CONF.strict_ssh_host_key_policy = True
+
+ # create with customized setting
+ sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+
+ with sshpool.item() as ssh:
+ self.assertTrue(isinstance(ssh.get_policy(),
+ paramiko.RejectPolicy))
+
+ @mock.patch('__builtin__.open')
+ @mock.patch('paramiko.SSHClient')
+ @mock.patch('os.path.isfile', return_value=True)
+ def test_ssh_not_strict_host_key_policy(self, mock_isfile, mock_sshclient,
+ mock_open):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ CONF.strict_ssh_host_key_policy = False
+
+ # create with customized setting
+ sshpool = ssh_utils.SSHPool("127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+
+ with sshpool.item() as ssh:
+ self.assertTrue(isinstance(ssh.get_policy(),
+ paramiko.AutoAddPolicy))
+
class BrickUtils(test.TestCase):
"""Unit test to test the brick utility