]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
SSHPool in utils should allow customized host key missing policy
authorLynxzh <jmzhang@cn.ibm.com>
Mon, 19 May 2014 10:47:16 +0000 (18:47 +0800)
committerLynxzh <jmzhang@cn.ibm.com>
Wed, 21 May 2014 17:13:05 +0000 (01:13 +0800)
The cinder/utils SSHPool should allow  missing key policy and host key
file being customized so that any caller can determine by their own
scenario if the host key file can be customized, or if an 'AutoAdd' is
appropriate, or just reject the key when mismatch. This will give more
flexible customization and also prevent any security issue as a middle
man.

Closes-Bug: #1320056
Change-Id: I3c72b0d042de719ecd45429d376bd88d0aefb2cc

cinder/tests/test_utils.py
cinder/utils.py

index 64d9140365dacb70959b9ac74c39cfa4f1854384..5ba73cfaec695b53d75beaddc69b576c7f4c5f41 100644 (file)
@@ -767,7 +767,13 @@ class FakeSSHClient(object):
         self.transport = FakeTransport()
 
     def set_missing_host_key_policy(self, policy):
-        pass
+        self.policy = policy
+
+    def load_system_host_keys(self):
+        self.system_host_keys = 'system_host_keys'
+
+    def load_host_keys(self, hosts_key_file):
+        self.hosts_key_file = hosts_key_file
 
     def connect(self, ip, port=22, username=None, password=None,
                 pkey=None, timeout=10):
@@ -776,6 +782,9 @@ class FakeSSHClient(object):
     def get_transport(self):
         return self.transport
 
+    def get_policy(self):
+        return self.policy
+
     def close(self):
         pass
 
@@ -803,6 +812,33 @@ class FakeTransport(object):
 
 class SSHPoolTestCase(test.TestCase):
     """Unit test for SSH Connection Pool."""
+    @mock.patch('paramiko.SSHClient')
+    def test_ssh_key_policy(self, mock_sshclient):
+        mock_sshclient.return_value = FakeSSHClient()
+
+        # create with customized setting
+        sshpool = utils.SSHPool("127.0.0.1", 22, 10,
+                                "test",
+                                password="test",
+                                min_size=1,
+                                max_size=1,
+                                missing_key_policy=paramiko.RejectPolicy(),
+                                hosts_key_file='dummy_host_keyfile')
+        with sshpool.item() as ssh:
+            self.assertTrue(isinstance(ssh.get_policy(),
+                                       paramiko.RejectPolicy))
+            self.assertEqual(ssh.hosts_key_file, 'dummy_host_keyfile')
+
+        # create with default setting
+        sshpool = utils.SSHPool("127.0.0.1", 22, 10,
+                                "test",
+                                password="test",
+                                min_size=1,
+                                max_size=1)
+        with sshpool.item() as ssh:
+            self.assertTrue(isinstance(ssh.get_policy(),
+                                       paramiko.AutoAddPolicy))
+            self.assertEqual(ssh.system_host_keys, 'system_host_keys')
 
     @mock.patch('paramiko.RSAKey.from_private_key_file')
     @mock.patch('paramiko.SSHClient')
index a0a3e086f7e1017c35a066d4173679efc0a7f045..3736163306e6ef1e9818a39e365b82a6c5eafcc1 100644 (file)
@@ -189,12 +189,24 @@ class SSHPool(pools.Pool):
         self.password = password
         self.conn_timeout = conn_timeout if conn_timeout else None
         self.privatekey = privatekey
+        if 'missing_key_policy' in kwargs.keys():
+            self.missing_key_policy = kwargs.pop('missing_key_policy')
+        else:
+            self.missing_key_policy = paramiko.AutoAddPolicy()
+        if 'hosts_key_file' in kwargs.keys():
+            self.hosts_key_file = kwargs.pop('hosts_key_file')
+        else:
+            self.hosts_key_file = None
         super(SSHPool, self).__init__(*args, **kwargs)
 
     def create(self):
         try:
             ssh = paramiko.SSHClient()
-            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            ssh.set_missing_host_key_policy(self.missing_key_policy)
+            if not self.hosts_key_file:
+                ssh.load_system_host_keys()
+            else:
+                ssh.load_host_keys(self.hosts_key_file)
             if self.password:
                 ssh.connect(self.ip,
                             port=self.port,