Add module common code to allow it to be easier to indicate whether arguments are mutually exclusive, required in conjunction, or whether one of a list of arguments is required. This simplifies writing Python modules.

This commit is contained in:
Michael DeHaan 2012-08-11 18:13:29 -04:00
parent 98c350a6ac
commit 1e4d45af1e
4 changed files with 61 additions and 52 deletions

View file

@ -55,11 +55,14 @@ except ImportError:
class AnsibleModule(object):
def __init__(self, argument_spec, bypass_checks=False, no_log=False, check_invalid_arguments=True):
def __init__(self, argument_spec, bypass_checks=False, no_log=False,
check_invalid_arguments=True, mutually_exclusive=None, required_together=None,
required_one_of=None):
'''
common code for quickly building an ansible module in Python
(although you can write modules in anything that can return JSON)
see library/slurp and others for examples
see library/* for examples
'''
self.argument_spec = argument_spec
@ -77,12 +80,14 @@ class AnsibleModule(object):
if not bypass_checks:
self._check_required_arguments()
self._check_argument_types()
self._check_mutually_exclusive(mutually_exclusive)
self._check_required_together(required_together)
self._check_required_one_of(required_one_of)
self._set_defaults(pre=False)
if not no_log:
self._log_invocation()
def _handle_aliases(self):
for (k,v) in self.argument_spec.iteritems():
self._legal_inputs.append(k)
@ -106,6 +111,39 @@ class AnsibleModule(object):
if k not in self._legal_inputs:
self.fail_json(msg="unsupported parameter for module: %s" % k)
def _count_terms(self, check):
count = 0
for term in check:
if term in self.params:
count += 1
return count
def _check_mutually_exclusive(self, spec):
if spec is None:
return
for check in spec:
count = self._count_terms(check)
if count > 1:
self.fail_json(msg="parameters are mutually exclusive: %s" % check)
def _check_required_one_of(self, spec):
if spec is None:
return
for check in spec:
count = self._count_terms(check)
if count == 0:
self.fail_json(msg="one of the following is required: %s" % check)
def _check_required_together(self, spec):
if spec is None:
return
for check in spec:
counts = [ self.count_terms([field]) for field in check ]
non_zero = [ c for c in counts if c > 0 ]
if len(non_zero) > 0:
if 0 in counts:
self.fail_json(msg="parameters are required together: %s" % check)
def _check_required_arguments(self):
''' ensure all required arguments are present '''
missing = []