]> review.fuel-infra Code Review - openstack-build/cinder-build.git/commitdiff
Use a SSH pool to manage SSH connection
authorzhangchao010 <zhangchao010@huawei.com>
Thu, 4 Apr 2013 00:45:21 +0000 (08:45 +0800)
committerzhangchao010 <zhangchao010@huawei.com>
Fri, 12 Apr 2013 16:57:12 +0000 (00:57 +0800)
Use a SSH pool to hold all SSH clients.It allows 4 SSH clients
at most to connect to the SSH server at the same time.
This patchset also enables every SSH client connect to the other
controller when they failed to connect to the current controller.
For example,failed to A,then to B,or,failed to B,then to A.

Fixes bug: 1162251
Change-Id: I86f7f684639034be97ddf2031e61ac6bf3a196ad

cinder/utils.py
cinder/volume/drivers/huawei/huawei_iscsi.py

index f5164acc35ddea8d7920b3a920cb1b1e9a367e37..65128be4a9f83eb32497222afe6d12961d752de4 100644 (file)
@@ -271,6 +271,13 @@ def ssh_execute(ssh, cmd, process_input=None,
     return (stdout, stderr)
 
 
+def create_channel(client, width, height):
+    """Invoke an interactive shell session on server."""
+    channel = client.invoke_shell()
+    channel.resize_pty(width, height)
+    return channel
+
+
 class SSHPool(pools.Pool):
     """A simple eventlet pool to hold ssh connections."""
 
@@ -344,6 +351,15 @@ class SSHPool(pools.Pool):
             return created
         return self.channel.get()
 
+    def remove(self, ssh):
+        """Close an ssh client and remove it from free_items."""
+        ssh.close()
+        ssh = None
+        if ssh in self.free_items:
+            self.free_items.pop(ssh)
+        if self.current_size > 0:
+            self.current_size -= 1
+
 
 def cinderdir():
     import cinder
index ad2961eba324a45355de4fd972f41aa52a90b8e6..b6f51bcdbe621b6fd3c3803ca8f2f03c9650b403 100644 (file)
 """
 Volume driver for HUAWEI T series and Dorado storage systems.
 """
-
+import os
+import paramiko
 import re
 import socket
+import threading
 import time
 
 from oslo.config import cfg
@@ -45,57 +47,78 @@ VOL_AND_SNAP_NAME_PREFIX = 'OpenStack_'
 READBUFFERSIZE = 8192
 
 
-class SSHConnection(utils.SSHPool):
-    """An SSH connetion class .
+class SSHConn(utils.SSHPool):
+    """Define a new class inherited to SSHPool.
 
-    For some reasons, we can not use method ssh_execute defined in utils to
-    send CLI commands. Here we define a new class inherited to SSHPool. Use
-    method create() to build a new SSH client and use invoke_shell() to start
-    an interactive shell session on the storage system.
+    This class rewrites method create() and defines a private method
+    ssh_read() which reads results of ssh commands.
     """
 
-    def __init__(self, ip, port, login, password, conn_timeout,
+    def __init__(self, ip, port, conn_timeout, login, password,
                  privatekey=None, *args, **kwargs):
-        self.ssh = None
-        super(SSHConnection, self).__init__(ip, port, conn_timeout, login,
-                                            password, privatekey=None,
-                                            *args, **kwargs)
-
-    def connect(self):
-        """Create an SSH client and open an interactive SSH channel."""
-        self.ssh = self.create()
-        self.channel = self.ssh.invoke_shell()
-        self.channel.resize_pty(600, 800)
-
-    def close(self):
-        """Close SSH connection."""
-        self.channel.close()
-        self.ssh.close()
-
-    def read(self, timeout=None):
-        """Read data from SSH channel."""
+
+        super(SSHConn, self).__init__(ip, port, conn_timeout, login,
+                                      password, privatekey=None,
+                                      *args, **kwargs)
+        self.lock = threading.Lock()
+
+    def create(self):
+        """Create an SSH client.
+
+        Because seting socket timeout to be None will cause client.close()
+        blocking, here we have to rewrite method create() and use default
+        socket timeout value 0.1.
+        """
+        try:
+            ssh = paramiko.SSHClient()
+            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            if self.password:
+                ssh.connect(self.ip,
+                            port=self.port,
+                            username=self.login,
+                            password=self.password,
+                            timeout=self.conn_timeout)
+            elif self.privatekey:
+                pkfile = os.path.expanduser(self.privatekey)
+                privatekey = paramiko.RSAKey.from_private_key_file(pkfile)
+                ssh.connect(self.ip,
+                            port=self.port,
+                            username=self.login,
+                            pkey=privatekey,
+                            timeout=self.conn_timeout)
+            else:
+                msg = _("Specify a password or private_key")
+                raise exception.CinderException(msg)
+
+            if self.conn_timeout:
+                transport = ssh.get_transport()
+                transport.set_keepalive(self.conn_timeout)
+            return ssh
+        except Exception as e:
+            msg = _("Error connecting via ssh: %s") % e
+            LOG.error(msg)
+            raise paramiko.SSHException(msg)
+
+    def ssh_read(self, channel, cmd, timeout):
+        """Get results of CLI commands."""
         result = ''
-        user_flg = self.login + ':/>$'
-        self.channel.settimeout(timeout)
+        user = self.login
+        user_flg = user + ':/>$'
+        channel.settimeout(timeout)
         while True:
             try:
-                result = result + self.channel.recv(READBUFFERSIZE)
+                result = result + channel.recv(READBUFFERSIZE)
             except socket.timeout:
-                break
+                raise exception.VolumeBackendAPIException(_('read timed out'))
             else:
-                # If we get the complete result string, then no need to wait
-                # until time out.
-                if re.search(user_flg, result) or re.search('(y/n)', result):
+                if re.search(cmd, result) and re.search(user_flg, result):
+                    if not re.search('Welcome', result):
+                        break
+                    elif re.search(user + ':/>' + cmd, result):
+                        break
+                elif re.search('(y/n)', result):
                     break
-        return result
-
-    def send_cmd(self, strcmd, timeout, waitstr=None):
-        """Send SSH commands and return results."""
-        info = ''
-        self.channel.send(strcmd + '\n')
-        result = self.read(timeout)
-        info = '\r\n'.join(result.split('\r\n')[1:-1])
-        return info
+        return '\r\n'.join(result.split('\r\n')[:-1])
 
 
 class HuaweiISCSIDriver(driver.ISCSIDriver):
@@ -107,10 +130,7 @@ class HuaweiISCSIDriver(driver.ISCSIDriver):
         self.device_type = {}
         self.login_info = {}
         self.hostgroup_id = None
-
-        # Flag to tell whether the other controller is available
-        # if the current controller can not be connected to.
-        self.controller_alterable = True
+        self.ssh_pool = None
 
     def do_setup(self, context):
         """Check config file."""
@@ -719,46 +739,76 @@ class HuaweiISCSIDriver(driver.ISCSIDriver):
         If the connection to first controller time out,
         try to connect to the other controller.
         """
+        LOG.debug(_('CLI command:%s') % cmd)
+        connect_times = 0
+        ip0 = self.login_info['ControllerIP0']
+        ip1 = self.login_info['ControllerIP1']
         user = self.login_info['UserName']
         pwd = self.login_info['UserPassword']
+        if not self.ssh_pool:
+            self.ssh_pool = SSHConn(ip0, 22, 30, user, pwd)
+        ssh_client = None
         while True:
-            if self.controller_alterable:
-                ip = self.login_info['ControllerIP0']
-            else:
-                ip = self.login_info['ControllerIP1']
-
+            if connect_times == 1:
+                # Switch to the other controller.
+                self.ssh_pool.lock.acquire()
+                if ssh_client:
+                    if ssh_client.server_ip == self.ssh_pool.ip:
+                        if self.ssh_pool.ip == ip0:
+                            self.ssh_pool.ip = ip1
+                        else:
+                            self.ssh_pool.ip = ip0
+                    # Create a new client.
+                    if ssh_client.chan:
+                        ssh_client.chan.close()
+                        ssh_client.chan = None
+                        ssh_client.server_ip = None
+                        ssh_client.close()
+                        ssh_client = None
+                        ssh_client = self.ssh_pool.create()
+                else:
+                    self.ssh_pool.ip = ip1
+                self.ssh_pool.lock.release()
             try:
-                ssh = SSHConnection(ip, 22, user, pwd, 30)
-                ssh.connect()
+                if not ssh_client:
+                    ssh_client = self.ssh_pool.get()
+                # "server_ip" shows controller connecting with the ssh client.
+                if ('server_ip' not in ssh_client.__dict__ or
+                        not ssh_client.server_ip):
+                    self.ssh_pool.lock.acquire()
+                    ssh_client.server_ip = self.ssh_pool.ip
+                    self.ssh_pool.lock.release()
+                # An SSH client owns one "chan".
+                if ('chan' not in ssh_client.__dict__ or
+                        not ssh_client.chan):
+                    ssh_client.chan =\
+                        utils.create_channel(ssh_client, 600, 800)
+
                 while True:
-                    out = ssh.send_cmd(cmd, 30)
+                    ssh_client.chan.send(cmd + '\n')
+                    out = self.ssh_pool.ssh_read(ssh_client.chan, cmd, 20)
                     if out.find('(y/n)') > -1:
                         cmd = 'y'
                     else:
                         break
-                ssh.close()
-            except Exception as err:
-                if ((not self.controller_alterable) and
-                        (str(err).find('timed out') > -1)):
-                    self.controller_alterable = False
+                self.ssh_pool.put(ssh_client)
 
-                    LOG.debug(_('_execute_cli:Connect to controller0 %(ctr0)s'
-                                ' time out.Try to Connect to controller1 '
-                                '%(ctr1)s.')
-                              % {'ctr0': self.login_info['ControllerIP0'],
-                                 'ctr1': self.login_info['ControllerIP1']})
+                index = out.find(user + ':/>')
+                if index > -1:
+                    return out[index:]
+                else:
+                    return out
 
+            except Exception as err:
+                if connect_times < 1:
+                    connect_times += 1
                     continue
                 else:
+                    if ssh_client:
+                        self.ssh_pool.remove(ssh_client)
                     LOG.error(_('_execute_cli:%s') % err)
                     raise exception.VolumeBackendAPIException(data=err)
 
-            index = out.find(user + ':/>')
-            if index > -1:
-                return out[index:]
-            else:
-                return out
-
     def _name_translate(self, name):
         """Form new names because of the 32-character limit on names."""
         newname = VOL_AND_SNAP_NAME_PREFIX + str(hash(name))