diff --git a/plugins/module_utils/mysql.py b/plugins/module_utils/mysql.py index b95d20d..d0048bf 100644 --- a/plugins/module_utils/mysql.py +++ b/plugins/module_utils/mysql.py @@ -206,6 +206,23 @@ def get_server_version(cursor): return version_str +def get_server_implementation(cursor): + if 'mariadb' in get_server_version(cursor).lower(): + return "mariadb" + else: + return "mysql" + +def is_mariadb(implementation): + if implementation == "mariadb": + return True + else: + return False + +def is_mysql(implementation): + if implementation == "mysql": + return True + else: + return False def set_session_vars(module, cursor, session_vars): """Set session vars.""" diff --git a/plugins/modules/mysql_info.py b/plugins/modules/mysql_info.py index 10f9dcf..26b9cec 100644 --- a/plugins/modules/mysql_info.py +++ b/plugins/modules/mysql_info.py @@ -293,6 +293,8 @@ from ansible_collections.community.mysql.plugins.module_utils.mysql import ( mysql_driver_fail_msg, get_connector_name, get_connector_version, + get_server_implementation, + is_mariadb, ) from ansible_collections.community.mysql.plugins.module_utils.user import ( @@ -389,18 +391,6 @@ class MySQL_Info(object): self.__collect(exclude_fields, return_empty_dbs, set(self.info)) return self.info - def is_mariadb(self): - if self.server_implementation == "mariadb": - return True - else: - return False - - def is_mysql(self): - if self.server_implementation == "mysql": - return True - else: - return False - def __collect(self, exclude_fields, return_empty_dbs, wanted): """Collect all possible subsets.""" if 'version' in wanted or 'settings' in wanted: @@ -511,7 +501,7 @@ class MySQL_Info(object): def __get_slave_status(self): """Get slave status if the instance is a slave.""" - if self.is_mariadb(): + if is_mariadb(self.server_implementation): res = self.__exec_sql('SHOW ALL SLAVES STATUS') else: res = self.__exec_sql('SHOW SLAVE STATUS') @@ -755,11 +745,7 @@ def main(): 'Exception message: %s' % (connector_name, connector_version, config_file, to_native(e))) module.fail_json(msg) - cursor.execute("SELECT VERSION()") - if 'mariadb' in cursor.fetchone()["VERSION()"].lower(): - server_implementation = "mariadb" - else: - server_implementation = "mysql" + server_implementation = get_server_implementation(cursor) ############################### # Create object and do main job diff --git a/tests/unit/plugins/module_utils/test_mysql.py b/tests/unit/plugins/module_utils/test_mysql.py index ac4de24..8557ff8 100644 --- a/tests/unit/plugins/module_utils/test_mysql.py +++ b/tests/unit/plugins/module_utils/test_mysql.py @@ -1,9 +1,10 @@ from __future__ import (absolute_import, division, print_function) + __metaclass__ = type import pytest -from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version +from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version, get_server_implementation, is_mariadb, is_mysql from ..utils import dummy_cursor_class @@ -22,3 +23,34 @@ def test_get_server_version(cursor_return_version, cursor_return_type): """ cursor = dummy_cursor_class(cursor_return_version, cursor_return_type) assert get_server_version(cursor) == cursor_return_version + +@pytest.mark.parametrize( + 'cursor_return_version,cursor_return_type,server_implementation', + [ + ('5.7.0-mysql', 'dict', 'mysql'), + ('8.0.0-mysql', 'list', 'mysql'), + ('10.5.0-mariadb', 'dict', 'mariadb'), + ('10.5.1-mariadb', 'list', 'mariadb'), + ] +) +def test_get_server_implamentation(cursor_return_version, cursor_return_type, server_implementation): + """ + Test that server implementation are handled properly by get_server_implementation() whether the server version returned as a list or dict. + """ + cursor = dummy_cursor_class(cursor_return_version, cursor_return_type) + + assert get_server_implementation(cursor) == server_implementation + +def test_is_mysql(): + """ + Test that server is_mysql return expect results + """ + assert is_mysql("mysql") == True + assert is_mysql("mariadb") == False + +def test_is_mariadb(): + """ + Test that server is_mariadb return expect results + """ + assert is_mariadb("mariadb") == True + assert is_mariadb("mysql") == False diff --git a/tests/unit/plugins/modules/test_mysql_info.py b/tests/unit/plugins/modules/test_mysql_info.py index d1fb070..6aaf66e 100644 --- a/tests/unit/plugins/modules/test_mysql_info.py +++ b/tests/unit/plugins/modules/test_mysql_info.py @@ -35,33 +35,3 @@ def test_get_info_suffix(suffix, cursor_output, server_implementation): info = MySQL_Info(MagicMock(), cursor, server_implementation) assert info.get_info([], [], False)['version']['suffix'] == suffix - - -@pytest.mark.parametrize( - 'server_implementation', - [ - ('mysql'), - ('mariadb'), - ] -) -def test_is_mariadb(server_implementation): - cursor = MagicMock() - - info = MySQL_Info(MagicMock(), cursor, server_implementation) - - assert info.is_mariadb() == (server_implementation == "mariadb") - - -@pytest.mark.parametrize( - 'server_implementation', - [ - ('mysql'), - ('mariadb'), - ] -) -def test_is_mysql(server_implementation): - cursor = MagicMock() - - info = MySQL_Info(MagicMock(), cursor, server_implementation) - - assert info.is_mysql() == (server_implementation == "mysql")