1 # Copyright 2012 OpenStack Foundation
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may
5 # not use this file except in compliance with the License. You may obtain
6 # a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations
23 from oslo_log import log as logging
26 from neutron.tests.tempest import exceptions
29 with warnings.catch_warnings():
30 warnings.simplefilter("ignore")
34 LOG = logging.getLogger(__name__)
39 def __init__(self, host, username, password=None, timeout=300, pkey=None,
40 channel_timeout=10, look_for_keys=False, key_filename=None):
42 self.username = username
43 self.password = password
44 if isinstance(pkey, six.string_types):
45 pkey = paramiko.RSAKey.from_private_key(
46 cStringIO.StringIO(str(pkey)))
48 self.look_for_keys = look_for_keys
49 self.key_filename = key_filename
50 self.timeout = int(timeout)
51 self.channel_timeout = float(channel_timeout)
54 def _get_ssh_connection(self, sleep=1.5, backoff=1):
55 """Returns an ssh connection to the specified host."""
57 ssh = paramiko.SSHClient()
58 ssh.set_missing_host_key_policy(
59 paramiko.AutoAddPolicy())
60 _start_time = time.time()
61 if self.pkey is not None:
62 LOG.info("Creating ssh connection to '%s' as '%s'"
63 " with public key authentication",
64 self.host, self.username)
66 LOG.info("Creating ssh connection to '%s' as '%s'"
68 self.host, self.username, str(self.password))
72 ssh.connect(self.host, username=self.username,
73 password=self.password,
74 look_for_keys=self.look_for_keys,
75 key_filename=self.key_filename,
76 timeout=self.channel_timeout, pkey=self.pkey)
77 LOG.info("ssh connection to %s@%s successfuly created",
78 self.username, self.host)
81 paramiko.SSHException) as e:
82 if self._is_timed_out(_start_time):
83 LOG.exception("Failed to establish authenticated ssh"
84 " connection to %s@%s after %d attempts",
85 self.username, self.host, attempts)
86 raise exceptions.SSHTimeout(host=self.host,
88 password=self.password)
91 LOG.warning("Failed to establish authenticated ssh"
92 " connection to %s@%s (%s). Number attempts: %s."
93 " Retry after %d seconds.",
94 self.username, self.host, e, attempts, bsleep)
97 def _is_timed_out(self, start_time):
98 return (time.time() - self.timeout) > start_time
100 def exec_command(self, cmd):
102 Execute the specified command on the server.
104 Note that this method is reading whole command outputs to memory, thus
105 shouldn't be used for large outputs.
107 :returns: data read from standard output of the command.
108 :raises: SSHExecCommandFailed if command returns nonzero
109 status. The exception contains command status stderr content.
111 ssh = self._get_ssh_connection()
112 transport = ssh.get_transport()
113 channel = transport.open_session()
114 channel.fileno() # Register event pipe
115 channel.exec_command(cmd)
116 channel.shutdown_write()
120 poll.register(channel, select.POLLIN)
121 start_time = time.time()
124 ready = poll.poll(self.channel_timeout)
126 if not self._is_timed_out(start_time):
128 raise exceptions.TimeoutException(
129 "Command: '{0}' executed on host '{1}'.".format(
131 if not ready[0]: # If there is nothing to read.
133 out_chunk = err_chunk = None
134 if channel.recv_ready():
135 out_chunk = channel.recv(self.buf_size)
136 out_data += out_chunk,
137 if channel.recv_stderr_ready():
138 err_chunk = channel.recv_stderr(self.buf_size)
139 err_data += err_chunk,
140 if channel.closed and not err_chunk and not out_chunk:
142 exit_status = channel.recv_exit_status()
144 raise exceptions.SSHExecCommandFailed(
145 command=cmd, exit_status=exit_status,
146 strerror=''.join(err_data))
147 return ''.join(out_data)
149 def test_connection_auth(self):
150 """Raises an exception when we can not connect to server via ssh."""
151 connection = self._get_ssh_connection()