self.transport = FakeTransport()
def set_missing_host_key_policy(self, policy):
- pass
+ self.policy = policy
+
+ def load_system_host_keys(self):
+ self.system_host_keys = 'system_host_keys'
+
+ def load_host_keys(self, hosts_key_file):
+ self.hosts_key_file = hosts_key_file
def connect(self, ip, port=22, username=None, password=None,
pkey=None, timeout=10):
def get_transport(self):
return self.transport
+ def get_policy(self):
+ return self.policy
+
def close(self):
pass
class SSHPoolTestCase(test.TestCase):
"""Unit test for SSH Connection Pool."""
+ @mock.patch('paramiko.SSHClient')
+ def test_ssh_key_policy(self, mock_sshclient):
+ mock_sshclient.return_value = FakeSSHClient()
+
+ # create with customized setting
+ sshpool = utils.SSHPool("127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1,
+ missing_key_policy=paramiko.RejectPolicy(),
+ hosts_key_file='dummy_host_keyfile')
+ with sshpool.item() as ssh:
+ self.assertTrue(isinstance(ssh.get_policy(),
+ paramiko.RejectPolicy))
+ self.assertEqual(ssh.hosts_key_file, 'dummy_host_keyfile')
+
+ # create with default setting
+ sshpool = utils.SSHPool("127.0.0.1", 22, 10,
+ "test",
+ password="test",
+ min_size=1,
+ max_size=1)
+ with sshpool.item() as ssh:
+ self.assertTrue(isinstance(ssh.get_policy(),
+ paramiko.AutoAddPolicy))
+ self.assertEqual(ssh.system_host_keys, 'system_host_keys')
@mock.patch('paramiko.RSAKey.from_private_key_file')
@mock.patch('paramiko.SSHClient')
self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None
self.privatekey = privatekey
+ if 'missing_key_policy' in kwargs.keys():
+ self.missing_key_policy = kwargs.pop('missing_key_policy')
+ else:
+ self.missing_key_policy = paramiko.AutoAddPolicy()
+ if 'hosts_key_file' in kwargs.keys():
+ self.hosts_key_file = kwargs.pop('hosts_key_file')
+ else:
+ self.hosts_key_file = None
super(SSHPool, self).__init__(*args, **kwargs)
def create(self):
try:
ssh = paramiko.SSHClient()
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.set_missing_host_key_policy(self.missing_key_policy)
+ if not self.hosts_key_file:
+ ssh.load_system_host_keys()
+ else:
+ ssh.load_host_keys(self.hosts_key_file)
if self.password:
ssh.connect(self.ip,
port=self.port,