diff --git a/tests/unit/plugins/modules/helper.py b/tests/unit/plugins/modules/helper.py
index b02fe88635..94ccfb4f38 100644
--- a/tests/unit/plugins/modules/helper.py
+++ b/tests/unit/plugins/modules/helper.py
@@ -18,6 +18,83 @@ ModuleTestCase = namedtuple("ModuleTestCase", ["id", "input", "output", "run_com
 RunCmdCall = namedtuple("RunCmdCall", ["command", "environ", "rc", "out", "err"])
 
 
+class _BaseContext(object):
+    def __init__(self, helper, testcase, mocker, capfd):
+        self.helper = helper
+        self.testcase = testcase
+        self.mocker = mocker
+        self.capfd = capfd
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        return False
+
+    def _run(self):
+        with pytest.raises(SystemExit):
+            self.helper.module_main()
+
+        out, err = self.capfd.readouterr()
+        results = json.loads(out)
+
+        self.check_results(results)
+
+    def test_flags(self, flag=None):
+        flags = self.testcase.flags
+        if flag:
+            flags = flags.get(flag)
+        return flags
+
+    def run(self):
+        func = self._run
+
+        test_flags = self.test_flags()
+        if test_flags.get("skip"):
+            pytest.skip(reason=test_flags["skip"])
+        if test_flags.get("xfail"):
+            pytest.xfail(reason=test_flags["xfail"])
+
+        func()
+
+    def check_results(self, results):
+        print("testcase =\n%s" % str(self.testcase))
+        print("results =\n%s" % results)
+        if 'exception' in results:
+            print("exception = \n%s" % results["exception"])
+
+        for test_result in self.testcase.output:
+            assert results[test_result] == self.testcase.output[test_result], \
+                "'{0}': '{1}' != '{2}'".format(test_result, results[test_result], self.testcase.output[test_result])
+
+
+class _RunCmdContext(_BaseContext):
+    def __init__(self, *args, **kwargs):
+        super(_RunCmdContext, self).__init__(*args, **kwargs)
+        self.run_cmd_calls = self.testcase.run_command_calls
+        self.mock_run_cmd = self._make_mock_run_cmd()
+
+    def _make_mock_run_cmd(self):
+        call_results = [(x.rc, x.out, x.err) for x in self.run_cmd_calls]
+        error_call_results = (123,
+                              "OUT: testcase has not enough run_command calls",
+                              "ERR: testcase has not enough run_command calls")
+        mock_run_command = self.mocker.patch('ansible.module_utils.basic.AnsibleModule.run_command',
+                                             side_effect=chain(call_results, repeat(error_call_results)))
+        return mock_run_command
+
+    def check_results(self, results):
+        super(_RunCmdContext, self).check_results(results)
+        call_args_list = [(item[0][0], item[1]) for item in self.mock_run_cmd.call_args_list]
+        expected_call_args_list = [(item.command, item.environ) for item in self.run_cmd_calls]
+        print("call args list =\n%s" % call_args_list)
+        print("expected args list =\n%s" % expected_call_args_list)
+
+        assert self.mock_run_cmd.call_count == len(self.run_cmd_calls)
+        if self.mock_run_cmd.call_count:
+            assert call_args_list == expected_call_args_list
+
+
 class Helper(object):
     @staticmethod
     def from_list(module_main, list_):
@@ -73,7 +150,7 @@ class Helper(object):
         return [item.id for item in self.testcases]
 
     def __call__(self, *args, **kwargs):
-        return _Context(self, *args, **kwargs)
+        return _RunCmdContext(self, *args, **kwargs)
 
     @property
     def test_module(self):
@@ -92,74 +169,3 @@ class Helper(object):
                 testcase_context.run()
 
         return _test_module
-
-
-class _Context(object):
-    def __init__(self, helper, testcase, mocker, capfd):
-        self.helper = helper
-        self.testcase = testcase
-        self.mocker = mocker
-        self.capfd = capfd
-
-        self.run_cmd_calls = self.testcase.run_command_calls
-        self.mock_run_cmd = self._make_mock_run_cmd()
-
-    def _make_mock_run_cmd(self):
-        call_results = [(x.rc, x.out, x.err) for x in self.run_cmd_calls]
-        error_call_results = (123,
-                              "OUT: testcase has not enough run_command calls",
-                              "ERR: testcase has not enough run_command calls")
-        mock_run_command = self.mocker.patch('ansible.module_utils.basic.AnsibleModule.run_command',
-                                             side_effect=chain(call_results, repeat(error_call_results)))
-        return mock_run_command
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        return False
-
-    def _run(self):
-        with pytest.raises(SystemExit):
-            self.helper.module_main()
-
-        out, err = self.capfd.readouterr()
-        results = json.loads(out)
-
-        self.check_results(results)
-
-    def test_flags(self, flag=None):
-        flags = self.testcase.flags
-        if flag:
-            flags = flags.get(flag)
-        return flags
-
-    def run(self):
-        func = self._run
-
-        test_flags = self.test_flags()
-        if test_flags.get("skip"):
-            pytest.skip(reason=test_flags["skip"])
-        if test_flags.get("xfail"):
-            pytest.xfail(reason=test_flags["xfail"])
-
-        func()
-
-    def check_results(self, results):
-        print("testcase =\n%s" % str(self.testcase))
-        print("results =\n%s" % results)
-        if 'exception' in results:
-            print("exception = \n%s" % results["exception"])
-
-        for test_result in self.testcase.output:
-            assert results[test_result] == self.testcase.output[test_result], \
-                "'{0}': '{1}' != '{2}'".format(test_result, results[test_result], self.testcase.output[test_result])
-
-        call_args_list = [(item[0][0], item[1]) for item in self.mock_run_cmd.call_args_list]
-        expected_call_args_list = [(item.command, item.environ) for item in self.run_cmd_calls]
-        print("call args list =\n%s" % call_args_list)
-        print("expected args list =\n%s" % expected_call_args_list)
-
-        assert self.mock_run_cmd.call_count == len(self.run_cmd_calls)
-        if self.mock_run_cmd.call_count:
-            assert call_args_list == expected_call_args_list