diff --git a/plugins/module_utils/implementations/mariadb/user.py b/plugins/module_utils/implementations/mariadb/user.py index cdc14b2..ecf867e 100644 --- a/plugins/module_utils/implementations/mariadb/user.py +++ b/plugins/module_utils/implementations/mariadb/user.py @@ -29,3 +29,47 @@ def server_supports_password_expire(cursor): version = get_server_version(cursor) return LooseVersion(version) >= LooseVersion("10.4.3") + +def get_tls_requires(cursor, user, host): + """Get user TLS requirements. + Reads directly from mysql.user table allowing for a more + readable code. + + Args: + cursor (cursor): DB driver cursor object. + user (str): User name. + host (str): User host name. + + Returns: Dictionary containing current TLS required + """ + tls_requires = dict() + + query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject ' + 'FROM mysql.user WHERE User = %s AND Host = %s') + cursor.execute(query, (user, host)) + res = cursor.fetchone() + + # Mysql_info use a DictCursor so we must convert back to a list + # otherwise we get KeyError 0 + if isinstance(res, dict): + res = list(res.values()) + + # When user don't require SSL, res value is: ('', '', '', '') + if not any(res): + return None + + if res[0] == 'ANY': + tls_requires['SSL'] = None + + if res[0] == 'X509': + tls_requires['X509'] = None + + if res[1]: + tls_requires['CIPHER'] = res[1] + + if res[2]: + tls_requires['ISSUER'] = res[2] + + if res[3]: + tls_requires['SUBJECT'] = res[3] + return tls_requires diff --git a/plugins/module_utils/implementations/mysql/user.py b/plugins/module_utils/implementations/mysql/user.py index 4e41c05..40176ef 100644 --- a/plugins/module_utils/implementations/mysql/user.py +++ b/plugins/module_utils/implementations/mysql/user.py @@ -8,6 +8,9 @@ __metaclass__ = type from ansible_collections.community.mysql.plugins.module_utils.version import LooseVersion from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version +import re +import shlex + def use_old_user_mgmt(cursor): version = get_server_version(cursor) @@ -30,3 +33,42 @@ def server_supports_password_expire(cursor): version = get_server_version(cursor) return LooseVersion(version) >= LooseVersion("5.7") + +def get_tls_requires(cursor, user, host): + """Get user TLS requirements. + We must use SHOW GRANTS because some tls fileds are encoded. + + Args: + cursor (cursor): DB driver cursor object. + user (str): User name. + host (str): User host name. + + Returns: Dictionary containing current TLS required + """ + if user: + if not use_old_user_mgmt(cursor): + query = "SHOW CREATE USER '%s'@'%s'" % (user, host) + else: + query = "SHOW GRANTS for '%s'@'%s'" % (user, host) + + cursor.execute(query) + grants = cursor.fetchone() + + # Mysql_info use a DictCursor so we must convert back to a list + # otherwise we get KeyError 0 + if isinstance(grants, dict): + grants = list(grants.values()) + grants_str = ''.join(grants) + + pattern = r"(?<=\bREQUIRE\b)(.*?)(?=(?:\bPASSWORD\b|$))" + requires_match = re.search(pattern, grants_str) + requires = requires_match.group().strip() if requires_match else "" + + if any((requires.startswith(req) for req in ('SSL', 'X509', 'NONE'))): + requires = requires.split()[0] + if requires == 'NONE': + requires = None + + items = iter(shlex.split(requires)) + requires = dict(zip(items, items)) + return requires or None diff --git a/plugins/module_utils/user.py b/plugins/module_utils/user.py index 28174f8..d4ae9dd 100644 --- a/plugins/module_utils/user.py +++ b/plugins/module_utils/user.py @@ -17,6 +17,7 @@ from ansible.module_utils.six import iteritems from ansible_collections.community.mysql.plugins.module_utils.mysql import ( mysql_driver, + get_server_implementation, ) @@ -80,49 +81,6 @@ def do_not_mogrify_requires(query, params, tls_requires): return query, params -def get_tls_requires(cursor, user, host): - """Get user TLS requirements. - - Args: - cursor (cursor): DB driver cursor object. - user (str): User name. - host (str): User host name. - - Returns: Dictionary containing current TLS required - """ - tls_requires = dict() - - query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject ' - 'FROM mysql.user WHERE User = %s AND Host = %s') - cursor.execute(query, (user, host)) - res = cursor.fetchone() - - # Mysql_info use a DictCursor so we must convert back to a list - # otherwise we get KeyError 0 - if isinstance(res, dict): - res = list(res.values()) - - # When user don't require SSL, res value is: ('', '', '', '') - if not any(res): - return None - - if res[0] == 'ANY': - return 'SSL' - - if res[0] == 'X509': - return 'X509' - - if res[1]: - tls_requires['CIPHER'] = res[1] - - if res[2]: - tls_requires['ISSUER'] = res[2] - - if res[3]: - tls_requires['SUBJECT'] = res[3] - return tls_requires - - def get_grants(cursor, user, host): cursor.execute("SHOW GRANTS FOR %s@%s", (user, host)) grants_line = list(filter(lambda x: "ON *.*" in x[0], cursor.fetchall()))[0] @@ -184,6 +142,7 @@ def user_add(cursor, user, host, host_all, password, encrypted, return {'changed': True, 'password_changed': None, 'attributes': attributes} # Determine what user management method server uses + impl = get_user_implementation(cursor) old_user_mgmt = impl.use_old_user_mgmt(cursor) mogrify = do_not_mogrify_requires if old_user_mgmt else mogrify_requires @@ -262,6 +221,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, grant_option = False # Determine what user management method server uses + impl = get_user_implementation(cursor) old_user_mgmt = impl.use_old_user_mgmt(cursor) if host_all and not role: @@ -517,7 +477,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, continue # Handle TLS requirements - current_requires = get_tls_requires(cursor, user, host) + current_requires = sanitize_requires(impl.get_tls_requires(cursor, user, host)) if current_requires != tls_requires: msg = "TLS requires updated" if not module.check_mode: @@ -855,6 +815,7 @@ def privileges_grant(cursor, user, host, db_table, priv, tls_requires, maria_rol query.append("TO %s") params = (user) + impl = get_user_implementation(cursor) if tls_requires and impl.use_old_user_mgmt(cursor): query, params = mogrify_requires(" ".join(query), params, tls_requires) query = [query] @@ -991,6 +952,7 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode): Returns: True, if changed, False otherwise. """ + impl = get_user_implementation(cursor) if not impl.server_supports_alter_user(cursor): module.fail_json(msg="The server version does not match the requirements " "for resource_limits parameter. See module's documentation.") @@ -1126,12 +1088,11 @@ def attributes_get(cursor, user, host): return j if j else None -def get_impl(cursor): - global impl - cursor.execute("SELECT VERSION()") - if 'mariadb' in cursor.fetchone()[0].lower(): +def get_user_implementation(cursor): + db_engine = get_server_implementation(cursor) + if db_engine == 'mariadb': from ansible_collections.community.mysql.plugins.module_utils.implementations.mariadb import user as mariauser - impl = mariauser + return mariauser else: from ansible_collections.community.mysql.plugins.module_utils.implementations.mysql import user as mysqluser - impl = mysqluser + return mysqluser diff --git a/plugins/modules/mysql_info.py b/plugins/modules/mysql_info.py index 10f4548..17732fd 100644 --- a/plugins/modules/mysql_info.py +++ b/plugins/modules/mysql_info.py @@ -300,8 +300,7 @@ from ansible_collections.community.mysql.plugins.module_utils.user import ( privileges_get, get_resource_limits, get_existing_authentication, - get_tls_requires, - sanitize_requires, + get_user_implementation, ) from ansible.module_utils.six import iteritems from ansible.module_utils._text import to_native @@ -329,10 +328,11 @@ class MySQL_Info(object): 5. add info about the new subset with an example to RETURN block """ - def __init__(self, module, cursor, server_implementation): + def __init__(self, module, cursor, server_implementation, user_implementation): self.module = module self.cursor = cursor self.server_implementation = server_implementation + self.user_implementation = user_implementation self.info = { 'version': {}, 'databases': {}, @@ -606,14 +606,15 @@ class MySQL_Info(object): resource_limits = get_resource_limits(self.cursor, user, host) copy_ressource_limits = dict.copy(resource_limits) - tls_requires = get_tls_requires(self.cursor, user, host) + tls_requires = self.user_implementation.get_tls_requires( + self.cursor, user, host) output_dict = { 'name': user, 'host': host, 'priv': '/'.join(priv_string), 'resource_limits': copy_ressource_limits, - 'tls_requires': sanitize_requires(tls_requires), + 'tls_requires': tls_requires, } # Prevent returning a resource limit if empty @@ -754,11 +755,12 @@ def main(): module.fail_json(msg) server_implementation = get_server_implementation(cursor) + user_implementation = get_user_implementation(cursor) ############################### # Create object and do main job - mysql = MySQL_Info(module, cursor, server_implementation) + mysql = MySQL_Info(module, cursor, server_implementation, user_implementation) module.exit_json(changed=False, connector_name=connector_name, diff --git a/plugins/modules/mysql_role.py b/plugins/modules/mysql_role.py index 3e3462a..65ed894 100644 --- a/plugins/modules/mysql_role.py +++ b/plugins/modules/mysql_role.py @@ -309,7 +309,7 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import ( ) from ansible_collections.community.mysql.plugins.module_utils.user import ( convert_priv_dict_to_str, - get_impl, + get_user_implementation, get_mode, user_mod, privileges_grant, @@ -1054,7 +1054,7 @@ def main(): # Set defaults changed = False - get_impl(cursor) + impl = get_user_implementation(cursor) if priv is not None: try: diff --git a/plugins/modules/mysql_user.py b/plugins/modules/mysql_user.py index e02b153..fa54c7d 100644 --- a/plugins/modules/mysql_user.py +++ b/plugins/modules/mysql_user.py @@ -401,7 +401,6 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import ( ) from ansible_collections.community.mysql.plugins.module_utils.user import ( convert_priv_dict_to_str, - get_impl, get_mode, InvalidPrivsError, limit_resources, @@ -528,8 +527,6 @@ def main(): if session_vars: set_session_vars(module, cursor, session_vars) - get_impl(cursor) - if priv is not None: try: mode = get_mode(cursor)