mirror of
				https://github.com/ansible-collections/community.general.git
				synced 2025-10-25 21:44:00 -07:00 
			
		
		
		
	* replaces persistent connection digest with _create_control_path() * adds _ansible_socket to _legal_inputs in basic.py * adds connection_user to play_context * maps remote_user to connection_user when connection is local * maps ansible_socket in task_vars to module_args _ansible_socket if exists
		
			
				
	
	
		
			338 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable file
		
	
	
	
	
			
		
		
	
	
			338 lines
		
	
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable file
		
	
	
	
	
| #!/usr/bin/env python
 | |
| 
 | |
| # (c) 2016, Ansible, Inc. <support@ansible.com>
 | |
| #
 | |
| # This file is part of Ansible
 | |
| #
 | |
| # Ansible is free software: you can redistribute it and/or modify
 | |
| # it under the terms of the GNU General Public License as published by
 | |
| # the Free Software Foundation, either version 3 of the License, or
 | |
| # (at your option) any later version.
 | |
| #
 | |
| # Ansible is distributed in the hope that it will be useful,
 | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
| # GNU General Public License for more details.
 | |
| #
 | |
| # You should have received a copy of the GNU General Public License
 | |
| # along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
 | |
| 
 | |
| ########################################################
 | |
| from __future__ import (absolute_import, division, print_function)
 | |
| 
 | |
| __metaclass__ = type
 | |
| __requires__ = ['ansible']
 | |
| 
 | |
| try:
 | |
|     import pkg_resources
 | |
| except Exception:
 | |
|     pass
 | |
| 
 | |
| import fcntl
 | |
| import os
 | |
| import shlex
 | |
| import signal
 | |
| import socket
 | |
| import struct
 | |
| import sys
 | |
| import time
 | |
| import traceback
 | |
| import syslog
 | |
| import datetime
 | |
| 
 | |
| from io import BytesIO
 | |
| 
 | |
| from ansible import constants as C
 | |
| from ansible.module_utils._text import to_bytes, to_native
 | |
| from ansible.module_utils.six.moves import cPickle, StringIO
 | |
| from ansible.playbook.play_context import PlayContext
 | |
| from ansible.plugins import connection_loader
 | |
| from ansible.utils.path import unfrackpath, makedirs_safe
 | |
| from ansible.errors import AnsibleConnectionFailure
 | |
| from ansible.utils.display import Display
 | |
| 
 | |
| 
 | |
| def do_fork():
 | |
|     '''
 | |
|     Does the required double fork for a daemon process. Based on
 | |
|     http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
 | |
|     '''
 | |
|     try:
 | |
|         pid = os.fork()
 | |
|         if pid > 0:
 | |
|             return pid
 | |
| 
 | |
|         os.chdir("/")
 | |
|         os.setsid()
 | |
|         os.umask(0)
 | |
| 
 | |
|         try:
 | |
|             pid = os.fork()
 | |
|             if pid > 0:
 | |
|                 sys.exit(0)
 | |
| 
 | |
|             os.close(sys.stdin.fileno())
 | |
|             os.close(sys.stdout.fileno())
 | |
|             os.close(sys.stderr.fileno())
 | |
| 
 | |
|             return pid
 | |
|         except OSError as e:
 | |
|             sys.exit(1)
 | |
|     except OSError as e:
 | |
|         sys.exit(1)
 | |
| 
 | |
| def send_data(s, data):
 | |
|     packed_len = struct.pack('!Q',len(data))
 | |
|     return s.sendall(packed_len + data)
 | |
| 
 | |
| def recv_data(s):
 | |
|     header_len = 8 # size of a packed unsigned long long
 | |
|     data = b""
 | |
|     while len(data) < header_len:
 | |
|         d = s.recv(header_len - len(data))
 | |
|         if not d:
 | |
|             return None
 | |
|         data += d
 | |
|     data_len = struct.unpack('!Q',data[:header_len])[0]
 | |
|     data = data[header_len:]
 | |
|     while len(data) < data_len:
 | |
|         d = s.recv(data_len - len(data))
 | |
|         if not d:
 | |
|             return None
 | |
|         data += d
 | |
|     return data
 | |
| 
 | |
| def log(msg, host=None, **kwargs):
 | |
|     if host:
 | |
|         msg = '<%s> %s' % (host, msg)
 | |
|     facility = getattr(syslog, C.DEFAULT_SYSLOG_FACILITY, syslog.LOG_USER)
 | |
|     syslog.openlog('ansible-connection', 0, facility)
 | |
|     syslog.syslog(syslog.LOG_INFO, str(msg))
 | |
| 
 | |
| 
 | |
| class Server():
 | |
| 
 | |
|     def __init__(self, path, play_context):
 | |
|         self.path = path
 | |
|         self.play_context = play_context
 | |
| 
 | |
|         self._start_time = datetime.datetime.now()
 | |
| 
 | |
|         self.log("setup connection %s" % self.play_context.connection)
 | |
| 
 | |
|         self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
 | |
|         self.conn._connect()
 | |
|         if not self.conn.connected:
 | |
|             raise AnsibleConnectionFailure('unable to connect to remote host')
 | |
| 
 | |
|         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 | |
|         self.socket.bind(path)
 | |
|         self.socket.listen(1)
 | |
| 
 | |
|         signal.signal(signal.SIGALRM, self.alarm_handler)
 | |
| 
 | |
|     def dispatch(self, obj, name, *args, **kwargs):
 | |
|         meth = getattr(obj, name, None)
 | |
|         if meth:
 | |
|             return meth(*args, **kwargs)
 | |
| 
 | |
|     def log(self, msg):
 | |
|         log(msg, host=self.play_context.remote_addr)
 | |
| 
 | |
|     def alarm_handler(self, signum, frame):
 | |
|         '''
 | |
|         Alarm handler
 | |
|         '''
 | |
|         # FIXME: this should also set internal flags for other
 | |
|         #        areas of code to check, so they can terminate
 | |
|         #        earlier than the socket going back to the accept
 | |
|         #        call and failing there.
 | |
|         #
 | |
|         # hooks the connection plugin to handle any cleanup
 | |
|         self.dispatch(self.conn, 'alarm_handler', signum, frame)
 | |
|         self.socket.close()
 | |
| 
 | |
|     def run(self):
 | |
|         try:
 | |
|             while True:
 | |
|                 # set the alarm, if we don't get an accept before it
 | |
|                 # goes off we exit (via an exception caused by the socket
 | |
|                 # getting closed while waiting on accept())
 | |
|                 # FIXME: is this the best way to exit? as noted above in the
 | |
|                 #        handler we should probably be setting a flag to check
 | |
|                 #        here and in other parts of the code
 | |
|                 signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
 | |
|                 try:
 | |
|                     (s, addr) = self.socket.accept()
 | |
|                     # clear the alarm
 | |
|                     # FIXME: potential race condition here between the accept and
 | |
|                     #        time to this call.
 | |
|                     signal.alarm(0)
 | |
|                 except:
 | |
|                     break
 | |
| 
 | |
|                 while True:
 | |
|                     data = recv_data(s)
 | |
|                     if not data:
 | |
|                         break
 | |
| 
 | |
|                     signal.alarm(C.DEFAULT_TIMEOUT)
 | |
| 
 | |
|                     rc = 255
 | |
|                     try:
 | |
|                         if data.startswith(b'EXEC: '):
 | |
|                             cmd = data.split(b'EXEC: ')[1]
 | |
|                             (rc, stdout, stderr) = self.conn.exec_command(cmd)
 | |
|                         elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
 | |
|                             (op, src, dst) = shlex.split(to_native(data))
 | |
|                             stdout = stderr = ''
 | |
|                             try:
 | |
|                                 if op == 'FETCH:':
 | |
|                                     self.conn.fetch_file(src, dst)
 | |
|                                 elif op == 'PUT:':
 | |
|                                     self.conn.put_file(src, dst)
 | |
|                                 rc = 0
 | |
|                             except:
 | |
|                                 pass
 | |
|                         elif data.startswith(b'CONTEXT: '):
 | |
|                             pc_data = data.split(b'CONTEXT: ')[1]
 | |
| 
 | |
|                             src = StringIO(pc_data)
 | |
|                             pc_data = cPickle.load(src)
 | |
|                             src.close()
 | |
| 
 | |
|                             pc = PlayContext()
 | |
|                             pc.deserialize(pc_data)
 | |
| 
 | |
|                             self.dispatch(self.conn, 'update_play_context', pc)
 | |
|                             continue
 | |
|                         else:
 | |
|                             stdout = ''
 | |
|                             stderr = 'Invalid action specified'
 | |
|                     except:
 | |
|                         stdout = ''
 | |
|                         stderr = traceback.format_exc()
 | |
| 
 | |
|                     signal.alarm(0)
 | |
| 
 | |
|                     send_data(s, to_bytes(str(rc)))
 | |
|                     send_data(s, to_bytes(stdout))
 | |
|                     send_data(s, to_bytes(stderr))
 | |
|                 s.close()
 | |
|         except Exception as e:
 | |
|             log(traceback.format_exc())
 | |
|         finally:
 | |
|             # when done, close the connection properly and cleanup
 | |
|             # the socket file so it can be recreated
 | |
|             end_time = datetime.datetime.now()
 | |
|             delta = end_time - self._start_time
 | |
|             log('shutting down connection, connection was active for %s secs' % delta, self.play_context.remote_addr)
 | |
|             try:
 | |
|                 self.conn.close()
 | |
|                 self.socket.close()
 | |
|             except Exception as e:
 | |
|                 pass
 | |
|             os.remove(self.path)
 | |
| 
 | |
| def main():
 | |
| 
 | |
|     try:
 | |
|         # read the play context data via stdin, which means depickling it
 | |
|         # FIXME: as noted above, we will probably need to deserialize the
 | |
|         #        connection loader here as well at some point, otherwise this
 | |
|         #        won't find role- or playbook-based connection plugins
 | |
|         cur_line = sys.stdin.readline()
 | |
|         init_data = ''
 | |
|         while cur_line.strip() != '#END_INIT#':
 | |
|             if cur_line  == '':
 | |
|                 raise Exception("EOL found before init data was complete")
 | |
|             init_data += cur_line
 | |
|             cur_line = sys.stdin.readline()
 | |
|         src = BytesIO(to_bytes(init_data))
 | |
|         pc_data = cPickle.load(src)
 | |
| 
 | |
|         pc = PlayContext()
 | |
|         pc.deserialize(pc_data)
 | |
|     except Exception as e:
 | |
|         # FIXME: better error message/handling/logging
 | |
|         sys.stderr.write(traceback.format_exc())
 | |
|         sys.exit("FAIL: %s" % e)
 | |
| 
 | |
|     display.verbosity = pc.verbosity
 | |
| 
 | |
|     ssh = connection_loader.get('ssh', class_only=True)
 | |
|     m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user)
 | |
| 
 | |
|     # create the persistent connection dir if need be and create the paths
 | |
|     # which we will be using later
 | |
|     tmp_path = unfrackpath("$HOME/.ansible/pc")
 | |
|     makedirs_safe(tmp_path)
 | |
|     lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
 | |
|     sf_path = unfrackpath(m % dict(directory=tmp_path))
 | |
| 
 | |
|     # if the socket file doesn't exist, spin up the daemon process
 | |
|     lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
 | |
|     fcntl.lockf(lock_fd, fcntl.LOCK_EX)
 | |
|     if not os.path.exists(sf_path):
 | |
|         pid = do_fork()
 | |
|         if pid == 0:
 | |
|             try:
 | |
|                 server = Server(sf_path, pc)
 | |
|             except Exception as exc:
 | |
|                 log(traceback.format_exc(), pc.remote_addr)
 | |
|             fcntl.lockf(lock_fd, fcntl.LOCK_UN)
 | |
|             os.close(lock_fd)
 | |
|             server.run()
 | |
|             sys.exit(0)
 | |
|     fcntl.lockf(lock_fd, fcntl.LOCK_UN)
 | |
|     os.close(lock_fd)
 | |
| 
 | |
|     # now connect to the daemon process
 | |
|     # FIXME: if the socket file existed but the daemonized process was killed,
 | |
|     #        the connection will timeout here. Need to make this more resilient.
 | |
|     rc = 0
 | |
|     while rc == 0:
 | |
|         data = sys.stdin.readline()
 | |
|         if data == '':
 | |
|             break
 | |
|         if data.strip() == '':
 | |
|             continue
 | |
|         sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 | |
|         attempts = 1
 | |
|         while True:
 | |
|             try:
 | |
|                 sf.connect(sf_path)
 | |
|                 break
 | |
|             except socket.error:
 | |
|                 # FIXME: better error handling/logging/message here
 | |
|                 time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
 | |
|                 attempts += 1
 | |
|                 if attempts > C.PERSISTENT_CONNECT_RETRIES:
 | |
|                     sys.stderr.write('failed to connect to the listener socket, '
 | |
|                     'connection timeout.  See syslog for more details')
 | |
|                     sys.exit(255)
 | |
| 
 | |
|         # send the play_context back into the connection so the connection
 | |
|         # can handle any privilege escalation activities
 | |
|         pc_data = 'CONTEXT: %s' % src.getvalue()
 | |
|         send_data(sf, to_bytes(pc_data))
 | |
|         src.close()
 | |
| 
 | |
|         send_data(sf, to_bytes(data.strip()))
 | |
| 
 | |
|         rc = int(recv_data(sf), 10)
 | |
|         stdout = recv_data(sf)
 | |
|         stderr = recv_data(sf)
 | |
| 
 | |
|         sys.stdout.write(to_native(stdout))
 | |
|         sys.stderr.write(to_native(stderr))
 | |
| 
 | |
|         sf.close()
 | |
|         break
 | |
| 
 | |
|     sys.exit(rc)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     display = Display()
 | |
|     display.display = log
 | |
|     main()
 |