#!/usr/bin/env pythonimport paramikoimport tracebackimport socketimport timeimport datetimeimport loggingclass SSHCommander(object): """class docs""" def __init__(self, serverip, username, password): self._username = username self._password = password self._host = serverip self._port = 22 self.transport = None self.bufsize = 65536 # Setup the logger self.logger = logging.getLogger('MySSH') self.logger.setLevel(logging.INFO) fmt = '%(asctime)s SSHCommander:%(funcName)s:%(lineno)d %(message)s' formater = logging.Formatter(fmt) handler = logging.StreamHandler() handler.setFormatter(formater) self.logger.addHandler(handler) self.info = self.logger.info self.debug = self.logger.debug self.error = self.logger.error self.connect() def __del__(self): if self.transport is not None: self.transport.close() self.transport = None def logout(self): if not self.isClosed(): self.client.close() self.client = None def connect(self): '''Try to connect, maybe again.''' self.client = paramiko.SSHClient() self.client.load_system_host_keys() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: self.client.connect(self._host, self._port, self._username, self._password) self.transport = self.client.get_transport() self.transport.use_compression(True) except socket.error: self.transport = None except paramiko.BadAuthenticationType: self.transport = None def isClosed(self): '''Try to judge whether connect is closed or not.''' if self.client is None: return True transport = self.client.get_transport() if transport is None or not transport.is_active() or not transport.is_active(): self.logout() return True return False def execute(self, cmd): self.debug('Start to execute SSH command: %s' % cmd) ssh_stdin, ssh_stdout, ssh_stderr = self.client.exec_command(cmd) linestr = ssh_stdout.readline() while linestr: self.info(linestr) linestr = ssh_stdout.readline() linestr = ssh_stderr.read() while linestr: self.info(linestr) linestr = ssh_stderr.read() return ssh_stdout.channel.recv_exit_status() def execute_interactive(self, cmd, input_data=None): if self.transport is None: return -1 # Fix the input data. input_data = self._run_fix_input_data(input_data) timeout = 10 # Initialize the session. session = self.transport.open_session() session.set_combine_stderr(True) session.get_pty() session.exec_command(cmd) output = self._run_poll(session, timeout, input_data) status = session.recv_exit_status() self.info(output) return status def executeSftpGet(self, remotepath, localpath): self.debug('SFTP Get: {remotepath} {localpath} ...'.format(remotepath=remotepath, localpath=localpath)) sftp = self.client.open_sftp() sftp.get(remotepath, localpath) sftp.close() self.debug('SFTP Get: {remotepath} {localpath} done.'.format(remotepath=remotepath, localpath=localpath)) def executeSftpPut(self, localpath, remotepath): self.debug('SFTP Put: {localpath} {remotepath} ...'.format(localpath=localpath, remotepath=remotepath)) sftp = self.client.open_sftp() sftp.put(localpath, remotepath) sftp.close() self.debug('SFTP Put: {localpath} {remotepath} done'.format(localpath=localpath, remotepath=remotepath)) def _run_fix_input_data(self, input_data): ''' Fix the input data supplied by the user for a command. @param input_data The input data (default is None). @returns the fixed input data. ''' if input_data is not None: if len(input_data) > 0: if '\\n' in input_data: # Convert \n in the input into new lines. lines = input_data.split('\\n') input_data = '\n'.join(lines) return input_data.split('\n') return [] def _run_poll(self, session, timeout, input_data): ''' Poll until the command completes. @param session The session. @param timeout The timeout in seconds. @param input_data The input data. @returns the output ''' interval = 0.1 maxseconds = timeout maxcount = maxseconds / interval # Poll until completion or timeout # Note that we cannot directly use the stdout file descriptor # because it stalls at 64K bytes (65536). input_idx = 0 timeout_flag = False start = datetime.datetime.now() start_secs = time.mktime(start.timetuple()) output = '' session.setblocking(0) while True: if session.recv_ready(): data = session.recv(self.bufsize) output += data self.debug('read %d bytes, total %d' % (len(data), len(output))) if session.send_ready(): # We received a potential prompt. # In the future this could be made to work more like # pexpect with pattern matching. if input_idx < len(input_data): data = input_data[input_idx] + '\n' input_idx += 1 session.send(data) if session.exit_status_ready(): break # Timeout check now = datetime.datetime.now() now_secs = time.mktime(now.timetuple()) et_secs = now_secs - start_secs if et_secs > maxseconds: self.error('polling finished - timeout') timeout_flag = True break time.sleep(0.200) self.debug('polling loop ended') if session.recv_ready(): data = session.recv(self.bufsize) output += data self.debug('read %d bytes, total %d' % (len(data), len(output))) self.debug('polling finished - %d output bytes' % (len(output))) if timeout_flag: self.debug('appending timeout message') output += '\nERROR: timeout after %d seconds\n' % (timeout) session.close() return output def main(): """main function.""" try: cmd = SSHCommander("127.0.0.1", "abc", "abc1234") retCode = cmd.execute_interactive("su root -c 'export PATH=$PATH:/sbin;
service tomcat status'", "root1234") if retCode != 0: exit(retCode) except Exception: print "Unexpected error:" + traceback.format_exc() sys.exit(1) finally: if cmd: cmd.logout()if __name__ == "__main__": main()
上面的例子是SSH到某台机器上检查tomcat的运行状态,但它会首先切换到ROOT用户,这个过程中系统会提示用户输入密码:# python SSHCommander.py2014-09-08 20:30:32,856 SSHCommander:execute_interactive:107 Password: Status of Tomcat Service: Tomcat (pid 31745) is running
