diff --git a/lib/ansible/playbook/__init__.py b/lib/ansible/playbook/__init__.py index 3f9130e153..106f400419 100644 --- a/lib/ansible/playbook/__init__.py +++ b/lib/ansible/playbook/__init__.py @@ -312,7 +312,7 @@ class PlayBook(object): conditional=task.only_if, callbacks=self.runner_callbacks, sudo=task.sudo, sudo_user=task.sudo_user, transport=task.transport, sudo_pass=task.sudo_pass, is_playbook=True, - check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args, + check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args, accelerate=task.play.accelerate, error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR ) diff --git a/lib/ansible/playbook/play.py b/lib/ansible/playbook/play.py index 3f968e2d34..5ed4525c47 100644 --- a/lib/ansible/playbook/play.py +++ b/lib/ansible/playbook/play.py @@ -29,7 +29,7 @@ class Play(object): __slots__ = [ 'hosts', 'name', 'vars', 'vars_prompt', 'vars_files', - 'handlers', 'remote_user', 'remote_port', + 'handlers', 'remote_user', 'remote_port', 'accelerate', 'sudo', 'sudo_user', 'transport', 'playbook', 'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks', 'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct' @@ -39,7 +39,7 @@ class Play(object): # and don't line up 1:1 with how they are stored VALID_KEYS = [ 'hosts', 'name', 'vars', 'vars_prompt', 'vars_files', - 'tasks', 'handlers', 'user', 'port', 'include', + 'tasks', 'handlers', 'user', 'port', 'include', 'accelerate', 'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial', 'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage' ] @@ -103,6 +103,7 @@ class Play(object): self.gather_facts = ds.get('gather_facts', None) self.remote_port = self.remote_port self.any_errors_fatal = ds.get('any_errors_fatal', False) + self.accelerate = ds.get('accelerate', False) self.max_fail_pct = int(ds.get('max_fail_percentage', 100)) load_vars = {} diff --git a/lib/ansible/runner/__init__.py b/lib/ansible/runner/__init__.py index 45e6cb8d07..0980c721a5 100644 --- a/lib/ansible/runner/__init__.py +++ b/lib/ansible/runner/__init__.py @@ -138,7 +138,8 @@ class Runner(object): diff=False, # whether to show diffs for template files that change environment=None, # environment variables (as dict) to use inside the command complex_args=None, # structured data in addition to module_args, must be a dict - error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False + error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR, # ex. False + accelerate=False, # use fireball acceleration ): # used to lock multiprocess inputs and outputs at various levels @@ -179,11 +180,16 @@ class Runner(object): self.environment = environment self.complex_args = complex_args self.error_on_undefined_vars = error_on_undefined_vars + self.accelerate = accelerate self.callbacks.runner = self - # if the transport is 'smart' see if SSH can support ControlPersist if not use paramiko - # 'smart' is the default since 1.2.1/1.3 - if self.transport == 'smart': + if self.accelerate: + # if we're using accelerated mode, force the local + # transport to fireball2 + self.transport = "fireball2" + elif self.transport == 'smart': + # if the transport is 'smart' see if SSH can support ControlPersist if not use paramiko + # 'smart' is the default since 1.2.1/1.3 cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) (out, err) = cmd.communicate() if "Bad configuration option" in err: diff --git a/lib/ansible/runner/connection_plugins/fireball2.py b/lib/ansible/runner/connection_plugins/fireball2.py index dbc881e4e6..9524084512 100644 --- a/lib/ansible/runner/connection_plugins/fireball2.py +++ b/lib/ansible/runner/connection_plugins/fireball2.py @@ -19,7 +19,9 @@ import json import os import base64 import socket +import struct from ansible.callbacks import vvv +from ansible.runner.connection_plugins.ssh import Connection as SSHConnection from ansible import utils from ansible import errors from ansible import constants @@ -27,32 +29,68 @@ from ansible import constants class Connection(object): ''' raw socket accelerated connection ''' - def __init__(self, runner, host, port, *args, **kwargs): + def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs): + + self.ssh = SSHConnection( + runner=runner, + host=host, + port=port, + user=user, + password=password, + private_key_file=private_key_file + ) self.runner = runner + self.host = host + self.context = None + self.conn = None + self.key = utils.key_for_hostname(host) + self.fbport = constants.FIREBALL2_PORT + self.is_connected = False # attempt to work around shared-memory funness if getattr(self.runner, 'aes_keys', None): utils.AES_KEYS = self.runner.aes_keys - self.host = host - self.context = None - self.conn = None - self.cipher = AES256Cipher() + def _execute_fb_module(self): + args = "password=%s" % base64.b64encode(self.key.__str__()) + self.ssh.connect() + return self.runner._execute_module(self.ssh, "/root/.ansible/tmp", 'fireball2', args, inject={"password":self.key}) - if port is None: - self.port = constants.FIREBALL2_PORT - else: - self.port = port - - def connect(self): + def connect(self, allow_ssh=True): ''' activates the connection object ''' - self.conn = socket.socket() - self.conn.connect((self.host,self.port)) + if self.is_connected: + return self + try: + self.conn = socket.socket() + self.conn.connect((self.host,self.fbport)) + except: + if allow_ssh: + print "Falling back to ssh to startup accelerated mode" + res = self._execute_fb_module() + return self.connect(allow_ssh=False) + else: + raise errors.AnsibleError("Failed to connect to %s:%s" % (self.host,self.fbport)) + self.is_connected = True return self + def send_data(self, data): + packed_len = struct.pack('Q',len(data)) + return self.conn.sendall(packed_len + data) + + def recv_data(self): + header_len = 8 # size of a packed unsigned long long + data = b"" + while len(data) < header_len: + data += self.conn.recv(1024) + data_len = struct.unpack('Q',data[:header_len])[0] + data = data[header_len:] + while len(data) < data_len: + data += self.conn.recv(1024) + return data + def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'): ''' run a command on the remote host ''' @@ -65,12 +103,12 @@ class Connection(object): executable=executable, ) data = utils.jsonify(data) - data = self.cipher.encrypt(data) - if self.conn.sendall(data): + data = utils.encrypt(self.key, data) + if self.send_data(data): raise errors.AnisbleError("Failed to send command to %s:%s" % (self.host,self.port)) - response = self.conn.recv(2048) - response = self.cipher.decrypt(response) + response = self.recv_data() + response = utils.decrypt(self.key, response) response = utils.parse_json(response) return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr','')) @@ -83,18 +121,18 @@ class Connection(object): if not os.path.exists(in_path): raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path) - data = base64.file(in_path).read() + data = file(in_path).read() data = base64.b64encode(data) data = dict(mode='put', data=data, out_path=out_path) # TODO: support chunked file transfer data = utils.jsonify(data) - data = self.cipher.encrypt(data) - if self.conn.sendall(data): + data = utils.encrypt(self.key, data) + if self.send_data(data): raise errors.AnsibleError("failed to send the file to %s:%s" % (self.host,self.port)) - response = self.conn.recv(2048) - response = self.cipher.decrypt(response) + response = self.recv_data() + response = utils.decrypt(self.key, data) response = utils.parse_json(response) # no meaningful response needed for this @@ -105,12 +143,12 @@ class Connection(object): data = dict(mode='fetch', in_path=in_path) data = utils.jsonify(data) - data = self.cipher.encrypt(data) - if self.conn.sendall(data): + data = utils.encrypt(self.key, data) + if self.send_data(data): raise errors.AnsibleError("failed to initiate the file fetch with %s:%s" % (self.host,self.port)) - response = self.socket.recv(2048) - response = self.cipher.decrypt(response) + response = self.recv_data() + response = utils.decrypt(self.key, data) response = utils.parse_json(response) response = response['data'] response = base64.b64decode(response) diff --git a/lib/ansible/utils/__init__.py b/lib/ansible/utils/__init__.py index f431b2a73c..63bdd3671c 100644 --- a/lib/ansible/utils/__init__.py +++ b/lib/ansible/utils/__init__.py @@ -31,7 +31,6 @@ import ansible.constants as C import time import StringIO import stat -import string import termios import tty import pipes @@ -41,11 +40,6 @@ import warnings import traceback import getpass -import hmac -from Crypto.Cipher import -from Crypto import Random -from Crypto.Random.random import StrongRandom - VERBOSITY=0 MAX_FILE_SIZE_FOR_DIFF=1*1024*1024 @@ -57,10 +51,8 @@ except ImportError: try: from hashlib import md5 as _md5 - from hashlib import sha1 as _sha1 except ImportError: from md5 import md5 as _md5 - from sha1 import sha1 as _sha1 PASSLIB_AVAILABLE = False try: @@ -69,128 +61,51 @@ try: except: pass +KEYCZAR_AVAILABLE=False +try: + import keyczar.errors as key_errors + from keyczar.keys import AesKey + KEYCZAR_AVAILABLE=True +except ImportError: + pass + ############################################################### -# Abstractions around PyCrypto +# Abstractions around keyczar ############################################################### -class AES256Cipher(object): - """ - Class abstraction of an AES 256 cipher. This class - also keeps track of the time since the key was last - generated, so you know when to rekey. Rekeying would - be done as follows: +def key_for_hostname(hostname): + # fireball mode is an implementation of ansible firing up zeromq via SSH + # to use no persistent daemons or key management - k = AES256Cipher.gen_key() - - AES26Cipher.set_key(k) + if not KEYCZAR_AVAILABLE: + raise errors.AnsibleError("python-keyczar must be installed to use fireball mode") - From this point on the new key would be used until - the lifetime is exceeded. - """ - def __init__(self, lifetime=60*30, mode=AES.MODE_CFB): - self.lifetime = lifetime - self.mode = mode - self.set_key(self.gen_key()) + key_path = os.path.expanduser("~/.fireball.keys") + if not os.path.exists(key_path): + os.makedirs(key_path) + key_path = os.path.expanduser("~/.fireball.keys/%s" % hostname) - def gen_key(self): - """ - Generates a 256-bit (32 byte) key to be used for the - AES block encryption. - """ - return b"".join(StrongRandom().sample(string.letters+string.digits+string.punctuation,32)) + # use new AES keys every 2 hours, which means fireball must not allow running for longer either + if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2): + key = AesKey.Generate() + fh = open(key_path, "w") + fh.write(str(key)) + fh.close() + return key + else: + fh = open(key_path) + key = AesKey.Read(fh.read()) + fh.close() + return key - def set_key(self,key): - """ - Sets the internal key to the one provided and resets the - internal time to now. This key should ONLY be set to one - generated by gen_key() - """ - self.init_time = time.time() - self.key = key +def encrypt(key, msg): + return key.Encrypt(msg) - def should_rekey(self): - """ - Returns true if the lifetime of the current key has - exceeded the set lifetime. - """ - if (time.time() - self.init_time) > self.lifetime: - return True - else: - return False - - def _pad(self, msg): - """ - Adds padding to the message so that it is a full - AES block size. Used during encryption of the message. - """ - pad = AES.block_size - len(msg) % AES.block_size - return msg + pad * chr(pad) - - def _unpad(self, msg): - """ - Strips out the padding that _pad added. Used during - the decryption of the message. - """ - pad = ord(msg[-1]) - return msg[:-pad] - - def gen_sig(self, msg): - """ - Generates an HMAC-SHA1 signature for the message - """ - return hmac.new(self.key, msg, _sha1).digest() - - def validate_sig(self, msg, sig): - """ - Verifies the generated signature of the message matches - the signature provided. - """ - new_sig = self.gen_sig(msg) - return (new_sig == sig) - - def encrypt(self, msg): - """ - Encrypt the message using AES. The signature - is appended to the end of the message and is - used to verify the integrity of the IV and data. - - Returns a base64-encoded version of the following: - - rval[0:16] = initialization vector - rval[16:-20] = cipher text - rval[-20:] = signature - """ - msg = self._pad(msg) - iv = Random.new().read(AES.block_size) - cipher = AES.new(self.key, self.mode, iv) - data = iv + cipher.encrypt(msg) - sig = self.gen_sig(data) - return (data + sig).encode('base64') - - def decrypt(self, msg): - """ - Decrypt the message using AES. The signature is - used to verify the IV and data before decoding to - ensure the integrity of the message. This is an - HMAC-SHA1 hash, so it is always 20 characters - - The incoming message format (after base64 decoding) - is as follows: - - msg[0:16] = initialization vector - msg[16:-20] = cipher text - msg[-20:] = signature (HMAC-SHA1) - - Returns the plain-text of the cipher. - """ - msg = msg.decode('base64') - data = msg[0:-20] # iv + cipher text - msig = msg[-20:] # hmac-sha1 hash - if not self.validate_sig(data,msig): - raise Exception("Failed to validate the message signature") - iv = msg[:AES.block_size] - cipher = AES.new(self.key, self.mode, iv) - return self._unpad(cipher.decrypt(msg)[AES.block_size:]) +def decrypt(key, msg): + try: + return key.Decrypt(msg) + except key_errors.InvalidSignatureError: + raise errors.AnsibleError("decryption failed") ############################################################### # UTILITY FUNCTIONS FOR COMMAND LINE TOOLS diff --git a/library/utilities/fireball2 b/library/utilities/fireball2 new file mode 100644 index 0000000000..e92d0817d8 --- /dev/null +++ b/library/utilities/fireball2 @@ -0,0 +1,284 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, James Cammarata +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +DOCUMENTATION = ''' +--- +module: fireball2 +short_description: Enable fireball2 mode on remote node +description: + - This modules launches an ephemeral I(fireball2) daemon on the remote node which + Ansible can use to communicate with nodes at high speed. + - The daemon listens on a configurable port for a configurable amount of time. + - Starting a new fireball2 as a given user terminates any existing user fireballs2. + - Fireball mode is AES encrypted +version_added: "1.3" +options: + port: + description: + - TCP port for the socket connection + required: false + default: 5099 + aliases: [] + minutes: + description: + - The I(fireball2) listener daemon is started on nodes and will stay around for + this number of minutes before turning itself off. + required: false + default: 30 +notes: + - See the advanced playbooks chapter for more about using fireball2 mode. +requirements: [ "pycrypto" ] +author: James Cammarata +''' + +EXAMPLES = ''' +# To use fireball2 mode, simple add "accelerated: true" to your play. The initial +# key exchange and starting up of the daemon will occur over SSH, but all commands and +# subsequent actions will be conducted over the raw socket connection using AES encryption + +- hosts: devservers + accelerated: true + tasks: + - command: /usr/bin/anything +''' + +import os +import sys +import shutil +import socket +import struct +import time +import base64 +import syslog +import signal +import time +import signal +import traceback + +import SocketServer + +syslog.openlog('ansible-%s' % os.path.basename(__file__)) +PIDFILE = os.path.expanduser("~/.fireball2.pid") + +def log(msg): + syslog.syslog(syslog.LOG_NOTICE, msg) + +if os.path.exists(PIDFILE): + try: + data = int(open(PIDFILE).read()) + try: + os.kill(data, signal.SIGKILL) + except OSError: + pass + except ValueError: + pass + os.unlink(PIDFILE) + +HAS_KEYCZAR = False +try: + from keyczar.keys import AesKey + HAS_KEYCZAR = True +except ImportError: + pass + +# NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move +# this into utils.module_common and probably should anyway + +def daemonize_self(module, password, port, minutes): + # daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 + try: + pid = os.fork() + if pid > 0: + log("exiting pid %s" % pid) + # exit first parent + module.exit_json(msg="daemonized fireball2 on port %s for %s minutes" % (port, minutes)) + except OSError, e: + log("fork #1 failed: %d (%s)" % (e.errno, e.strerror)) + sys.exit(1) + + # decouple from parent environment + os.chdir("/") + os.setsid() + os.umask(022) + + # do second fork + try: + pid = os.fork() + if pid > 0: + log("daemon pid %s, writing %s" % (pid, PIDFILE)) + pid_file = open(PIDFILE, "w") + pid_file.write("%s" % pid) + pid_file.close() + log("pidfile written") + sys.exit(0) + except OSError, e: + log("fork #2 failed: %d (%s)" % (e.errno, e.strerror)) + sys.exit(1) + + dev_null = file('/dev/null','rw') + os.dup2(dev_null.fileno(), sys.stdin.fileno()) + os.dup2(dev_null.fileno(), sys.stdout.fileno()) + os.dup2(dev_null.fileno(), sys.stderr.fileno()) + log("daemonizing successful") + +#class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): +class ThreadedTCPServer(SocketServer.ThreadingTCPServer): + def __init__(self, server_address, RequestHandlerClass, module, password): + self.module = module + self.key = AesKey.Read(password) + self.allow_reuse_address = True + self.timeout = None + SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) + +class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): + def send_data(self, data): + packed_len = struct.pack('Q', len(data)) + return self.request.sendall(packed_len + data) + + def recv_data(self): + header_len = 8 # size of a packed unsigned long long + data = b"" + while len(data) < header_len: + d = self.request.recv(1024) + if not d: + return None + data += d + data_len = struct.unpack('Q',data[:header_len])[0] + data = data[header_len:] + while len(data) < data_len: + data += self.request.recv(1024) + return data + + def handle(self): + while True: + data = self.recv_data() + if not data: + break + try: + data = self.server.key.Decrypt(data) + except: + log("bad decrypt, skipping...") + data2 = json.dumps(dict(rc=1)) + data2 = self.server.key.Encrypt(data2) + send_data(client, data2) + return + + data = json.loads(data) + + mode = data['mode'] + response = {} + if mode == 'command': + response = self.command(data) + elif mode == 'put': + response = self.put(data) + elif mode == 'fetch': + response = self.fetch(data) + + data2 = json.dumps(response) + data2 = self.server.key.Encrypt(data2) + self.send_data(data2) + + def command(self, data): + if 'cmd' not in data: + return dict(failed=True, msg='internal error: cmd is required') + if 'tmp_path' not in data: + return dict(failed=True, msg='internal error: tmp_path is required') + if 'executable' not in data: + return dict(failed=True, msg='internal error: executable is required') + + log("executing: %s" % data['cmd']) + rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True) + if stdout is None: + stdout = '' + if stderr is None: + stderr = '' + log("got stdout: %s" % stdout) + + return dict(rc=rc, stdout=stdout, stderr=stderr) + + def fetch(self, data): + if 'in_path' not in data: + return dict(failed=True, msg='internal error: in_path is required') + + # FIXME: should probably support chunked file transfer for binary files + # at some point. For now, just base64 encodes the file + # so don't use it to move ISOs, use rsync. + + fh = open(data['in_path']) + data = base64.b64encode(fh.read()) + return dict(data=data) + + def put(self, data): + if 'data' not in data: + return dict(failed=True, msg='internal error: data is required') + if 'out_path' not in data: + return dict(failed=True, msg='internal error: out_path is required') + + # FIXME: should probably support chunked file transfer for binary files + # at some point. For now, just base64 encodes the file + # so don't use it to move ISOs, use rsync. + + fh = open(data['out_path'], 'w') + fh.write(base64.b64decode(data['data'])) + fh.close() + + return dict() + +def daemonize(module, password, port, minutes): + try: + daemonize_self(module, password, port, minutes) + + def catcher(signum, _): + module.exit_json(msg='timer expired') + + signal.signal(signal.SIGALRM, catcher) + signal.setitimer(signal.ITIMER_REAL, 60 * minutes) + + server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password) + server.allow_reuse_address = True + + server.serve_forever(poll_interval=1.0) + except Exception, e: + tb = traceback.format_exc() + log("exception caught, exiting fireball mode: %s\n%s" % (e, tb)) + sys.exit(0) + +def main(): + module = AnsibleModule( + argument_spec = dict( + port=dict(required=False, default=5099), + password=dict(required=True), + minutes=dict(required=False, default=30), + ) + ) + + password = base64.b64decode(module.params['password']) + port = module.params['port'] + minutes = int(module.params['minutes']) + + if not HAS_KEYCZAR: + module.fail_json(msg="keyczar is not installed") + + daemonize(module, password, port, minutes) + + +# this is magic, see lib/ansible/module_common.py +#<> +main()