]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Adding a SSH Connection Pool.
authorNirmal Ranganathan <rnirmal@gmail.com>
Fri, 2 Nov 2012 04:05:35 +0000 (23:05 -0500)
committerNirmal Ranganathan <rnirmal@gmail.com>
Fri, 2 Nov 2012 04:05:35 +0000 (23:05 -0500)
Adds a connection pool using eventlet.pools
which works well with greenthreads. Adding new
parameters ssh_min_pool_conn and ssh_max_pool_conn.

Updating the _run_ssh method in san.py to use the
connection pool and cleaning up the existing
single connections, also added retries.

Fixes bug 1074185

Change-Id: I90dd89ffc025d09fc6ad060c4273508103b85456

cinder/tests/test_utils.py
cinder/utils.py
cinder/volume/san/san.py
etc/cinder/cinder.conf.sample

index 92be797b8fcfe5b76e335c4325ae697372230cf6..c1298f1bb3dcd9a890613932ef0dbaf58b5e11ef 100644 (file)
@@ -19,6 +19,7 @@ import datetime
 import hashlib
 import os
 import os.path
+import paramiko
 import StringIO
 import tempfile
 
@@ -667,3 +668,88 @@ class AuditPeriodTest(test.TestCase):
                                            day=1,
                                            month=6,
                                            year=2011))
+
+
+class FakeSSHClient(object):
+
+    def __init__(self):
+        self.id = utils.gen_uuid()
+        self.transport = FakeTransport()
+
+    def set_missing_host_key_policy(self, policy):
+        pass
+
+    def connect(self, ip, port=22, username=None, password=None,
+                pkey=None, timeout=10):
+        pass
+
+    def get_transport(self):
+        return self.transport
+
+    def close(self):
+        pass
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class FakeSock(object):
+    def settimeout(self, timeout):
+        pass
+
+
+class FakeTransport(object):
+
+    def __init__(self):
+        self.active = True
+        self.sock = FakeSock()
+
+    def set_keepalive(self, timeout):
+        pass
+
+    def is_active(self):
+        return self.active
+
+
+class SSHPoolTestCase(test.TestCase):
+    """Unit test for SSH Connection Pool."""
+
+    def setup(self):
+        self.mox.StubOutWithMock(paramiko, "SSHClient")
+        paramiko.SSHClient().AndReturn(FakeSSHClient())
+        self.mox.ReplayAll()
+
+    def test_single_ssh_connect(self):
+        self.setup()
+        sshpool = utils.SSHPool("127.0.0.1", 22, 10, "test", password="test",
+                                min_size=1, max_size=1)
+        with sshpool.item() as ssh:
+            first_id = ssh.id
+
+        with sshpool.item() as ssh:
+            second_id = ssh.id
+
+        self.assertEqual(first_id, second_id)
+
+    def test_closed_reopend_ssh_connections(self):
+        self.setup()
+        sshpool = utils.SSHPool("127.0.0.1", 22, 10, "test", password="test",
+                                min_size=1, max_size=2)
+        with sshpool.item() as ssh:
+            first_id = ssh.id
+        with sshpool.item() as ssh:
+            second_id = ssh.id
+            # Close the connection and test for a new connection
+            ssh.get_transport().active = False
+
+        self.assertEqual(first_id, second_id)
+
+        # The mox items are not getting setup in a new pool connection,
+        # so had to reset and set again.
+        self.mox.UnsetStubs()
+        self.setup()
+
+        with sshpool.item() as ssh:
+            third_id = ssh.id
+
+        self.assertNotEqual(first_id, third_id)
index 752f11e04c219e55ac85989d13ac404a8c47ad50..e8035ea58deee5269aa49a9e04b0595b87655499 100644 (file)
@@ -28,6 +28,7 @@ import hashlib
 import inspect
 import itertools
 import os
+import paramiko
 import pyclbr
 import random
 import re
@@ -46,6 +47,7 @@ from xml.sax import saxutils
 
 from eventlet import event
 from eventlet import greenthread
+from eventlet import pools
 from eventlet.green import subprocess
 
 from cinder.common import deprecated
@@ -232,7 +234,7 @@ def trycmd(*args, **kwargs):
 
 def ssh_execute(ssh, cmd, process_input=None,
                 addl_env=None, check_exit_code=True):
-    LOG.debug(_('Running cmd (SSH): %s'), ' '.join(cmd))
+    LOG.debug(_('Running cmd (SSH): %s'), cmd)
     if addl_env:
         raise exception.Error(_('Environment not supported over SSH'))
 
@@ -251,6 +253,8 @@ def ssh_execute(ssh, cmd, process_input=None,
     stdout = stdout_stream.read()
     stderr = stderr_stream.read()
     stdin_stream.close()
+    stdout_stream.close()
+    stderr_stream.close()
 
     exit_status = channel.recv_exit_status()
 
@@ -261,11 +265,84 @@ def ssh_execute(ssh, cmd, process_input=None,
             raise exception.ProcessExecutionError(exit_code=exit_status,
                                                   stdout=stdout,
                                                   stderr=stderr,
-                                                  cmd=' '.join(cmd))
-
+                                                  cmd=cmd)
+    channel.close()
     return (stdout, stderr)
 
 
+class SSHPool(pools.Pool):
+    """A simple eventlet pool to hold ssh connections."""
+
+    def __init__(self, ip, port, conn_timeout, login, password=None,
+                 privatekey=None, *args, **kwargs):
+        self.ip = ip
+        self.port = port
+        self.login = login
+        self.password = password
+        self.conn_timeout = conn_timeout
+        self.privatekey = privatekey
+        super(SSHPool, self).__init__(*args, **kwargs)
+
+    def create(self):
+        try:
+            ssh = paramiko.SSHClient()
+            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            if self.password:
+                ssh.connect(self.ip,
+                            port=self.port,
+                            username=self.login,
+                            password=self.password,
+                            timeout=self.conn_timeout)
+            elif self.privatekey:
+                pkfile = os.path.expanduser(self.privatekey)
+                privatekey = paramiko.RSAKey.from_private_key_file(pkfile)
+                ssh.connect(self.ip,
+                            port=self.port,
+                            username=self.login,
+                            pkey=privatekey,
+                            timeout=self.conn_timeout)
+            else:
+                msg = _("Specify a password or private_key")
+                raise exception.CinderException(msg)
+
+            # Paramiko by default sets the socket timeout to 0.1 seconds,
+            # ignoring what we set thru the sshclient. This doesn't help for
+            # keeping long lived connections. Hence we have to bypass it, by
+            # overriding it after the transport is initialized. We are setting
+            # the sockettimeout to None and setting a keepalive packet so that,
+            # the server will keep the connection open. All that does is send
+            # a keepalive packet every ssh_conn_timeout seconds.
+            transport = ssh.get_transport()
+            transport.sock.settimeout(None)
+            transport.set_keepalive(self.conn_timeout)
+            return ssh
+        except Exception as e:
+            msg = "Error connecting via ssh: %s" % e
+            LOG.error(_(msg))
+            raise paramiko.SSHException(msg)
+
+    def get(self):
+        """
+        Return an item from the pool, when one is available.  This may
+        cause the calling greenthread to block. Check if a connection is active
+        before returning it. For dead connections create and return a new
+        connection.
+        """
+        if self.free_items:
+            conn = self.free_items.popleft()
+            if conn:
+                if conn.get_transport().is_active():
+                    return conn
+                else:
+                    conn.close()
+            return self.create()
+        if self.current_size < self.max_size:
+            created = self.create()
+            self.current_size += 1
+            return created
+        return self.channel.get()
+
+
 def cinderdir():
     import cinder
     return os.path.abspath(cinder.__file__).split('cinder/__init__.py')[0]
index c57eb3a21ce4594337432b81b3990f3e4f8a9919..feda4fc1243e1ebac2963213341e8c0084f44ae5 100644 (file)
@@ -21,8 +21,10 @@ The unique thing about a SAN is that we don't expect that we can run the volume
 controller on the SAN hardware.  We expect to access it over SSH or some API.
 """
 
-import os
 import paramiko
+import random
+
+from eventlet import greenthread
 
 from cinder import exception
 from cinder import flags
@@ -60,6 +62,15 @@ san_opts = [
                 default=False,
                 help='Execute commands locally instead of over SSH; '
                      'use if the volume service is running on the SAN device'),
+    cfg.IntOpt('ssh_conn_timeout',
+               default=30,
+               help="SSH connection timeout in seconds"),
+    cfg.IntOpt('ssh_min_pool_conn',
+               default=1,
+               help='Minimum ssh connections in the pool'),
+    cfg.IntOpt('ssh_max_pool_conn',
+               default=5,
+               help='Maximum ssh connections in the pool'),
 ]
 
 FLAGS = flags.FLAGS
@@ -77,32 +88,11 @@ class SanISCSIDriver(ISCSIDriver):
     def __init__(self, *args, **kwargs):
         super(SanISCSIDriver, self).__init__(*args, **kwargs)
         self.run_local = FLAGS.san_is_local
+        self.sshpool = None
 
     def _build_iscsi_target_name(self, volume):
         return "%s%s" % (FLAGS.iscsi_target_prefix, volume['name'])
 
-    def _connect_to_ssh(self):
-        ssh = paramiko.SSHClient()
-        #TODO(justinsb): We need a better SSH key policy
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        if FLAGS.san_password:
-            ssh.connect(FLAGS.san_ip,
-                        port=FLAGS.san_ssh_port,
-                        username=FLAGS.san_login,
-                        password=FLAGS.san_password)
-        elif FLAGS.san_private_key:
-            privatekeyfile = os.path.expanduser(FLAGS.san_private_key)
-            # It sucks that paramiko doesn't support DSA keys
-            privatekey = paramiko.RSAKey.from_private_key_file(privatekeyfile)
-            ssh.connect(FLAGS.san_ip,
-                        port=FLAGS.san_ssh_port,
-                        username=FLAGS.san_login,
-                        pkey=privatekey)
-        else:
-            msg = _("Specify san_password or san_private_key")
-            raise exception.InvalidInput(reason=msg)
-        return ssh
-
     def _execute(self, *cmd, **kwargs):
         if self.run_local:
             return utils.execute(*cmd, **kwargs)
@@ -111,16 +101,33 @@ class SanISCSIDriver(ISCSIDriver):
             command = ' '.join(cmd)
             return self._run_ssh(command, check_exit_code)
 
-    def _run_ssh(self, command, check_exit_code=True):
-        #TODO(justinsb): SSH connection caching (?)
-        ssh = self._connect_to_ssh()
-
-        #TODO(justinsb): Reintroduce the retry hack
-        ret = utils.ssh_execute(ssh, command, check_exit_code=check_exit_code)
-
-        ssh.close()
-
-        return ret
+    def _run_ssh(self, command, check_exit_code=True, attempts=1):
+        if not self.sshpool:
+            self.sshpool = utils.SSHPool(FLAGS.san_ip,
+                                         FLAGS.san_ssh_port,
+                                         FLAGS.ssh_conn_timeout,
+                                         FLAGS.san_login,
+                                         password=FLAGS.san_password,
+                                         privatekey=FLAGS.san_private_key,
+                                         min_size=FLAGS.ssh_min_pool_conn,
+                                         max_size=FLAGS.ssh_max_pool_conn)
+        try:
+            total_attempts = attempts
+            with self.sshpool.item() as ssh:
+                while attempts > 0:
+                    attempts -= 1
+                    try:
+                        return utils.ssh_execute(ssh, command,
+                                               check_exit_code=check_exit_code)
+                    except Exception as e:
+                        LOG.error(e)
+                        greenthread.sleep(random.randint(20, 500) / 100.0)
+                raise paramiko.SSHException(_("SSH Command failed after '%r' "
+                                              "attempts: '%s'"
+                                              % (total_attempts, command)))
+        except Exception as e:
+            LOG.error(_("Error running ssh command: %s" % command))
+            raise e
 
     def ensure_export(self, context, volume):
         """Synchronously recreates an export for a logical volume."""
index 93137047f1833d2b23055c6588f0d792c08ec98c..f976d05ea06ed175fe13c15753768660280bab59 100644 (file)
 #### (BoolOpt) Execute commands locally instead of over SSH; use if the
 ####           volume service is running on the SAN device
 
+# ssh_conn_timeout=30
+#### (IntOpt) SSH connection timeout in seconds
+
+# ssh_min_pool_conn=1
+#### (IntOpt) Minimum ssh connections in the pool
+
+# ssh_max_pool_conn=5
+#### (IntOpt) Maximum ssh connections in the pool
+
 
 ######## defined in cinder.volume.solaris ########