00febc6295c2b930aafe3938ab636e26fcb22ea2
[openstack-build/neutron-build.git] / neutron / tests / tempest / common / ssh.py
1 # Copyright 2012 OpenStack Foundation
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16
17 import cStringIO
18 import select
19 import socket
20 import time
21 import warnings
22
23 from oslo_log import log as logging
24 import six
25
26 from neutron.tests.tempest import exceptions
27
28
29 with warnings.catch_warnings():
30     warnings.simplefilter("ignore")
31     import paramiko
32
33
34 LOG = logging.getLogger(__name__)
35
36
37 class Client(object):
38
39     def __init__(self, host, username, password=None, timeout=300, pkey=None,
40                  channel_timeout=10, look_for_keys=False, key_filename=None):
41         self.host = host
42         self.username = username
43         self.password = password
44         if isinstance(pkey, six.string_types):
45             pkey = paramiko.RSAKey.from_private_key(
46                 cStringIO.StringIO(str(pkey)))
47         self.pkey = pkey
48         self.look_for_keys = look_for_keys
49         self.key_filename = key_filename
50         self.timeout = int(timeout)
51         self.channel_timeout = float(channel_timeout)
52         self.buf_size = 1024
53
54     def _get_ssh_connection(self, sleep=1.5, backoff=1):
55         """Returns an ssh connection to the specified host."""
56         bsleep = sleep
57         ssh = paramiko.SSHClient()
58         ssh.set_missing_host_key_policy(
59             paramiko.AutoAddPolicy())
60         _start_time = time.time()
61         if self.pkey is not None:
62             LOG.info("Creating ssh connection to '%s' as '%s'"
63                      " with public key authentication",
64                      self.host, self.username)
65         else:
66             LOG.info("Creating ssh connection to '%s' as '%s'"
67                      " with password %s",
68                      self.host, self.username, str(self.password))
69         attempts = 0
70         while True:
71             try:
72                 ssh.connect(self.host, username=self.username,
73                             password=self.password,
74                             look_for_keys=self.look_for_keys,
75                             key_filename=self.key_filename,
76                             timeout=self.channel_timeout, pkey=self.pkey)
77                 LOG.info("ssh connection to %s@%s successfuly created",
78                          self.username, self.host)
79                 return ssh
80             except (socket.error,
81                     paramiko.SSHException) as e:
82                 if self._is_timed_out(_start_time):
83                     LOG.exception("Failed to establish authenticated ssh"
84                                   " connection to %s@%s after %d attempts",
85                                   self.username, self.host, attempts)
86                     raise exceptions.SSHTimeout(host=self.host,
87                                                 user=self.username,
88                                                 password=self.password)
89                 bsleep += backoff
90                 attempts += 1
91                 LOG.warning("Failed to establish authenticated ssh"
92                             " connection to %s@%s (%s). Number attempts: %s."
93                             " Retry after %d seconds.",
94                             self.username, self.host, e, attempts, bsleep)
95                 time.sleep(bsleep)
96
97     def _is_timed_out(self, start_time):
98         return (time.time() - self.timeout) > start_time
99
100     def exec_command(self, cmd):
101         """
102         Execute the specified command on the server.
103
104         Note that this method is reading whole command outputs to memory, thus
105         shouldn't be used for large outputs.
106
107         :returns: data read from standard output of the command.
108         :raises: SSHExecCommandFailed if command returns nonzero
109                  status. The exception contains command status stderr content.
110         """
111         ssh = self._get_ssh_connection()
112         transport = ssh.get_transport()
113         channel = transport.open_session()
114         channel.fileno()  # Register event pipe
115         channel.exec_command(cmd)
116         channel.shutdown_write()
117         out_data = []
118         err_data = []
119         poll = select.poll()
120         poll.register(channel, select.POLLIN)
121         start_time = time.time()
122
123         while True:
124             ready = poll.poll(self.channel_timeout)
125             if not any(ready):
126                 if not self._is_timed_out(start_time):
127                     continue
128                 raise exceptions.TimeoutException(
129                     "Command: '{0}' executed on host '{1}'.".format(
130                         cmd, self.host))
131             if not ready[0]:  # If there is nothing to read.
132                 continue
133             out_chunk = err_chunk = None
134             if channel.recv_ready():
135                 out_chunk = channel.recv(self.buf_size)
136                 out_data += out_chunk,
137             if channel.recv_stderr_ready():
138                 err_chunk = channel.recv_stderr(self.buf_size)
139                 err_data += err_chunk,
140             if channel.closed and not err_chunk and not out_chunk:
141                 break
142         exit_status = channel.recv_exit_status()
143         if 0 != exit_status:
144             raise exceptions.SSHExecCommandFailed(
145                 command=cmd, exit_status=exit_status,
146                 strerror=''.join(err_data))
147         return ''.join(out_data)
148
149     def test_connection_auth(self):
150         """Raises an exception when we can not connect to server via ssh."""
151         connection = self._get_ssh_connection()
152         connection.close()