diff --git a/lib/ansible/modules/utilities/logic/wait_for.py b/lib/ansible/modules/utilities/logic/wait_for.py index 00f8d0a88d..a3ae8e17c6 100644 --- a/lib/ansible/modules/utilities/logic/wait_for.py +++ b/lib/ansible/modules/utilities/logic/wait_for.py @@ -85,6 +85,11 @@ options: - port number to poll required: false default: null + active_connection_states: + description: + - The list of tcp connection states which are counted as active connections + default: ['ESTABLISHED','SYN_SENT','SYN_RECV','FIN_WAIT1','FIN_WAIT2','TIME_WAIT'] + version_added: "2.3" state: description: - either C(present), C(started), or C(stopped), C(absent), or C(drained) @@ -194,14 +199,6 @@ class TCPConnectionInfo(object): 'prefix': '::ffff', 'match_all': '::ffff:0.0.0.0' } - connection_states = { - '01': 'ESTABLISHED', - '02': 'SYN_SENT', - '03': 'SYN_RECV', - '04': 'FIN_WAIT1', - '05': 'FIN_WAIT2', - '06': 'TIME_WAIT', - } def __new__(cls, *args, **kwargs): return load_platform_subclass(TCPConnectionInfo, args, kwargs) @@ -227,7 +224,7 @@ class TCPConnectionInfo(object): for p in psutil.process_iter(): connections = p.get_connections(kind='inet') for conn in connections: - if conn.status not in self.connection_states.values(): + if conn.status not in self.module.params['active_connection_states']: continue (local_ip, local_port) = conn.local_address if self.port != local_port: @@ -295,7 +292,7 @@ class LinuxTCPConnectionInfo(TCPConnectionInfo): tcp_connection = tcp_connection.strip().split() if tcp_connection[self.local_address_field] == 'local_address': continue - if tcp_connection[self.connection_state_field] not in self.connection_states: + if tcp_connection[self.connection_state_field] not in [ get_connection_state_id(_connection_state) for _connection_state in self.module.params['active_connection_states'] ]: continue (local_ip, local_port) = tcp_connection[self.local_address_field].split(':') if self.port != local_port: @@ -383,6 +380,17 @@ def _timedelta_total_seconds(timedelta): timedelta.microseconds + 0.0 + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6) / 10 ** 6 +def get_connection_state_id(state): + connection_state_id = { + 'ESTABLISHED': '01', + 'SYN_SENT': '02', + 'SYN_RECV': '03', + 'FIN_WAIT1': '04', + 'FIN_WAIT2': '05', + 'TIME_WAIT': '06', + } + return connection_state_id[state] + def main(): module = AnsibleModule( @@ -392,6 +400,7 @@ def main(): connect_timeout=dict(default=5, type='int'), delay=dict(default=0, type='int'), port=dict(default=None, type='int'), + active_connection_states=dict(default=['ESTABLISHED','SYN_SENT','SYN_RECV','FIN_WAIT1','FIN_WAIT2','TIME_WAIT'], type='list'), path=dict(default=None, type='path'), search_regex=dict(default=None), state=dict(default='started', choices=['started', 'stopped', 'present', 'absent', 'drained']), @@ -423,7 +432,11 @@ def main(): module.fail_json(msg="state=drained should only be used for checking a port in the wait_for module") if params['exclude_hosts'] is not None and state != 'drained': module.fail_json(msg="exclude_hosts should only be with state=drained") - + for _connection_state in params['active_connection_states']: + try: + get_connection_state_id(_connection_state) + except: + module.fail_json(msg="unknown active_connection_state ("+_connection_state+") defined") start = datetime.datetime.now()