mirror of
https://github.com/ansible-collections/community.mysql.git
synced 2025-05-30 04:49:10 -07:00
initial commit
This commit is contained in:
parent
9fcbbaad81
commit
76a1adffef
108 changed files with 8729 additions and 42 deletions
189
plugins/module_utils/database.py
Normal file
189
plugins/module_utils/database.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
# This code is part of Ansible, but is an independent component.
|
||||
# This particular file snippet, and this file snippet only, is BSD licensed.
|
||||
# Modules you write using this snippet, which is embedded dynamically by Ansible
|
||||
# still belong to the author of the module, and may assign their own license
|
||||
# to the complete work.
|
||||
#
|
||||
# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com>
|
||||
#
|
||||
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
|
||||
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import re
|
||||
|
||||
|
||||
# Input patterns for is_input_dangerous function:
|
||||
#
|
||||
# 1. '"' in string and '--' in string or
|
||||
# "'" in string and '--' in string
|
||||
PATTERN_1 = re.compile(r'(\'|\").*--')
|
||||
|
||||
# 2. union \ intersect \ except + select
|
||||
PATTERN_2 = re.compile(r'(UNION|INTERSECT|EXCEPT).*SELECT', re.IGNORECASE)
|
||||
|
||||
# 3. ';' and any KEY_WORDS
|
||||
PATTERN_3 = re.compile(r';.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)', re.IGNORECASE)
|
||||
|
||||
|
||||
class SQLParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnclosedQuoteError(SQLParseError):
|
||||
pass
|
||||
|
||||
|
||||
# maps a type of identifier to the maximum number of dot levels that are
|
||||
# allowed to specify that identifier. For example, a database column can be
|
||||
# specified by up to 4 levels: database.schema.table.column
|
||||
_PG_IDENTIFIER_TO_DOT_LEVEL = dict(
|
||||
database=1,
|
||||
schema=2,
|
||||
table=3,
|
||||
column=4,
|
||||
role=1,
|
||||
tablespace=1,
|
||||
sequence=3,
|
||||
publication=1,
|
||||
)
|
||||
_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1)
|
||||
|
||||
|
||||
def _find_end_quote(identifier, quote_char):
|
||||
accumulate = 0
|
||||
while True:
|
||||
try:
|
||||
quote = identifier.index(quote_char)
|
||||
except ValueError:
|
||||
raise UnclosedQuoteError
|
||||
accumulate = accumulate + quote
|
||||
try:
|
||||
next_char = identifier[quote + 1]
|
||||
except IndexError:
|
||||
return accumulate
|
||||
if next_char == quote_char:
|
||||
try:
|
||||
identifier = identifier[quote + 2:]
|
||||
accumulate = accumulate + 2
|
||||
except IndexError:
|
||||
raise UnclosedQuoteError
|
||||
else:
|
||||
return accumulate
|
||||
|
||||
|
||||
def _identifier_parse(identifier, quote_char):
|
||||
if not identifier:
|
||||
raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
|
||||
|
||||
already_quoted = False
|
||||
if identifier.startswith(quote_char):
|
||||
already_quoted = True
|
||||
try:
|
||||
end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
|
||||
except UnclosedQuoteError:
|
||||
already_quoted = False
|
||||
else:
|
||||
if end_quote < len(identifier) - 1:
|
||||
if identifier[end_quote + 1] == '.':
|
||||
dot = end_quote + 1
|
||||
first_identifier = identifier[:dot]
|
||||
next_identifier = identifier[dot + 1:]
|
||||
further_identifiers = _identifier_parse(next_identifier, quote_char)
|
||||
further_identifiers.insert(0, first_identifier)
|
||||
else:
|
||||
raise SQLParseError('User escaped identifiers must escape extra quotes')
|
||||
else:
|
||||
further_identifiers = [identifier]
|
||||
|
||||
if not already_quoted:
|
||||
try:
|
||||
dot = identifier.index('.')
|
||||
except ValueError:
|
||||
identifier = identifier.replace(quote_char, quote_char * 2)
|
||||
identifier = ''.join((quote_char, identifier, quote_char))
|
||||
further_identifiers = [identifier]
|
||||
else:
|
||||
if dot == 0 or dot >= len(identifier) - 1:
|
||||
identifier = identifier.replace(quote_char, quote_char * 2)
|
||||
identifier = ''.join((quote_char, identifier, quote_char))
|
||||
further_identifiers = [identifier]
|
||||
else:
|
||||
first_identifier = identifier[:dot]
|
||||
next_identifier = identifier[dot + 1:]
|
||||
further_identifiers = _identifier_parse(next_identifier, quote_char)
|
||||
first_identifier = first_identifier.replace(quote_char, quote_char * 2)
|
||||
first_identifier = ''.join((quote_char, first_identifier, quote_char))
|
||||
further_identifiers.insert(0, first_identifier)
|
||||
|
||||
return further_identifiers
|
||||
|
||||
|
||||
def pg_quote_identifier(identifier, id_type):
|
||||
identifier_fragments = _identifier_parse(identifier, quote_char='"')
|
||||
if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
|
||||
raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
|
||||
return '.'.join(identifier_fragments)
|
||||
|
||||
|
||||
def mysql_quote_identifier(identifier, id_type):
|
||||
identifier_fragments = _identifier_parse(identifier, quote_char='`')
|
||||
if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
|
||||
raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]))
|
||||
|
||||
special_cased_fragments = []
|
||||
for fragment in identifier_fragments:
|
||||
if fragment == '`*`':
|
||||
special_cased_fragments.append('*')
|
||||
else:
|
||||
special_cased_fragments.append(fragment)
|
||||
|
||||
return '.'.join(special_cased_fragments)
|
||||
|
||||
|
||||
def is_input_dangerous(string):
|
||||
"""Check if the passed string is potentially dangerous.
|
||||
Can be used to prevent SQL injections.
|
||||
|
||||
Note: use this function only when you can't use
|
||||
psycopg2's cursor.execute method parametrized
|
||||
(typically with DDL queries).
|
||||
"""
|
||||
if not string:
|
||||
return False
|
||||
|
||||
for pattern in (PATTERN_1, PATTERN_2, PATTERN_3):
|
||||
if re.search(pattern, string):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_input(module, *args):
|
||||
"""Wrapper for is_input_dangerous function."""
|
||||
needs_to_check = args
|
||||
|
||||
dangerous_elements = []
|
||||
|
||||
for elem in needs_to_check:
|
||||
if isinstance(elem, str):
|
||||
if is_input_dangerous(elem):
|
||||
dangerous_elements.append(elem)
|
||||
|
||||
elif isinstance(elem, list):
|
||||
for e in elem:
|
||||
if is_input_dangerous(e):
|
||||
dangerous_elements.append(e)
|
||||
|
||||
elif elem is None or isinstance(elem, bool):
|
||||
pass
|
||||
|
||||
else:
|
||||
elem = str(elem)
|
||||
if is_input_dangerous(elem):
|
||||
dangerous_elements.append(elem)
|
||||
|
||||
if dangerous_elements:
|
||||
module.fail_json(msg="Passed input '%s' is "
|
||||
"potentially dangerous" % ', '.join(dangerous_elements))
|
110
plugins/module_utils/mysql.py
Normal file
110
plugins/module_utils/mysql.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# This code is part of Ansible, but is an independent component.
|
||||
# This particular file snippet, and this file snippet only, is BSD licensed.
|
||||
# Modules you write using this snippet, which is embedded dynamically by Ansible
|
||||
# still belong to the author of the module, and may assign their own license
|
||||
# to the complete work.
|
||||
#
|
||||
# Copyright (c), Jonathan Mainguy <jon@soh.re>, 2015
|
||||
# Most of this was originally added by Sven Schliesing @muffl0n in the mysql_user.py module
|
||||
#
|
||||
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
|
||||
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import os
|
||||
|
||||
from ansible.module_utils.six.moves import configparser
|
||||
|
||||
try:
|
||||
import pymysql as mysql_driver
|
||||
_mysql_cursor_param = 'cursor'
|
||||
except ImportError:
|
||||
try:
|
||||
import MySQLdb as mysql_driver
|
||||
import MySQLdb.cursors
|
||||
_mysql_cursor_param = 'cursorclass'
|
||||
except ImportError:
|
||||
mysql_driver = None
|
||||
|
||||
mysql_driver_fail_msg = 'The PyMySQL (Python 2.7 and Python 3.X) or MySQL-python (Python 2.X) module is required.'
|
||||
|
||||
|
||||
def parse_from_mysql_config_file(cnf):
|
||||
cp = configparser.ConfigParser()
|
||||
cp.read(cnf)
|
||||
return cp
|
||||
|
||||
|
||||
def mysql_connect(module, login_user=None, login_password=None, config_file='', ssl_cert=None,
|
||||
ssl_key=None, ssl_ca=None, db=None, cursor_class=None,
|
||||
connect_timeout=30, autocommit=False, config_overrides_defaults=False):
|
||||
config = {}
|
||||
|
||||
if config_file and os.path.exists(config_file):
|
||||
config['read_default_file'] = config_file
|
||||
cp = parse_from_mysql_config_file(config_file)
|
||||
# Override some commond defaults with values from config file if needed
|
||||
if cp and cp.has_section('client') and config_overrides_defaults:
|
||||
try:
|
||||
module.params['login_host'] = cp.get('client', 'host', fallback=module.params['login_host'])
|
||||
module.params['login_port'] = cp.getint('client', 'port', fallback=module.params['login_port'])
|
||||
except Exception as e:
|
||||
if "got an unexpected keyword argument 'fallback'" in e.message:
|
||||
module.fail_json('To use config_overrides_defaults, '
|
||||
'it needs Python 3.5+ as the default interpreter on a target host')
|
||||
|
||||
if ssl_ca is not None or ssl_key is not None or ssl_cert is not None:
|
||||
config['ssl'] = {}
|
||||
|
||||
if module.params['login_unix_socket']:
|
||||
config['unix_socket'] = module.params['login_unix_socket']
|
||||
else:
|
||||
config['host'] = module.params['login_host']
|
||||
config['port'] = module.params['login_port']
|
||||
|
||||
# If login_user or login_password are given, they should override the
|
||||
# config file
|
||||
if login_user is not None:
|
||||
config['user'] = login_user
|
||||
if login_password is not None:
|
||||
config['passwd'] = login_password
|
||||
if ssl_cert is not None:
|
||||
config['ssl']['cert'] = ssl_cert
|
||||
if ssl_key is not None:
|
||||
config['ssl']['key'] = ssl_key
|
||||
if ssl_ca is not None:
|
||||
config['ssl']['ca'] = ssl_ca
|
||||
if db is not None:
|
||||
config['db'] = db
|
||||
if connect_timeout is not None:
|
||||
config['connect_timeout'] = connect_timeout
|
||||
|
||||
if _mysql_cursor_param == 'cursor':
|
||||
# In case of PyMySQL driver:
|
||||
db_connection = mysql_driver.connect(autocommit=autocommit, **config)
|
||||
else:
|
||||
# In case of MySQLdb driver
|
||||
db_connection = mysql_driver.connect(**config)
|
||||
if autocommit:
|
||||
db_connection.autocommit(True)
|
||||
|
||||
if cursor_class == 'DictCursor':
|
||||
return db_connection.cursor(**{_mysql_cursor_param: mysql_driver.cursors.DictCursor}), db_connection
|
||||
else:
|
||||
return db_connection.cursor(), db_connection
|
||||
|
||||
|
||||
def mysql_common_argument_spec():
|
||||
return dict(
|
||||
login_user=dict(type='str', default=None),
|
||||
login_password=dict(type='str', no_log=True),
|
||||
login_host=dict(type='str', default='localhost'),
|
||||
login_port=dict(type='int', default=3306),
|
||||
login_unix_socket=dict(type='str'),
|
||||
config_file=dict(type='path', default='~/.my.cnf'),
|
||||
connect_timeout=dict(type='int', default=30),
|
||||
client_cert=dict(type='path', aliases=['ssl_cert']),
|
||||
client_key=dict(type='path', aliases=['ssl_key']),
|
||||
ca_cert=dict(type='path', aliases=['ssl_ca']),
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue