# This code is part of Ansible, but is an independent component. # This particular file snippet, and this file snippet only, is BSD licensed. # Modules you write using this snippet, which is embedded dynamically by Ansible # still belong to the author of the module, and may assign their own license # to the complete work. # # Copyright (c), Jonathan Mainguy , 2015 # Most of this was originally added by Sven Schliesing @muffl0n in the mysql_user.py module # # Simplified BSD License (see simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) from __future__ import (absolute_import, division, print_function) from functools import reduce __metaclass__ = type import os from ansible.module_utils.six.moves import configparser from ansible.module_utils._text import to_native from ansible_collections.community.mysql.plugins.module_utils import pymysql as mysql_driver from ansible_collections.community.mysql.plugins.module_utils.database import mysql_quote_identifier def parse_from_mysql_config_file(cnf): # Default values of comment_prefix is '#' and ';'. # '!' added to prevent a parsing error # when a config file contains !includedir parameter. cp = configparser.ConfigParser(comment_prefixes=('#', ';', '!')) cp.read(cnf) return cp def mysql_connect(module, login_user=None, login_password=None, config_file='', ssl_cert=None, ssl_key=None, ssl_ca=None, db=None, cursor_class=None, connect_timeout=30, autocommit=False, config_overrides_defaults=False, check_hostname=None): config = {} if config_file and os.path.exists(config_file): config['read_default_file'] = config_file if config_overrides_defaults: try: cp = parse_from_mysql_config_file(config_file) except Exception as e: module.fail_json(msg="Failed to parse %s: %s" % (config_file, to_native(e))) # Override some commond defaults with values from config file if needed if cp and cp.has_section('client'): try: module.params['login_host'] = cp.get('client', 'host', fallback=module.params['login_host']) module.params['login_port'] = cp.getint('client', 'port', fallback=module.params['login_port']) except Exception as e: if "got an unexpected keyword argument 'fallback'" in e.message: module.fail_json(msg='To use config_overrides_defaults, ' 'it needs Python 3.5+ as the default interpreter on a target host') if ssl_ca is not None or ssl_key is not None or ssl_cert is not None or check_hostname is not None: config['ssl'] = {} if module.params['login_unix_socket']: config['unix_socket'] = module.params['login_unix_socket'] else: config['host'] = module.params['login_host'] config['port'] = module.params['login_port'] # If login_user or login_password are given, they should override the # config file if login_user is not None: config['user'] = login_user if login_password is not None: config['password'] = login_password if ssl_cert is not None: config['ssl']['cert'] = ssl_cert if ssl_key is not None: config['ssl']['key'] = ssl_key if ssl_ca is not None: config['ssl']['ca'] = ssl_ca if db is not None: config['database'] = db if connect_timeout is not None: config['connect_timeout'] = connect_timeout if check_hostname is not None: config['ssl']['check_hostname'] = check_hostname db_connection = mysql_driver.connect(autocommit=autocommit, **config) # Monkey patch the Connection class to close the connection when garbage collected def _conn_patch(conn_self): conn_self.close() db_connection.__class__.__del__ = _conn_patch # Patched if cursor_class == 'DictCursor': return db_connection.cursor(**{'cursor': mysql_driver.cursors.DictCursor}), db_connection else: return db_connection.cursor(), db_connection def mysql_common_argument_spec(): return dict( login_user=dict(type='str', default=None), login_password=dict(type='str', no_log=True), login_host=dict(type='str', default='localhost'), login_port=dict(type='int', default=3306), login_unix_socket=dict(type='str'), config_file=dict(type='path', default='~/.my.cnf'), connect_timeout=dict(type='int', default=30), client_cert=dict(type='path', aliases=['ssl_cert']), client_key=dict(type='path', aliases=['ssl_key']), ca_cert=dict(type='path', aliases=['ssl_ca']), check_hostname=dict(type='bool', default=None), ) def get_server_version(cursor): """Returns a string representation of the server version.""" cursor.execute("SELECT VERSION() AS version") result = cursor.fetchone() if isinstance(result, dict): version_str = result['version'] else: version_str = result[0] return version_str def get_server_implementation(cursor): if 'mariadb' in get_server_version(cursor).lower(): return "mariadb" else: return "mysql" def set_session_vars(module, cursor, session_vars): """Set session vars.""" for var, value in session_vars.items(): query = "SET SESSION %s = " % mysql_quote_identifier(var, 'vars') try: cursor.execute(query + "%s", (value,)) except Exception as e: module.fail_json(msg='Failed to execute %s%s: %s' % (query, value, e))