mirror of
https://github.com/ansible-collections/community.mysql.git
synced 2025-04-08 11:40:33 -07:00
Refactor user implementation to host get_tls_requires
This commit is contained in:
parent
5460dec642
commit
d77be1ba03
6 changed files with 107 additions and 61 deletions
|
@ -29,3 +29,47 @@ def server_supports_password_expire(cursor):
|
||||||
version = get_server_version(cursor)
|
version = get_server_version(cursor)
|
||||||
|
|
||||||
return LooseVersion(version) >= LooseVersion("10.4.3")
|
return LooseVersion(version) >= LooseVersion("10.4.3")
|
||||||
|
|
||||||
|
def get_tls_requires(cursor, user, host):
|
||||||
|
"""Get user TLS requirements.
|
||||||
|
Reads directly from mysql.user table allowing for a more
|
||||||
|
readable code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cursor (cursor): DB driver cursor object.
|
||||||
|
user (str): User name.
|
||||||
|
host (str): User host name.
|
||||||
|
|
||||||
|
Returns: Dictionary containing current TLS required
|
||||||
|
"""
|
||||||
|
tls_requires = dict()
|
||||||
|
|
||||||
|
query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject '
|
||||||
|
'FROM mysql.user WHERE User = %s AND Host = %s')
|
||||||
|
cursor.execute(query, (user, host))
|
||||||
|
res = cursor.fetchone()
|
||||||
|
|
||||||
|
# Mysql_info use a DictCursor so we must convert back to a list
|
||||||
|
# otherwise we get KeyError 0
|
||||||
|
if isinstance(res, dict):
|
||||||
|
res = list(res.values())
|
||||||
|
|
||||||
|
# When user don't require SSL, res value is: ('', '', '', '')
|
||||||
|
if not any(res):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if res[0] == 'ANY':
|
||||||
|
tls_requires['SSL'] = None
|
||||||
|
|
||||||
|
if res[0] == 'X509':
|
||||||
|
tls_requires['X509'] = None
|
||||||
|
|
||||||
|
if res[1]:
|
||||||
|
tls_requires['CIPHER'] = res[1]
|
||||||
|
|
||||||
|
if res[2]:
|
||||||
|
tls_requires['ISSUER'] = res[2]
|
||||||
|
|
||||||
|
if res[3]:
|
||||||
|
tls_requires['SUBJECT'] = res[3]
|
||||||
|
return tls_requires
|
||||||
|
|
|
@ -8,6 +8,9 @@ __metaclass__ = type
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.version import LooseVersion
|
from ansible_collections.community.mysql.plugins.module_utils.version import LooseVersion
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version
|
from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version
|
||||||
|
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
|
||||||
def use_old_user_mgmt(cursor):
|
def use_old_user_mgmt(cursor):
|
||||||
version = get_server_version(cursor)
|
version = get_server_version(cursor)
|
||||||
|
@ -30,3 +33,42 @@ def server_supports_password_expire(cursor):
|
||||||
version = get_server_version(cursor)
|
version = get_server_version(cursor)
|
||||||
|
|
||||||
return LooseVersion(version) >= LooseVersion("5.7")
|
return LooseVersion(version) >= LooseVersion("5.7")
|
||||||
|
|
||||||
|
def get_tls_requires(cursor, user, host):
|
||||||
|
"""Get user TLS requirements.
|
||||||
|
We must use SHOW GRANTS because some tls fileds are encoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cursor (cursor): DB driver cursor object.
|
||||||
|
user (str): User name.
|
||||||
|
host (str): User host name.
|
||||||
|
|
||||||
|
Returns: Dictionary containing current TLS required
|
||||||
|
"""
|
||||||
|
if user:
|
||||||
|
if not use_old_user_mgmt(cursor):
|
||||||
|
query = "SHOW CREATE USER '%s'@'%s'" % (user, host)
|
||||||
|
else:
|
||||||
|
query = "SHOW GRANTS for '%s'@'%s'" % (user, host)
|
||||||
|
|
||||||
|
cursor.execute(query)
|
||||||
|
grants = cursor.fetchone()
|
||||||
|
|
||||||
|
# Mysql_info use a DictCursor so we must convert back to a list
|
||||||
|
# otherwise we get KeyError 0
|
||||||
|
if isinstance(grants, dict):
|
||||||
|
grants = list(grants.values())
|
||||||
|
grants_str = ''.join(grants)
|
||||||
|
|
||||||
|
pattern = r"(?<=\bREQUIRE\b)(.*?)(?=(?:\bPASSWORD\b|$))"
|
||||||
|
requires_match = re.search(pattern, grants_str)
|
||||||
|
requires = requires_match.group().strip() if requires_match else ""
|
||||||
|
|
||||||
|
if any((requires.startswith(req) for req in ('SSL', 'X509', 'NONE'))):
|
||||||
|
requires = requires.split()[0]
|
||||||
|
if requires == 'NONE':
|
||||||
|
requires = None
|
||||||
|
|
||||||
|
items = iter(shlex.split(requires))
|
||||||
|
requires = dict(zip(items, items))
|
||||||
|
return requires or None
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ansible.module_utils.six import iteritems
|
||||||
|
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.mysql import (
|
from ansible_collections.community.mysql.plugins.module_utils.mysql import (
|
||||||
mysql_driver,
|
mysql_driver,
|
||||||
|
get_server_implementation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,49 +81,6 @@ def do_not_mogrify_requires(query, params, tls_requires):
|
||||||
return query, params
|
return query, params
|
||||||
|
|
||||||
|
|
||||||
def get_tls_requires(cursor, user, host):
|
|
||||||
"""Get user TLS requirements.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cursor (cursor): DB driver cursor object.
|
|
||||||
user (str): User name.
|
|
||||||
host (str): User host name.
|
|
||||||
|
|
||||||
Returns: Dictionary containing current TLS required
|
|
||||||
"""
|
|
||||||
tls_requires = dict()
|
|
||||||
|
|
||||||
query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject '
|
|
||||||
'FROM mysql.user WHERE User = %s AND Host = %s')
|
|
||||||
cursor.execute(query, (user, host))
|
|
||||||
res = cursor.fetchone()
|
|
||||||
|
|
||||||
# Mysql_info use a DictCursor so we must convert back to a list
|
|
||||||
# otherwise we get KeyError 0
|
|
||||||
if isinstance(res, dict):
|
|
||||||
res = list(res.values())
|
|
||||||
|
|
||||||
# When user don't require SSL, res value is: ('', '', '', '')
|
|
||||||
if not any(res):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if res[0] == 'ANY':
|
|
||||||
return 'SSL'
|
|
||||||
|
|
||||||
if res[0] == 'X509':
|
|
||||||
return 'X509'
|
|
||||||
|
|
||||||
if res[1]:
|
|
||||||
tls_requires['CIPHER'] = res[1]
|
|
||||||
|
|
||||||
if res[2]:
|
|
||||||
tls_requires['ISSUER'] = res[2]
|
|
||||||
|
|
||||||
if res[3]:
|
|
||||||
tls_requires['SUBJECT'] = res[3]
|
|
||||||
return tls_requires
|
|
||||||
|
|
||||||
|
|
||||||
def get_grants(cursor, user, host):
|
def get_grants(cursor, user, host):
|
||||||
cursor.execute("SHOW GRANTS FOR %s@%s", (user, host))
|
cursor.execute("SHOW GRANTS FOR %s@%s", (user, host))
|
||||||
grants_line = list(filter(lambda x: "ON *.*" in x[0], cursor.fetchall()))[0]
|
grants_line = list(filter(lambda x: "ON *.*" in x[0], cursor.fetchall()))[0]
|
||||||
|
@ -184,6 +142,7 @@ def user_add(cursor, user, host, host_all, password, encrypted,
|
||||||
return {'changed': True, 'password_changed': None, 'attributes': attributes}
|
return {'changed': True, 'password_changed': None, 'attributes': attributes}
|
||||||
|
|
||||||
# Determine what user management method server uses
|
# Determine what user management method server uses
|
||||||
|
impl = get_user_implementation(cursor)
|
||||||
old_user_mgmt = impl.use_old_user_mgmt(cursor)
|
old_user_mgmt = impl.use_old_user_mgmt(cursor)
|
||||||
|
|
||||||
mogrify = do_not_mogrify_requires if old_user_mgmt else mogrify_requires
|
mogrify = do_not_mogrify_requires if old_user_mgmt else mogrify_requires
|
||||||
|
@ -262,6 +221,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted,
|
||||||
grant_option = False
|
grant_option = False
|
||||||
|
|
||||||
# Determine what user management method server uses
|
# Determine what user management method server uses
|
||||||
|
impl = get_user_implementation(cursor)
|
||||||
old_user_mgmt = impl.use_old_user_mgmt(cursor)
|
old_user_mgmt = impl.use_old_user_mgmt(cursor)
|
||||||
|
|
||||||
if host_all and not role:
|
if host_all and not role:
|
||||||
|
@ -517,7 +477,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted,
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle TLS requirements
|
# Handle TLS requirements
|
||||||
current_requires = get_tls_requires(cursor, user, host)
|
current_requires = sanitize_requires(impl.get_tls_requires(cursor, user, host))
|
||||||
if current_requires != tls_requires:
|
if current_requires != tls_requires:
|
||||||
msg = "TLS requires updated"
|
msg = "TLS requires updated"
|
||||||
if not module.check_mode:
|
if not module.check_mode:
|
||||||
|
@ -855,6 +815,7 @@ def privileges_grant(cursor, user, host, db_table, priv, tls_requires, maria_rol
|
||||||
query.append("TO %s")
|
query.append("TO %s")
|
||||||
params = (user)
|
params = (user)
|
||||||
|
|
||||||
|
impl = get_user_implementation(cursor)
|
||||||
if tls_requires and impl.use_old_user_mgmt(cursor):
|
if tls_requires and impl.use_old_user_mgmt(cursor):
|
||||||
query, params = mogrify_requires(" ".join(query), params, tls_requires)
|
query, params = mogrify_requires(" ".join(query), params, tls_requires)
|
||||||
query = [query]
|
query = [query]
|
||||||
|
@ -991,6 +952,7 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode):
|
||||||
|
|
||||||
Returns: True, if changed, False otherwise.
|
Returns: True, if changed, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
impl = get_user_implementation(cursor)
|
||||||
if not impl.server_supports_alter_user(cursor):
|
if not impl.server_supports_alter_user(cursor):
|
||||||
module.fail_json(msg="The server version does not match the requirements "
|
module.fail_json(msg="The server version does not match the requirements "
|
||||||
"for resource_limits parameter. See module's documentation.")
|
"for resource_limits parameter. See module's documentation.")
|
||||||
|
@ -1126,12 +1088,11 @@ def attributes_get(cursor, user, host):
|
||||||
return j if j else None
|
return j if j else None
|
||||||
|
|
||||||
|
|
||||||
def get_impl(cursor):
|
def get_user_implementation(cursor):
|
||||||
global impl
|
db_engine = get_server_implementation(cursor)
|
||||||
cursor.execute("SELECT VERSION()")
|
if db_engine == 'mariadb':
|
||||||
if 'mariadb' in cursor.fetchone()[0].lower():
|
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.implementations.mariadb import user as mariauser
|
from ansible_collections.community.mysql.plugins.module_utils.implementations.mariadb import user as mariauser
|
||||||
impl = mariauser
|
return mariauser
|
||||||
else:
|
else:
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.implementations.mysql import user as mysqluser
|
from ansible_collections.community.mysql.plugins.module_utils.implementations.mysql import user as mysqluser
|
||||||
impl = mysqluser
|
return mysqluser
|
||||||
|
|
|
@ -300,8 +300,7 @@ from ansible_collections.community.mysql.plugins.module_utils.user import (
|
||||||
privileges_get,
|
privileges_get,
|
||||||
get_resource_limits,
|
get_resource_limits,
|
||||||
get_existing_authentication,
|
get_existing_authentication,
|
||||||
get_tls_requires,
|
get_user_implementation,
|
||||||
sanitize_requires,
|
|
||||||
)
|
)
|
||||||
from ansible.module_utils.six import iteritems
|
from ansible.module_utils.six import iteritems
|
||||||
from ansible.module_utils._text import to_native
|
from ansible.module_utils._text import to_native
|
||||||
|
@ -329,10 +328,11 @@ class MySQL_Info(object):
|
||||||
5. add info about the new subset with an example to RETURN block
|
5. add info about the new subset with an example to RETURN block
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module, cursor, server_implementation):
|
def __init__(self, module, cursor, server_implementation, user_implementation):
|
||||||
self.module = module
|
self.module = module
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
self.server_implementation = server_implementation
|
self.server_implementation = server_implementation
|
||||||
|
self.user_implementation = user_implementation
|
||||||
self.info = {
|
self.info = {
|
||||||
'version': {},
|
'version': {},
|
||||||
'databases': {},
|
'databases': {},
|
||||||
|
@ -606,14 +606,15 @@ class MySQL_Info(object):
|
||||||
resource_limits = get_resource_limits(self.cursor, user, host)
|
resource_limits = get_resource_limits(self.cursor, user, host)
|
||||||
copy_ressource_limits = dict.copy(resource_limits)
|
copy_ressource_limits = dict.copy(resource_limits)
|
||||||
|
|
||||||
tls_requires = get_tls_requires(self.cursor, user, host)
|
tls_requires = self.user_implementation.get_tls_requires(
|
||||||
|
self.cursor, user, host)
|
||||||
|
|
||||||
output_dict = {
|
output_dict = {
|
||||||
'name': user,
|
'name': user,
|
||||||
'host': host,
|
'host': host,
|
||||||
'priv': '/'.join(priv_string),
|
'priv': '/'.join(priv_string),
|
||||||
'resource_limits': copy_ressource_limits,
|
'resource_limits': copy_ressource_limits,
|
||||||
'tls_requires': sanitize_requires(tls_requires),
|
'tls_requires': tls_requires,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prevent returning a resource limit if empty
|
# Prevent returning a resource limit if empty
|
||||||
|
@ -754,11 +755,12 @@ def main():
|
||||||
module.fail_json(msg)
|
module.fail_json(msg)
|
||||||
|
|
||||||
server_implementation = get_server_implementation(cursor)
|
server_implementation = get_server_implementation(cursor)
|
||||||
|
user_implementation = get_user_implementation(cursor)
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Create object and do main job
|
# Create object and do main job
|
||||||
|
|
||||||
mysql = MySQL_Info(module, cursor, server_implementation)
|
mysql = MySQL_Info(module, cursor, server_implementation, user_implementation)
|
||||||
|
|
||||||
module.exit_json(changed=False,
|
module.exit_json(changed=False,
|
||||||
connector_name=connector_name,
|
connector_name=connector_name,
|
||||||
|
|
|
@ -309,7 +309,7 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import (
|
||||||
)
|
)
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.user import (
|
from ansible_collections.community.mysql.plugins.module_utils.user import (
|
||||||
convert_priv_dict_to_str,
|
convert_priv_dict_to_str,
|
||||||
get_impl,
|
get_user_implementation,
|
||||||
get_mode,
|
get_mode,
|
||||||
user_mod,
|
user_mod,
|
||||||
privileges_grant,
|
privileges_grant,
|
||||||
|
@ -1054,7 +1054,7 @@ def main():
|
||||||
# Set defaults
|
# Set defaults
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
get_impl(cursor)
|
impl = get_user_implementation(cursor)
|
||||||
|
|
||||||
if priv is not None:
|
if priv is not None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -401,7 +401,6 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import (
|
||||||
)
|
)
|
||||||
from ansible_collections.community.mysql.plugins.module_utils.user import (
|
from ansible_collections.community.mysql.plugins.module_utils.user import (
|
||||||
convert_priv_dict_to_str,
|
convert_priv_dict_to_str,
|
||||||
get_impl,
|
|
||||||
get_mode,
|
get_mode,
|
||||||
InvalidPrivsError,
|
InvalidPrivsError,
|
||||||
limit_resources,
|
limit_resources,
|
||||||
|
@ -528,8 +527,6 @@ def main():
|
||||||
if session_vars:
|
if session_vars:
|
||||||
set_session_vars(module, cursor, session_vars)
|
set_session_vars(module, cursor, session_vars)
|
||||||
|
|
||||||
get_impl(cursor)
|
|
||||||
|
|
||||||
if priv is not None:
|
if priv is not None:
|
||||||
try:
|
try:
|
||||||
mode = get_mode(cursor)
|
mode = get_mode(cursor)
|
||||||
|
|
Loading…
Add table
Reference in a new issue