mirror of
				https://github.com/ansible-collections/community.mysql.git
				synced 2025-10-26 05:50:39 -07:00 
			
		
		
		
	Refactor user implementation to host get_tls_requires
This commit is contained in:
		
					parent
					
						
							
								5460dec642
							
						
					
				
			
			
				commit
				
					
						d77be1ba03
					
				
			
		
					 6 changed files with 107 additions and 61 deletions
				
			
		|  | @ -29,3 +29,47 @@ def server_supports_password_expire(cursor): | |||
|     version = get_server_version(cursor) | ||||
| 
 | ||||
|     return LooseVersion(version) >= LooseVersion("10.4.3") | ||||
| 
 | ||||
| def get_tls_requires(cursor, user, host): | ||||
|     """Get user TLS requirements. | ||||
|     Reads directly from mysql.user table allowing for a more | ||||
|     readable code. | ||||
| 
 | ||||
|     Args: | ||||
|         cursor (cursor): DB driver cursor object. | ||||
|         user (str): User name. | ||||
|         host (str): User host name. | ||||
| 
 | ||||
|     Returns: Dictionary containing current TLS required | ||||
|     """ | ||||
|     tls_requires = dict() | ||||
| 
 | ||||
|     query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject ' | ||||
|              'FROM mysql.user WHERE User = %s AND Host = %s') | ||||
|     cursor.execute(query, (user, host)) | ||||
|     res = cursor.fetchone() | ||||
| 
 | ||||
|     # Mysql_info use a DictCursor so we must convert back to a list | ||||
|     # otherwise we get KeyError 0 | ||||
|     if isinstance(res, dict): | ||||
|         res = list(res.values()) | ||||
| 
 | ||||
|     # When user don't require SSL, res value is: ('', '', '', '') | ||||
|     if not any(res): | ||||
|         return None | ||||
| 
 | ||||
|     if res[0] == 'ANY': | ||||
|         tls_requires['SSL'] = None | ||||
| 
 | ||||
|     if res[0] == 'X509': | ||||
|         tls_requires['X509'] = None | ||||
| 
 | ||||
|     if res[1]: | ||||
|         tls_requires['CIPHER'] = res[1] | ||||
| 
 | ||||
|     if res[2]: | ||||
|         tls_requires['ISSUER'] = res[2] | ||||
| 
 | ||||
|     if res[3]: | ||||
|         tls_requires['SUBJECT'] = res[3] | ||||
|     return tls_requires | ||||
|  |  | |||
|  | @ -8,6 +8,9 @@ __metaclass__ = type | |||
| from ansible_collections.community.mysql.plugins.module_utils.version import LooseVersion | ||||
| from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version | ||||
| 
 | ||||
| import re | ||||
| import shlex | ||||
| 
 | ||||
| 
 | ||||
| def use_old_user_mgmt(cursor): | ||||
|     version = get_server_version(cursor) | ||||
|  | @ -30,3 +33,42 @@ def server_supports_password_expire(cursor): | |||
|     version = get_server_version(cursor) | ||||
| 
 | ||||
|     return LooseVersion(version) >= LooseVersion("5.7") | ||||
| 
 | ||||
| def get_tls_requires(cursor, user, host): | ||||
|     """Get user TLS requirements. | ||||
|     We must use SHOW GRANTS because some tls fileds are encoded. | ||||
| 
 | ||||
|     Args: | ||||
|         cursor (cursor): DB driver cursor object. | ||||
|         user (str): User name. | ||||
|         host (str): User host name. | ||||
| 
 | ||||
|     Returns: Dictionary containing current TLS required | ||||
|     """ | ||||
|     if user: | ||||
|         if not use_old_user_mgmt(cursor): | ||||
|             query = "SHOW CREATE USER '%s'@'%s'" % (user, host) | ||||
|         else: | ||||
|             query = "SHOW GRANTS for '%s'@'%s'" % (user, host) | ||||
| 
 | ||||
|         cursor.execute(query) | ||||
|         grants = cursor.fetchone() | ||||
| 
 | ||||
|         # Mysql_info use a DictCursor so we must convert back to a list | ||||
|         # otherwise we get KeyError 0 | ||||
|         if isinstance(grants, dict): | ||||
|             grants = list(grants.values()) | ||||
|         grants_str = ''.join(grants) | ||||
| 
 | ||||
|         pattern = r"(?<=\bREQUIRE\b)(.*?)(?=(?:\bPASSWORD\b|$))" | ||||
|         requires_match = re.search(pattern, grants_str) | ||||
|         requires = requires_match.group().strip() if requires_match else "" | ||||
| 
 | ||||
|         if any((requires.startswith(req) for req in ('SSL', 'X509', 'NONE'))): | ||||
|             requires = requires.split()[0] | ||||
|             if requires == 'NONE': | ||||
|                 requires = None | ||||
| 
 | ||||
|         items = iter(shlex.split(requires)) | ||||
|         requires = dict(zip(items, items)) | ||||
|         return requires or None | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ from ansible.module_utils.six import iteritems | |||
| 
 | ||||
| from ansible_collections.community.mysql.plugins.module_utils.mysql import ( | ||||
|     mysql_driver, | ||||
|     get_server_implementation, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -80,49 +81,6 @@ def do_not_mogrify_requires(query, params, tls_requires): | |||
|     return query, params | ||||
| 
 | ||||
| 
 | ||||
| def get_tls_requires(cursor, user, host): | ||||
|     """Get user TLS requirements. | ||||
| 
 | ||||
|     Args: | ||||
|         cursor (cursor): DB driver cursor object. | ||||
|         user (str): User name. | ||||
|         host (str): User host name. | ||||
| 
 | ||||
|     Returns: Dictionary containing current TLS required | ||||
|     """ | ||||
|     tls_requires = dict() | ||||
| 
 | ||||
|     query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject ' | ||||
|              'FROM mysql.user WHERE User = %s AND Host = %s') | ||||
|     cursor.execute(query, (user, host)) | ||||
|     res = cursor.fetchone() | ||||
| 
 | ||||
|     # Mysql_info use a DictCursor so we must convert back to a list | ||||
|     # otherwise we get KeyError 0 | ||||
|     if isinstance(res, dict): | ||||
|         res = list(res.values()) | ||||
| 
 | ||||
|     # When user don't require SSL, res value is: ('', '', '', '') | ||||
|     if not any(res): | ||||
|         return None | ||||
| 
 | ||||
|     if res[0] == 'ANY': | ||||
|         return 'SSL' | ||||
| 
 | ||||
|     if res[0] == 'X509': | ||||
|         return 'X509' | ||||
| 
 | ||||
|     if res[1]: | ||||
|         tls_requires['CIPHER'] = res[1] | ||||
| 
 | ||||
|     if res[2]: | ||||
|         tls_requires['ISSUER'] = res[2] | ||||
| 
 | ||||
|     if res[3]: | ||||
|         tls_requires['SUBJECT'] = res[3] | ||||
|     return tls_requires | ||||
| 
 | ||||
| 
 | ||||
| def get_grants(cursor, user, host): | ||||
|     cursor.execute("SHOW GRANTS FOR %s@%s", (user, host)) | ||||
|     grants_line = list(filter(lambda x: "ON *.*" in x[0], cursor.fetchall()))[0] | ||||
|  | @ -184,6 +142,7 @@ def user_add(cursor, user, host, host_all, password, encrypted, | |||
|         return {'changed': True, 'password_changed': None, 'attributes': attributes} | ||||
| 
 | ||||
|     # Determine what user management method server uses | ||||
|     impl = get_user_implementation(cursor) | ||||
|     old_user_mgmt = impl.use_old_user_mgmt(cursor) | ||||
| 
 | ||||
|     mogrify = do_not_mogrify_requires if old_user_mgmt else mogrify_requires | ||||
|  | @ -262,6 +221,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, | |||
|     grant_option = False | ||||
| 
 | ||||
|     # Determine what user management method server uses | ||||
|     impl = get_user_implementation(cursor) | ||||
|     old_user_mgmt = impl.use_old_user_mgmt(cursor) | ||||
| 
 | ||||
|     if host_all and not role: | ||||
|  | @ -517,7 +477,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, | |||
|             continue | ||||
| 
 | ||||
|         # Handle TLS requirements | ||||
|         current_requires = get_tls_requires(cursor, user, host) | ||||
|         current_requires = sanitize_requires(impl.get_tls_requires(cursor, user, host)) | ||||
|         if current_requires != tls_requires: | ||||
|             msg = "TLS requires updated" | ||||
|             if not module.check_mode: | ||||
|  | @ -855,6 +815,7 @@ def privileges_grant(cursor, user, host, db_table, priv, tls_requires, maria_rol | |||
|         query.append("TO %s") | ||||
|         params = (user) | ||||
| 
 | ||||
|     impl = get_user_implementation(cursor) | ||||
|     if tls_requires and impl.use_old_user_mgmt(cursor): | ||||
|         query, params = mogrify_requires(" ".join(query), params, tls_requires) | ||||
|         query = [query] | ||||
|  | @ -991,6 +952,7 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode): | |||
| 
 | ||||
|     Returns: True, if changed, False otherwise. | ||||
|     """ | ||||
|     impl = get_user_implementation(cursor) | ||||
|     if not impl.server_supports_alter_user(cursor): | ||||
|         module.fail_json(msg="The server version does not match the requirements " | ||||
|                              "for resource_limits parameter. See module's documentation.") | ||||
|  | @ -1126,12 +1088,11 @@ def attributes_get(cursor, user, host): | |||
|     return j if j else None | ||||
| 
 | ||||
| 
 | ||||
| def get_impl(cursor): | ||||
|     global impl | ||||
|     cursor.execute("SELECT VERSION()") | ||||
|     if 'mariadb' in cursor.fetchone()[0].lower(): | ||||
| def get_user_implementation(cursor): | ||||
|     db_engine = get_server_implementation(cursor) | ||||
|     if db_engine == 'mariadb': | ||||
|         from ansible_collections.community.mysql.plugins.module_utils.implementations.mariadb import user as mariauser | ||||
|         impl = mariauser | ||||
|         return mariauser | ||||
|     else: | ||||
|         from ansible_collections.community.mysql.plugins.module_utils.implementations.mysql import user as mysqluser | ||||
|         impl = mysqluser | ||||
|         return mysqluser | ||||
|  |  | |||
|  | @ -300,8 +300,7 @@ from ansible_collections.community.mysql.plugins.module_utils.user import ( | |||
|     privileges_get, | ||||
|     get_resource_limits, | ||||
|     get_existing_authentication, | ||||
|     get_tls_requires, | ||||
|     sanitize_requires, | ||||
|     get_user_implementation, | ||||
| ) | ||||
| from ansible.module_utils.six import iteritems | ||||
| from ansible.module_utils._text import to_native | ||||
|  | @ -329,10 +328,11 @@ class MySQL_Info(object): | |||
|         5. add info about the new subset with an example to RETURN block | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, module, cursor, server_implementation): | ||||
|     def __init__(self, module, cursor, server_implementation, user_implementation): | ||||
|         self.module = module | ||||
|         self.cursor = cursor | ||||
|         self.server_implementation = server_implementation | ||||
|         self.user_implementation = user_implementation | ||||
|         self.info = { | ||||
|             'version': {}, | ||||
|             'databases': {}, | ||||
|  | @ -606,14 +606,15 @@ class MySQL_Info(object): | |||
|             resource_limits = get_resource_limits(self.cursor, user, host) | ||||
|             copy_ressource_limits = dict.copy(resource_limits) | ||||
| 
 | ||||
|             tls_requires = get_tls_requires(self.cursor, user, host) | ||||
|             tls_requires = self.user_implementation.get_tls_requires( | ||||
|                 self.cursor, user, host) | ||||
| 
 | ||||
|             output_dict = { | ||||
|                 'name': user, | ||||
|                 'host': host, | ||||
|                 'priv': '/'.join(priv_string), | ||||
|                 'resource_limits': copy_ressource_limits, | ||||
|                 'tls_requires': sanitize_requires(tls_requires), | ||||
|                 'tls_requires': tls_requires, | ||||
|             } | ||||
| 
 | ||||
|             # Prevent returning a resource limit if empty | ||||
|  | @ -754,11 +755,12 @@ def main(): | |||
|         module.fail_json(msg) | ||||
| 
 | ||||
|     server_implementation = get_server_implementation(cursor) | ||||
|     user_implementation = get_user_implementation(cursor) | ||||
| 
 | ||||
|     ############################### | ||||
|     # Create object and do main job | ||||
| 
 | ||||
|     mysql = MySQL_Info(module, cursor, server_implementation) | ||||
|     mysql = MySQL_Info(module, cursor, server_implementation, user_implementation) | ||||
| 
 | ||||
|     module.exit_json(changed=False, | ||||
|                      connector_name=connector_name, | ||||
|  |  | |||
|  | @ -309,7 +309,7 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import ( | |||
| ) | ||||
| from ansible_collections.community.mysql.plugins.module_utils.user import ( | ||||
|     convert_priv_dict_to_str, | ||||
|     get_impl, | ||||
|     get_user_implementation, | ||||
|     get_mode, | ||||
|     user_mod, | ||||
|     privileges_grant, | ||||
|  | @ -1054,7 +1054,7 @@ def main(): | |||
|     # Set defaults | ||||
|     changed = False | ||||
| 
 | ||||
|     get_impl(cursor) | ||||
|     impl = get_user_implementation(cursor) | ||||
| 
 | ||||
|     if priv is not None: | ||||
|         try: | ||||
|  |  | |||
|  | @ -401,7 +401,6 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import ( | |||
| ) | ||||
| from ansible_collections.community.mysql.plugins.module_utils.user import ( | ||||
|     convert_priv_dict_to_str, | ||||
|     get_impl, | ||||
|     get_mode, | ||||
|     InvalidPrivsError, | ||||
|     limit_resources, | ||||
|  | @ -528,8 +527,6 @@ def main(): | |||
|     if session_vars: | ||||
|         set_session_vars(module, cursor, session_vars) | ||||
| 
 | ||||
|     get_impl(cursor) | ||||
| 
 | ||||
|     if priv is not None: | ||||
|         try: | ||||
|             mode = get_mode(cursor) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue