mirror of
https://github.com/ansible-collections/community.mysql.git
synced 2025-04-20 09:21:25 -07:00
some changes and integration tests
This commit is contained in:
parent
10780aee98
commit
6d73c24526
5 changed files with 230 additions and 20 deletions
|
@ -202,7 +202,7 @@ def user_add(cursor, user, host, host_all, password, encrypted,
|
|||
cursor.execute(*mogrify(*query_with_args_and_tls_requires))
|
||||
|
||||
if password_expire:
|
||||
if not impl.supports_identified_by_password(cursor):
|
||||
if not impl.server_supports_password_expire(cursor):
|
||||
module.fail_json(msg="The server version does not match the requirements "
|
||||
"for password_expire parameter. See module's documentation.")
|
||||
set_password_expire(cursor, user, host, password_expire, password_expire_interval)
|
||||
|
@ -315,9 +315,12 @@ def user_mod(cursor, user, host, host_all, password, encrypted,
|
|||
update = False
|
||||
mariadb_role = True if "mariadb" in str(impl.__name__) else False
|
||||
current_password_policy = get_password_expiration_policy(cursor, user, host, maria_role=mariadb_role)
|
||||
if not (current_password_policy == -1 and password_expire == "default" or
|
||||
current_password_policy == 0 and password_expire == "never" or
|
||||
current_password_policy == password_expire_interval):
|
||||
password_expired = is_password_expired(cursor, user, host)
|
||||
# Check if changes needed to be applied.
|
||||
if not ((current_password_policy == -1 and password_expire == "default") or
|
||||
(current_password_policy == 0 and password_expire == "never") or
|
||||
(current_password_policy == password_expire_interval and password_expire == "interval") or
|
||||
(password_expire == 'now' and password_expired)):
|
||||
|
||||
update = True
|
||||
|
||||
|
@ -968,14 +971,19 @@ def set_password_expire(cursor, user, host, password_expire, password_expire_int
|
|||
elif password_expire.lower() == "default":
|
||||
statment = "PASSWORD EXPIRE DEFAULT"
|
||||
elif password_expire.lower() == "interval":
|
||||
if password_expire_interval > 0:
|
||||
statment = "PASSWORD EXPIRE INTERVAL %d DAY" % (password_expire_interval)
|
||||
else:
|
||||
# expire password now if days <=0
|
||||
if isinstance(password_expire_interval, int):
|
||||
statment = "PASSWORD EXPIRE"
|
||||
query = "ALTER USER %s@%s %s" % (user, host, statment)
|
||||
cursor.execute(query)
|
||||
statment = "PASSWORD EXPIRE INTERVAL %d DAY" % (password_expire_interval)
|
||||
elif password_expire.lower() == "now":
|
||||
statment = "PASSWORD EXPIRE"
|
||||
if host:
|
||||
params = (user, host)
|
||||
query = ["ALTER USER %s@%s"]
|
||||
else:
|
||||
params = (user,)
|
||||
query = ["ALTER USER %s"]
|
||||
|
||||
query.append(statment)
|
||||
query = ' '.join(query)
|
||||
cursor.execute(query, params)
|
||||
|
||||
|
||||
def get_password_expiration_policy(cursor, user, host, maria_role=False):
|
||||
|
@ -991,7 +999,7 @@ def get_password_expiration_policy(cursor, user, host, maria_role=False):
|
|||
policy (int): Current users password policy.
|
||||
"""
|
||||
if not maria_role:
|
||||
statment = "SELECT password_lifetime FROM mysql.user \
|
||||
statment = "SELECT IFNULL(password_lifetime, -1) FROM mysql.user \
|
||||
WHERE User = %s AND Host = %s", (user, host)
|
||||
else:
|
||||
statment = "SELECT JSON_EXTRACT(Priv, '$.password_lifetime') AS password_lifetime \
|
||||
|
@ -999,11 +1007,28 @@ def get_password_expiration_policy(cursor, user, host, maria_role=False):
|
|||
WHERE User = %s AND Host = %s", (user, host)
|
||||
cursor.execute(*statment)
|
||||
policy = cursor.fetchone()[0]
|
||||
if not policy:
|
||||
policy = -1
|
||||
return int(policy)
|
||||
|
||||
|
||||
def is_password_expired(cursor, user, host):
|
||||
"""Function to check if password is expired
|
||||
|
||||
Args:
|
||||
cursor (cursor): DB driver cursor object.
|
||||
user (str): User name.
|
||||
host (str): User hostname.
|
||||
|
||||
Returns:
|
||||
expired (bool): True if expired, else False.
|
||||
"""
|
||||
statment = "SELECT password_expired FROM mysql.user \
|
||||
WHERE User = %s AND Host = %s", (user, host)
|
||||
cursor.execute(*statment)
|
||||
expired = cursor.fetchone()[0]
|
||||
if str(expired) == "Y":
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_impl(cursor):
|
||||
global impl
|
||||
cursor.execute("SELECT VERSION()")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue