diff --git a/ansible_testing/modules.py b/ansible_testing/modules.py index cc3d762b5c..26b12d5f5f 100644 --- a/ansible_testing/modules.py +++ b/ansible_testing/modules.py @@ -214,31 +214,32 @@ class ModuleValidator(Validator): def _find_module_utils(self, main): linenos = [] + found_basic = False for child in self.ast.body: - found_module_utils_import = False - if isinstance(child, ast.ImportFrom): - if child.module.startswith('ansible.module_utils.'): - found_module_utils_import = True + if isinstance(child, (ast.Import, ast.ImportFrom)): + names = [] + try: + names.append(child.module) + if child.module.endswith('.basic'): + found_basic = True + except AttributeError: + pass + names.extend([n.name for n in child.names]) + if [n for n in names if n.startswith('ansible.module_utils')]: linenos.append(child.lineno) - if not child.names: - self.errors.append('%s: not a "from" import"' % - child.module) - - found_alias = False for name in child.names: - if isinstance(name, ast.alias): - found_alias = True - if name.asname or name.name != '*': - self.errors.append('%s: did not import "*"' % - child.module) - - if found_module_utils_import and not found_alias: - self.errors.append('%s: did not import "*"' % child.module) + print(name.name) + if (isinstance(name, ast.alias) and + name.name == 'basic'): + found_basic = True if not linenos: self.errors.append('Did not find a module_utils import') + elif not found_basic: + self.errors.append('Did not find "ansible.module_utils.basic" ' + 'import') return linenos