mirror of
https://github.com/ansible-collections/community.general.git
synced 2025-04-25 11:51:26 -07:00
Allow setting/unsetting BYPASSRLS Postgres role attribute (#24625)
* Allow setting/unsetting BYPASSRLS role attr * Build valid role attrs against version * Add integration tests
This commit is contained in:
parent
603366863f
commit
723c8f06ab
4 changed files with 243 additions and 93 deletions
|
@ -103,7 +103,7 @@ options:
|
|||
required: false
|
||||
default: ""
|
||||
choices: [ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEUSER", "[NO]CREATEDB",
|
||||
"[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION" ]
|
||||
"[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION", "[NO]BYPASSRLS" ]
|
||||
state:
|
||||
description:
|
||||
- The user (role) state
|
||||
|
@ -210,6 +210,8 @@ EXAMPLES = '''
|
|||
import re
|
||||
import itertools
|
||||
|
||||
from distutils.version import StrictVersion
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
|
@ -220,7 +222,7 @@ else:
|
|||
from ansible.module_utils.six import iteritems
|
||||
|
||||
_flags = ('SUPERUSER', 'CREATEROLE', 'CREATEUSER', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION')
|
||||
VALID_FLAGS = frozenset(itertools.chain(_flags, ('NO%s' % f for f in _flags)))
|
||||
_flags_by_version = {'BYPASSRLS': '9.5.0'}
|
||||
|
||||
VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')),
|
||||
database=frozenset(('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')),
|
||||
|
@ -230,7 +232,7 @@ VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRU
|
|||
PRIV_TO_AUTHID_COLUMN = dict(SUPERUSER='rolsuper', CREATEROLE='rolcreaterole',
|
||||
CREATEUSER='rolcreateuser', CREATEDB='rolcreatedb',
|
||||
INHERIT='rolinherit', LOGIN='rolcanlogin',
|
||||
REPLICATION='rolreplication')
|
||||
REPLICATION='rolreplication', BYPASSRLS='rolbypassrls')
|
||||
|
||||
class InvalidFlagsError(Exception):
|
||||
pass
|
||||
|
@ -558,7 +560,7 @@ def grant_privileges(cursor, user, privs):
|
|||
changed = True
|
||||
return changed
|
||||
|
||||
def parse_role_attrs(role_attr_flags):
|
||||
def parse_role_attrs(cursor, role_attr_flags):
|
||||
"""
|
||||
Parse role attributes string for user creation.
|
||||
Format:
|
||||
|
@ -569,18 +571,24 @@ def parse_role_attrs(role_attr_flags):
|
|||
|
||||
attributes := CREATEDB,CREATEROLE,NOSUPERUSER,...
|
||||
[ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEUSER", "[NO]CREATEDB",
|
||||
"[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION" ]
|
||||
"[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION",
|
||||
"[NO]BYPASSRLS" ]
|
||||
|
||||
Note: "[NO]BYPASSRLS" role attribute introduced in 9.5
|
||||
|
||||
"""
|
||||
flags = frozenset(itertools.chain(_flags, get_valid_flags_by_version(cursor)))
|
||||
valid_flags = frozenset(itertools.chain(flags, ('NO%s' % f for f in flags)))
|
||||
|
||||
if ',' in role_attr_flags:
|
||||
flag_set = frozenset(r.upper() for r in role_attr_flags.split(","))
|
||||
elif role_attr_flags:
|
||||
flag_set = frozenset((role_attr_flags.upper(),))
|
||||
else:
|
||||
flag_set = frozenset()
|
||||
if not flag_set.issubset(VALID_FLAGS):
|
||||
if not flag_set.issubset(valid_flags):
|
||||
raise InvalidFlagsError('Invalid role_attr_flags specified: %s' %
|
||||
' '.join(flag_set.difference(VALID_FLAGS)))
|
||||
' '.join(flag_set.difference(valid_flags)))
|
||||
o_flags = ' '.join(flag_set)
|
||||
return o_flags
|
||||
|
||||
|
@ -633,6 +641,35 @@ def parse_privs(privs, db):
|
|||
|
||||
return o_privs
|
||||
|
||||
def get_pg_server_version(cursor):
|
||||
"""
|
||||
Queries Postgres for its server version.
|
||||
|
||||
server_version should be just the server version itself:
|
||||
|
||||
postgres=# SHOW SERVER_VERSION;
|
||||
server_version
|
||||
----------------
|
||||
9.6.2
|
||||
(1 row)
|
||||
"""
|
||||
cursor.execute("SHOW SERVER_VERSION")
|
||||
return cursor.fetchone()['server_version']
|
||||
|
||||
def get_valid_flags_by_version(cursor):
|
||||
"""
|
||||
Some role attributes were introduced after certain versions. We want to
|
||||
compile a list of valid flags against the current Postgres version.
|
||||
"""
|
||||
current_version = StrictVersion(get_pg_server_version(cursor))
|
||||
|
||||
return [
|
||||
flag
|
||||
for flag, version_introduced in _flags_by_version.items()
|
||||
if current_version >= StrictVersion(version_introduced)
|
||||
]
|
||||
|
||||
|
||||
# ===========================================
|
||||
# Module execution.
|
||||
#
|
||||
|
@ -671,11 +708,6 @@ def main():
|
|||
privs = parse_privs(module.params["priv"], db)
|
||||
port = module.params["port"]
|
||||
no_password_changes = module.params["no_password_changes"]
|
||||
try:
|
||||
role_attr_flags = parse_role_attrs(module.params["role_attr_flags"])
|
||||
except InvalidFlagsError:
|
||||
e = get_exception()
|
||||
module.fail_json(msg=str(e))
|
||||
if module.params["encrypted"]:
|
||||
encrypted = "ENCRYPTED"
|
||||
else:
|
||||
|
@ -723,6 +755,12 @@ def main():
|
|||
e = get_exception()
|
||||
module.fail_json(msg="unable to connect to database: %s" % e)
|
||||
|
||||
try:
|
||||
role_attr_flags = parse_role_attrs(cursor, module.params["role_attr_flags"])
|
||||
except InvalidFlagsError:
|
||||
e = get_exception()
|
||||
module.fail_json(msg=str(e))
|
||||
|
||||
kw = dict(user=user)
|
||||
changed = False
|
||||
user_removed = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue