Remove Black formatting in an attempt to make codecov happy

This commit is contained in:
Jorge-Rodriguez 2020-07-26 10:52:37 +03:00
parent 69d3f8b3c2
commit 25441dfaaf

View file

@ -6,11 +6,10 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
DOCUMENTATION = r""" DOCUMENTATION = r'''
--- ---
module: mysql_user module: mysql_user
short_description: Adds or removes a user from a MySQL database short_description: Adds or removes a user from a MySQL database
@ -142,9 +141,9 @@ author:
extends_documentation_fragment: extends_documentation_fragment:
- community.mysql.mysql - community.mysql.mysql
""" '''
EXAMPLES = r""" EXAMPLES = r'''
- name: Removes anonymous user account for localhost - name: Removes anonymous user account for localhost
community.mysql.mysql_user: community.mysql.mysql_user:
name: '' name: ''
@ -286,109 +285,55 @@ EXAMPLES = r"""
# [client] # [client]
# user=root # user=root
# password=n<_665{vS43y # password=n<_665{vS43y
""" '''
import re import re
import string import string
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
from ansible_collections.community.mysql.plugins.module_utils.database import SQLParseError from ansible_collections.community.mysql.plugins.module_utils.database import SQLParseError
from ansible_collections.community.mysql.plugins.module_utils.mysql import ( from ansible_collections.community.mysql.plugins.module_utils.mysql import mysql_connect, mysql_driver, mysql_driver_fail_msg
mysql_connect,
mysql_driver,
mysql_driver_fail_msg,
)
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
VALID_PRIVS = frozenset( VALID_PRIVS = frozenset(('CREATE', 'DROP', 'GRANT', 'GRANT OPTION',
( 'LOCK TABLES', 'REFERENCES', 'EVENT', 'ALTER',
"CREATE", 'DELETE', 'INDEX', 'INSERT', 'SELECT', 'UPDATE',
"DROP", 'CREATE TEMPORARY TABLES', 'TRIGGER', 'CREATE VIEW',
"GRANT", 'SHOW VIEW', 'ALTER ROUTINE', 'CREATE ROUTINE',
"GRANT OPTION", 'EXECUTE', 'FILE', 'CREATE TABLESPACE', 'CREATE USER',
"LOCK TABLES", 'PROCESS', 'PROXY', 'RELOAD', 'REPLICATION CLIENT',
"REFERENCES", 'REPLICATION SLAVE', 'SHOW DATABASES', 'SHUTDOWN',
"EVENT", 'SUPER', 'ALL', 'ALL PRIVILEGES', 'USAGE', 'REQUIRESSL',
"ALTER", 'CREATE ROLE', 'DROP ROLE', 'APPLICATION_PASSWORD_ADMIN',
"DELETE", 'AUDIT_ADMIN', 'BACKUP_ADMIN', 'BINLOG_ADMIN',
"INDEX", 'BINLOG_ENCRYPTION_ADMIN', 'CLONE_ADMIN', 'CONNECTION_ADMIN',
"INSERT", 'ENCRYPTION_KEY_ADMIN', 'FIREWALL_ADMIN', 'FIREWALL_USER',
"SELECT", 'GROUP_REPLICATION_ADMIN', 'INNODB_REDO_LOG_ARCHIVE',
"UPDATE", 'NDB_STORED_USER', 'PERSIST_RO_VARIABLES_ADMIN',
"CREATE TEMPORARY TABLES", 'REPLICATION_APPLIER', 'REPLICATION_SLAVE_ADMIN',
"TRIGGER", 'RESOURCE_GROUP_ADMIN', 'RESOURCE_GROUP_USER',
"CREATE VIEW", 'ROLE_ADMIN', 'SESSION_VARIABLES_ADMIN', 'SET_USER_ID',
"SHOW VIEW", 'SYSTEM_USER', 'SYSTEM_VARIABLES_ADMIN', 'SYSTEM_USER',
"ALTER ROUTINE", 'TABLE_ENCRYPTION_ADMIN', 'VERSION_TOKEN_ADMIN',
"CREATE ROUTINE", 'XA_RECOVER_ADMIN', 'LOAD FROM S3', 'SELECT INTO S3',
"EXECUTE", 'INVOKE LAMBDA',
"FILE", 'ALTER ROUTINE',
"CREATE TABLESPACE", 'BINLOG ADMIN',
"CREATE USER", 'BINLOG MONITOR',
"PROCESS", 'BINLOG REPLAY',
"PROXY", 'CONNECTION ADMIN',
"RELOAD", 'READ_ONLY ADMIN',
"REPLICATION CLIENT", 'REPLICATION MASTER ADMIN',
"REPLICATION SLAVE", 'REPLICATION SLAVE',
"SHOW DATABASES", 'REPLICATION SLAVE ADMIN',
"SHUTDOWN", 'SET USER',))
"SUPER",
"ALL",
"ALL PRIVILEGES",
"USAGE",
"REQUIRESSL",
"CREATE ROLE",
"DROP ROLE",
"APPLICATION_PASSWORD_ADMIN",
"AUDIT_ADMIN",
"BACKUP_ADMIN",
"BINLOG_ADMIN",
"BINLOG_ENCRYPTION_ADMIN",
"CLONE_ADMIN",
"CONNECTION_ADMIN",
"ENCRYPTION_KEY_ADMIN",
"FIREWALL_ADMIN",
"FIREWALL_USER",
"GROUP_REPLICATION_ADMIN",
"INNODB_REDO_LOG_ARCHIVE",
"NDB_STORED_USER",
"PERSIST_RO_VARIABLES_ADMIN",
"REPLICATION_APPLIER",
"REPLICATION_SLAVE_ADMIN",
"RESOURCE_GROUP_ADMIN",
"RESOURCE_GROUP_USER",
"ROLE_ADMIN",
"SESSION_VARIABLES_ADMIN",
"SET_USER_ID",
"SYSTEM_USER",
"SYSTEM_VARIABLES_ADMIN",
"SYSTEM_USER",
"TABLE_ENCRYPTION_ADMIN",
"VERSION_TOKEN_ADMIN",
"XA_RECOVER_ADMIN",
"LOAD FROM S3",
"SELECT INTO S3",
"INVOKE LAMBDA",
"ALTER ROUTINE",
"BINLOG ADMIN",
"BINLOG MONITOR",
"BINLOG REPLAY",
"CONNECTION ADMIN",
"READ_ONLY ADMIN",
"REPLICATION MASTER ADMIN",
"REPLICATION SLAVE",
"REPLICATION SLAVE ADMIN",
"SET USER",
)
)
class InvalidPrivsError(Exception): class InvalidPrivsError(Exception):
pass pass
# =========================================== # ===========================================
# MySQL module specific support methods. # MySQL module specific support methods.
# #
@ -399,9 +344,9 @@ def use_old_user_mgmt(cursor):
cursor.execute("SELECT VERSION()") cursor.execute("SELECT VERSION()")
result = cursor.fetchone() result = cursor.fetchone()
version_str = result[0] version_str = result[0]
version = version_str.split(".") version = version_str.split('.')
if "mariadb" in version_str.lower(): if 'mariadb' in version_str.lower():
# Prior to MariaDB 10.2 # Prior to MariaDB 10.2
if int(version[0]) * 1000 + int(version[1]) < 10002: if int(version[0]) * 1000 + int(version[1]) < 10002:
return True return True
@ -416,13 +361,13 @@ def use_old_user_mgmt(cursor):
def get_mode(cursor): def get_mode(cursor):
cursor.execute("SELECT @@GLOBAL.sql_mode") cursor.execute('SELECT @@GLOBAL.sql_mode')
result = cursor.fetchone() result = cursor.fetchone()
mode_str = result[0] mode_str = result[0]
if "ANSI" in mode_str: if 'ANSI' in mode_str:
mode = "ANSI" mode = 'ANSI'
else: else:
mode = "NOTANSI" mode = 'NOTANSI'
return mode return mode
@ -500,20 +445,9 @@ def get_grants(cursor, user, host):
return grants.split(", ") return grants.split(", ")
def user_add( def user_add(cursor, user, host, host_all, password, encrypted,
cursor, plugin, plugin_hash_string, plugin_auth_string, new_priv,
user, tls_requires, check_mode):
host,
host_all,
password,
encrypted,
plugin,
plugin_hash_string,
plugin_auth_string,
new_priv,
tls_requires,
check_mode,
):
# we cannot create users without a proper hostname # we cannot create users without a proper hostname
if host_all: if host_all:
return False return False
@ -567,27 +501,15 @@ def user_add(
def is_hash(password): def is_hash(password):
ishash = False ishash = False
if len(password) == 41 and password[0] == "*": if len(password) == 41 and password[0] == '*':
if frozenset(password[1:]).issubset(string.hexdigits): if frozenset(password[1:]).issubset(string.hexdigits):
ishash = True ishash = True
return ishash return ishash
def user_mod( def user_mod(cursor, user, host, host_all, password, encrypted,
cursor, plugin, plugin_hash_string, plugin_auth_string, new_priv,
user, append_privs, tls_requires, module):
host,
host_all,
password,
encrypted,
plugin,
plugin_hash_string,
plugin_auth_string,
new_priv,
append_privs,
tls_requires,
module,
):
changed = False changed = False
msg = "User unchanged" msg = "User unchanged"
grant_option = False grant_option = False
@ -604,46 +526,36 @@ def user_mod(
old_user_mgmt = use_old_user_mgmt(cursor) old_user_mgmt = use_old_user_mgmt(cursor)
# Get a list of valid columns in mysql.user table to check if Password and/or authentication_string exist # Get a list of valid columns in mysql.user table to check if Password and/or authentication_string exist
cursor.execute( cursor.execute("""
"""
SELECT COLUMN_NAME FROM information_schema.COLUMNS SELECT COLUMN_NAME FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = 'mysql' AND TABLE_NAME = 'user' AND COLUMN_NAME IN ('Password', 'authentication_string') WHERE TABLE_SCHEMA = 'mysql' AND TABLE_NAME = 'user' AND COLUMN_NAME IN ('Password', 'authentication_string')
ORDER BY COLUMN_NAME DESC LIMIT 1 ORDER BY COLUMN_NAME DESC LIMIT 1
""" """)
)
colA = cursor.fetchone() colA = cursor.fetchone()
cursor.execute( cursor.execute("""
"""
SELECT COLUMN_NAME FROM information_schema.COLUMNS SELECT COLUMN_NAME FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = 'mysql' AND TABLE_NAME = 'user' AND COLUMN_NAME IN ('Password', 'authentication_string') WHERE TABLE_SCHEMA = 'mysql' AND TABLE_NAME = 'user' AND COLUMN_NAME IN ('Password', 'authentication_string')
ORDER BY COLUMN_NAME ASC LIMIT 1 ORDER BY COLUMN_NAME ASC LIMIT 1
""" """)
)
colB = cursor.fetchone() colB = cursor.fetchone()
# Select hash from either Password or authentication_string, depending which one exists and/or is filled # Select hash from either Password or authentication_string, depending which one exists and/or is filled
cursor.execute( cursor.execute("""
"""
SELECT COALESCE( SELECT COALESCE(
CASE WHEN %s = '' THEN NULL ELSE %s END, CASE WHEN %s = '' THEN NULL ELSE %s END,
CASE WHEN %s = '' THEN NULL ELSE %s END CASE WHEN %s = '' THEN NULL ELSE %s END
) )
FROM mysql.user WHERE user = %%s AND host = %%s FROM mysql.user WHERE user = %%s AND host = %%s
""" """ % (colA[0], colA[0], colB[0], colB[0]), (user, host))
% (colA[0], colA[0], colB[0], colB[0]),
(user, host),
)
current_pass_hash = cursor.fetchone()[0] current_pass_hash = cursor.fetchone()[0]
if isinstance(current_pass_hash, bytes): if isinstance(current_pass_hash, bytes):
current_pass_hash = current_pass_hash.decode("ascii") current_pass_hash = current_pass_hash.decode('ascii')
if encrypted: if encrypted:
encrypted_password = password encrypted_password = password
if not is_hash(encrypted_password): if not is_hash(encrypted_password):
module.fail_json( module.fail_json(msg="encrypted was specified however it does not appear to be a valid hash expecting: *SHA1(SHA1(your_password))")
msg="encrypted was specified however it does not appear to be a valid hash expecting: *SHA1(SHA1(your_password))"
)
else: else:
if old_user_mgmt: if old_user_mgmt:
cursor.execute("SELECT PASSWORD(%s)", (password,)) cursor.execute("SELECT PASSWORD(%s)", (password,))
@ -660,10 +572,7 @@ def user_mod(
msg = "Password updated (old style)" msg = "Password updated (old style)"
else: else:
try: try:
cursor.execute( cursor.execute("ALTER USER %s@%s IDENTIFIED WITH mysql_native_password AS %s", (user, host, encrypted_password))
"ALTER USER %s@%s IDENTIFIED WITH mysql_native_password AS %s",
(user, host, encrypted_password),
)
msg = "Password updated (new style)" msg = "Password updated (new style)"
except (mysql_driver.Error) as e: except (mysql_driver.Error) as e:
# https://stackoverflow.com/questions/51600000/authentication-string-of-root-user-on-mysql # https://stackoverflow.com/questions/51600000/authentication-string-of-root-user-on-mysql
@ -671,7 +580,7 @@ def user_mod(
if e.args[0] == 1396: if e.args[0] == 1396:
cursor.execute( cursor.execute(
"UPDATE mysql.user SET plugin = %s, authentication_string = %s, Password = '' WHERE User = %s AND Host = %s", "UPDATE mysql.user SET plugin = %s, authentication_string = %s, Password = '' WHERE User = %s AND Host = %s",
("mysql_native_password", encrypted_password, user, host), ('mysql_native_password', encrypted_password, user, host)
) )
cursor.execute("FLUSH PRIVILEGES") cursor.execute("FLUSH PRIVILEGES")
msg = "Password forced update" msg = "Password forced update"
@ -681,9 +590,8 @@ def user_mod(
# Handle plugin authentication # Handle plugin authentication
if plugin: if plugin:
cursor.execute( cursor.execute("SELECT plugin, authentication_string FROM mysql.user "
"SELECT plugin, authentication_string FROM mysql.user " "WHERE user = %s AND host = %s", (user, host) "WHERE user = %s AND host = %s", (user, host))
)
current_plugin = cursor.fetchone() current_plugin = cursor.fetchone()
update = False update = False
@ -703,13 +611,9 @@ def user_mod(
if update: if update:
if plugin_hash_string: if plugin_hash_string:
cursor.execute( cursor.execute("ALTER USER %s@%s IDENTIFIED WITH %s AS %s", (user, host, plugin, plugin_hash_string))
"ALTER USER %s@%s IDENTIFIED WITH %s AS %s", (user, host, plugin, plugin_hash_string)
)
elif plugin_auth_string: elif plugin_auth_string:
cursor.execute( cursor.execute("ALTER USER %s@%s IDENTIFIED WITH %s BY %s", (user, host, plugin, plugin_auth_string))
"ALTER USER %s@%s IDENTIFIED WITH %s BY %s", (user, host, plugin, plugin_auth_string)
)
else: else:
cursor.execute("ALTER USER %s@%s IDENTIFIED WITH %s", (user, host, plugin)) cursor.execute("ALTER USER %s@%s IDENTIFIED WITH %s", (user, host, plugin))
changed = True changed = True
@ -819,23 +723,21 @@ def privileges_get(cursor, user, host):
grants = cursor.fetchall() grants = cursor.fetchall()
def pick(x): def pick(x):
if x == "ALL PRIVILEGES": if x == 'ALL PRIVILEGES':
return "ALL" return 'ALL'
else: else:
return x return x
for grant in grants: for grant in grants:
res = re.match( res = re.match("""GRANT (.+) ON (.+) TO (['`"]).*\\3@(['`"]).*\\4( IDENTIFIED BY PASSWORD (['`"]).+\\6)? ?(.*)""", grant[0])
"""GRANT (.+) ON (.+) TO (['`"]).*\\3@(['`"]).*\\4( IDENTIFIED BY PASSWORD (['`"]).+\\6)? ?(.*)""", grant[0]
)
if res is None: if res is None:
raise InvalidPrivsError("unable to parse the MySQL grant string: %s" % grant[0]) raise InvalidPrivsError('unable to parse the MySQL grant string: %s' % grant[0])
privileges = res.group(1).split(",") privileges = res.group(1).split(",")
privileges = [pick(x.strip()) for x in privileges] privileges = [pick(x.strip()) for x in privileges]
if "WITH GRANT OPTION" in res.group(7): if "WITH GRANT OPTION" in res.group(7):
privileges.append("GRANT") privileges.append('GRANT')
if "REQUIRE SSL" in res.group(7): if 'REQUIRE SSL' in res.group(7):
privileges.append("REQUIRESSL") privileges.append('REQUIRESSL')
db = res.group(2) db = res.group(2)
output.setdefault(db, []).extend(privileges) output.setdefault(db, []).extend(privileges)
return output return output
@ -852,83 +754,83 @@ def privileges_unpack(priv, mode):
The privilege USAGE stands for no privileges, so we add that in on *.* if it's The privilege USAGE stands for no privileges, so we add that in on *.* if it's
not specified in the string, as MySQL will always provide this by default. not specified in the string, as MySQL will always provide this by default.
""" """
if mode == "ANSI": if mode == 'ANSI':
quote = '"' quote = '"'
else: else:
quote = "`" quote = '`'
output = {} output = {}
privs = [] privs = []
for item in priv.strip().split("/"): for item in priv.strip().split('/'):
pieces = item.strip().rsplit(":", 1) pieces = item.strip().rsplit(':', 1)
dbpriv = pieces[0].rsplit(".", 1) dbpriv = pieces[0].rsplit(".", 1)
# Check for FUNCTION or PROCEDURE object types # Check for FUNCTION or PROCEDURE object types
parts = dbpriv[0].split(" ", 1) parts = dbpriv[0].split(" ", 1)
object_type = "" object_type = ''
if len(parts) > 1 and (parts[0] == "FUNCTION" or parts[0] == "PROCEDURE"): if len(parts) > 1 and (parts[0] == 'FUNCTION' or parts[0] == 'PROCEDURE'):
object_type = parts[0] + " " object_type = parts[0] + ' '
dbpriv[0] = parts[1] dbpriv[0] = parts[1]
# Do not escape if privilege is for database or table, i.e. # Do not escape if privilege is for database or table, i.e.
# neither quote *. nor .* # neither quote *. nor .*
for i, side in enumerate(dbpriv): for i, side in enumerate(dbpriv):
if side.strip("`") != "*": if side.strip('`') != '*':
dbpriv[i] = "%s%s%s" % (quote, side.strip("`"), quote) dbpriv[i] = '%s%s%s' % (quote, side.strip('`'), quote)
pieces[0] = object_type + ".".join(dbpriv) pieces[0] = object_type + '.'.join(dbpriv)
if "(" in pieces[1]: if '(' in pieces[1]:
output[pieces[0]] = re.split(r",\s*(?=[^)]*(?:\(|$))", pieces[1].upper()) output[pieces[0]] = re.split(r',\s*(?=[^)]*(?:\(|$))', pieces[1].upper())
for i in output[pieces[0]]: for i in output[pieces[0]]:
privs.append(re.sub(r"\s*\(.*\)", "", i)) privs.append(re.sub(r'\s*\(.*\)', '', i))
else: else:
output[pieces[0]] = pieces[1].upper().split(",") output[pieces[0]] = pieces[1].upper().split(',')
privs = output[pieces[0]] privs = output[pieces[0]]
new_privs = frozenset(privs) new_privs = frozenset(privs)
if not new_privs.issubset(VALID_PRIVS): if not new_privs.issubset(VALID_PRIVS):
raise InvalidPrivsError("Invalid privileges specified: %s" % new_privs.difference(VALID_PRIVS)) raise InvalidPrivsError('Invalid privileges specified: %s' % new_privs.difference(VALID_PRIVS))
if "*.*" not in output: if '*.*' not in output:
output["*.*"] = ["USAGE"] output['*.*'] = ['USAGE']
# if we are only specifying something like REQUIRESSL and/or GRANT (=WITH GRANT OPTION) in *.* # if we are only specifying something like REQUIRESSL and/or GRANT (=WITH GRANT OPTION) in *.*
# we still need to add USAGE as a privilege to avoid syntax errors # we still need to add USAGE as a privilege to avoid syntax errors
if "REQUIRESSL" in priv and not set(output["*.*"]).difference(set(["GRANT", "REQUIRESSL"])): if 'REQUIRESSL' in priv and not set(output['*.*']).difference(set(['GRANT', 'REQUIRESSL'])):
output["*.*"].append("USAGE") output['*.*'].append('USAGE')
return output return output
def privileges_revoke(cursor, user, host, db_table, priv, grant_option): def privileges_revoke(cursor, user, host, db_table, priv, grant_option):
# Escape '%' since mysql db.execute() uses a format string # Escape '%' since mysql db.execute() uses a format string
db_table = db_table.replace("%", "%%") db_table = db_table.replace('%', '%%')
if grant_option: if grant_option:
query = ["REVOKE GRANT OPTION ON %s" % db_table] query = ["REVOKE GRANT OPTION ON %s" % db_table]
query.append("FROM %s@%s") query.append("FROM %s@%s")
query = " ".join(query) query = ' '.join(query)
cursor.execute(query, (user, host)) cursor.execute(query, (user, host))
priv_string = ",".join([p for p in priv if p not in ("GRANT", "REQUIRESSL")]) priv_string = ",".join([p for p in priv if p not in ('GRANT', 'REQUIRESSL')])
query = ["REVOKE %s ON %s" % (priv_string, db_table)] query = ["REVOKE %s ON %s" % (priv_string, db_table)]
query.append("FROM %s@%s") query.append("FROM %s@%s")
query = " ".join(query) query = ' '.join(query)
cursor.execute(query, (user, host)) cursor.execute(query, (user, host))
def privileges_grant(cursor, user, host, db_table, priv, tls_requires): def privileges_grant(cursor, user, host, db_table, priv, tls_requires):
# Escape '%' since mysql db.execute uses a format string and the # Escape '%' since mysql db.execute uses a format string and the
# specification of db and table often use a % (SQL wildcard) # specification of db and table often use a % (SQL wildcard)
db_table = db_table.replace("%", "%%") db_table = db_table.replace('%', '%%')
priv_string = ",".join([p for p in priv if p not in ("GRANT", "REQUIRESSL")]) priv_string = ",".join([p for p in priv if p not in ('GRANT', 'REQUIRESSL')])
query = ["GRANT %s ON %s" % (priv_string, db_table)] query = ["GRANT %s ON %s" % (priv_string, db_table)]
query.append("TO %s@%s") query.append("TO %s@%s")
params = (user, host) params = (user, host)
if tls_requires and not server_suports_requires_create(cursor): if tls_requires and not server_suports_requires_create(cursor):
query, params = mogrify_requires(" ".join(query), params, tls_requires) query, params = mogrify_requires(" ".join(query), params, tls_requires)
query = [query] query = [query]
if "REQUIRESSL" in priv and not tls_requires: if 'REQUIRESSL' in priv and not tls_requires:
query.append("REQUIRE SSL") query.append("REQUIRE SSL")
if "GRANT" in priv: if 'GRANT' in priv:
query.append("WITH GRANT OPTION") query.append("WITH GRANT OPTION")
query = " ".join(query) query = ' '.join(query)
cursor.execute(query, params) cursor.execute(query, params)
@ -941,9 +843,9 @@ def convert_priv_dict_to_str(priv):
Returns: Returns:
priv (str): String representation of input argument. priv (str): String representation of input argument.
""" """
priv_list = ["%s:%s" % (key, val) for key, val in iteritems(priv)] priv_list = ['%s:%s' % (key, val) for key, val in iteritems(priv)]
return "/".join(priv_list) return '/'.join(priv_list)
# TLS requires on user create statement is supported since MySQL 5.7 and MariaDB 10.2 # TLS requires on user create statement is supported since MySQL 5.7 and MariaDB 10.2
@ -984,9 +886,9 @@ def server_supports_alter_user(cursor):
""" """
cursor.execute("SELECT VERSION()") cursor.execute("SELECT VERSION()")
version_str = cursor.fetchone()[0] version_str = cursor.fetchone()[0]
version = version_str.split(".") version = version_str.split('.')
if "mariadb" in version_str.lower(): if 'mariadb' in version_str.lower():
# MariaDB 10.2 and later # MariaDB 10.2 and later
if int(version[0]) * 1000 + int(version[1]) >= 10002: if int(version[0]) * 1000 + int(version[1]) >= 10002:
return True return True
@ -1011,13 +913,11 @@ def get_resource_limits(cursor, user, host):
Returns: Dictionary containing current resource limits. Returns: Dictionary containing current resource limits.
""" """
query = ( query = ('SELECT max_questions AS MAX_QUERIES_PER_HOUR, '
"SELECT max_questions AS MAX_QUERIES_PER_HOUR, " 'max_updates AS MAX_UPDATES_PER_HOUR, '
"max_updates AS MAX_UPDATES_PER_HOUR, " 'max_connections AS MAX_CONNECTIONS_PER_HOUR, '
"max_connections AS MAX_CONNECTIONS_PER_HOUR, " 'max_user_connections AS MAX_USER_CONNECTIONS '
"max_user_connections AS MAX_USER_CONNECTIONS " 'FROM mysql.user WHERE User = %s AND Host = %s')
"FROM mysql.user WHERE User = %s AND Host = %s"
)
cursor.execute(query, (user, host)) cursor.execute(query, (user, host))
res = cursor.fetchone() res = cursor.fetchone()
@ -1025,10 +925,10 @@ def get_resource_limits(cursor, user, host):
return None return None
current_limits = { current_limits = {
"MAX_QUERIES_PER_HOUR": res[0], 'MAX_QUERIES_PER_HOUR': res[0],
"MAX_UPDATES_PER_HOUR": res[1], 'MAX_UPDATES_PER_HOUR': res[1],
"MAX_CONNECTIONS_PER_HOUR": res[2], 'MAX_CONNECTIONS_PER_HOUR': res[2],
"MAX_USER_CONNECTIONS": res[3], 'MAX_USER_CONNECTIONS': res[3],
} }
return current_limits return current_limits
@ -1083,10 +983,8 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode):
Returns: True, if changed, False otherwise. Returns: True, if changed, False otherwise.
""" """
if not server_supports_alter_user(cursor): if not server_supports_alter_user(cursor):
module.fail_json( module.fail_json(msg="The server version does not match the requirements "
msg="The server version does not match the requirements " "for resource_limits parameter. See module's documentation.")
"for resource_limits parameter. See module's documentation."
)
current_limits = get_resource_limits(cursor, user, host) current_limits = get_resource_limits(cursor, user, host)
@ -1101,10 +999,10 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode):
# If not check_mode # If not check_mode
tmp = [] tmp = []
for key, val in iteritems(needs_to_change): for key, val in iteritems(needs_to_change):
tmp.append("%s %s" % (key, val)) tmp.append('%s %s' % (key, val))
query = "ALTER USER %s@%s" query = "ALTER USER %s@%s"
query += " WITH %s" % " ".join(tmp) query += ' WITH %s' % ' '.join(tmp)
cursor.execute(query, (user, host)) cursor.execute(query, (user, host))
return True return True
@ -1117,32 +1015,32 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode):
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
login_user=dict(type="str"), login_user=dict(type='str'),
login_password=dict(type="str", no_log=True), login_password=dict(type='str', no_log=True),
login_host=dict(type="str", default="localhost"), login_host=dict(type='str', default='localhost'),
login_port=dict(type="int", default=3306), login_port=dict(type='int', default=3306),
login_unix_socket=dict(type="str"), login_unix_socket=dict(type='str'),
user=dict(type="str", required=True, aliases=["name"]), user=dict(type='str', required=True, aliases=['name']),
password=dict(type="str", no_log=True), password=dict(type='str', no_log=True),
encrypted=dict(type="bool", default=False), encrypted=dict(type='bool', default=False),
host=dict(type="str", default="localhost"), host=dict(type='str', default='localhost'),
host_all=dict(type="bool", default=False), host_all=dict(type="bool", default=False),
state=dict(type="str", default="present", choices=["absent", "present"]), state=dict(type='str', default='present', choices=['absent', 'present']),
priv=dict(type="raw"), priv=dict(type='raw'),
tls_requires=dict(type="dict"), tls_requires=dict(type='dict'),
append_privs=dict(type="bool", default=False), append_privs=dict(type='bool', default=False),
check_implicit_admin=dict(type="bool", default=False), check_implicit_admin=dict(type='bool', default=False),
update_password=dict(type="str", default="always", choices=["always", "on_create"], no_log=False), update_password=dict(type='str', default='always', choices=['always', 'on_create'], no_log=False),
connect_timeout=dict(type="int", default=30), connect_timeout=dict(type='int', default=30),
config_file=dict(type="path", default="~/.my.cnf"), config_file=dict(type='path', default='~/.my.cnf'),
sql_log_bin=dict(type="bool", default=True), sql_log_bin=dict(type='bool', default=True),
client_cert=dict(type="path", aliases=["ssl_cert"]), client_cert=dict(type='path', aliases=['ssl_cert']),
client_key=dict(type="path", aliases=["ssl_key"]), client_key=dict(type='path', aliases=['ssl_key']),
ca_cert=dict(type="path", aliases=["ssl_ca"]), ca_cert=dict(type='path', aliases=['ssl_ca']),
plugin=dict(default=None, type="str"), plugin=dict(default=None, type='str'),
plugin_hash_string=dict(default=None, type="str"), plugin_hash_string=dict(default=None, type='str'),
plugin_auth_string=dict(default=None, type="str"), plugin_auth_string=dict(default=None, type='str'),
resource_limits=dict(type="dict"), resource_limits=dict(type='dict'),
), ),
supports_check_mode=True, supports_check_mode=True,
) )
@ -1160,11 +1058,11 @@ def main():
connect_timeout = module.params["connect_timeout"] connect_timeout = module.params["connect_timeout"]
config_file = module.params["config_file"] config_file = module.params["config_file"]
append_privs = module.boolean(module.params["append_privs"]) append_privs = module.boolean(module.params["append_privs"])
update_password = module.params["update_password"] update_password = module.params['update_password']
ssl_cert = module.params["client_cert"] ssl_cert = module.params["client_cert"]
ssl_key = module.params["client_key"] ssl_key = module.params["client_key"]
ssl_ca = module.params["ca_cert"] ssl_ca = module.params["ca_cert"]
db = "" db = ''
sql_log_bin = module.params["sql_log_bin"] sql_log_bin = module.params["sql_log_bin"]
plugin = module.params["plugin"] plugin = module.params["plugin"]
plugin_hash_string = module.params["plugin_hash_string"] plugin_hash_string = module.params["plugin_hash_string"]
@ -1183,29 +1081,18 @@ def main():
try: try:
if check_implicit_admin: if check_implicit_admin:
try: try:
cursor, db_conn = mysql_connect( cursor, db_conn = mysql_connect(module, "root", "", config_file, ssl_cert, ssl_key, ssl_ca, db,
module, "root", "", config_file, ssl_cert, ssl_key, ssl_ca, db, connect_timeout=connect_timeout connect_timeout=connect_timeout)
)
except Exception: except Exception:
pass pass
if not cursor: if not cursor:
cursor, db_conn = mysql_connect( cursor, db_conn = mysql_connect(module, login_user, login_password, config_file, ssl_cert, ssl_key, ssl_ca, db,
module, connect_timeout=connect_timeout,
login_user,
login_password,
config_file,
ssl_cert,
ssl_key,
ssl_ca,
db,
connect_timeout=connect_timeout,
) )
except Exception as e: except Exception as e:
module.fail_json( module.fail_json(msg="unable to connect to database, check login_user and login_password are correct or %s has the credentials. "
msg="unable to connect to database, check login_user and login_password are correct or %s has the credentials. " "Exception message: %s" % (config_file, to_native(e)))
"Exception message: %s" % (config_file, to_native(e))
)
if not sql_log_bin: if not sql_log_bin:
cursor.execute("SET SQL_LOG_BIN=0;") cursor.execute("SET SQL_LOG_BIN=0;")
@ -1224,37 +1111,13 @@ def main():
if user_exists(cursor, user, host, host_all): if user_exists(cursor, user, host, host_all):
try: try:
if update_password == "always": if update_password == "always":
changed, msg = user_mod( changed, msg = user_mod(cursor, user, host, host_all, password, encrypted,
cursor, plugin, plugin_hash_string, plugin_auth_string,
user, priv, append_privs, tls_requires, module)
host,
host_all,
password,
encrypted,
plugin,
plugin_hash_string,
plugin_auth_string,
priv,
append_privs,
tls_requires,
module,
)
else: else:
changed, msg = user_mod( changed, msg = user_mod(cursor, user, host, host_all, None, encrypted,
cursor, plugin, plugin_hash_string, plugin_auth_string,
user, priv, append_privs, tls_requires, module)
host,
host_all,
None,
encrypted,
plugin,
plugin_hash_string,
plugin_auth_string,
priv,
append_privs,
tls_requires,
module,
)
except (SQLParseError, InvalidPrivsError, mysql_driver.Error) as e: except (SQLParseError, InvalidPrivsError, mysql_driver.Error) as e:
module.fail_json(msg=to_native(e)) module.fail_json(msg=to_native(e))
@ -1262,20 +1125,9 @@ def main():
if host_all: if host_all:
module.fail_json(msg="host_all parameter cannot be used when adding a user") module.fail_json(msg="host_all parameter cannot be used when adding a user")
try: try:
changed = user_add( changed = user_add(cursor, user, host, host_all, password, encrypted,
cursor, plugin, plugin_hash_string, plugin_auth_string,
user, priv, tls_requires, module.check_mode)
host,
host_all,
password,
encrypted,
plugin,
plugin_hash_string,
plugin_auth_string,
priv,
tls_requires,
module.check_mode,
)
if changed: if changed:
msg = "User added" msg = "User added"
@ -1295,5 +1147,5 @@ def main():
module.exit_json(changed=changed, user=user, msg=msg) module.exit_json(changed=changed, user=user, msg=msg)
if __name__ == "__main__": if __name__ == '__main__':
main() main()