Use select in wait_for so that we don't get stuck in cornercases:

* reading from a socket that gave some data we weren't looking for and
  then closed.
* read from a socket that stays open and never sends data.
* reading from a socket that sends data but not the data we're looking
  for.

Fixes #2051
This commit is contained in:
Toshio Kuratomi 2015-10-27 17:26:51 -07:00 committed by Matt Clay
commit fda9eeaa89

View file

@ -18,12 +18,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>. # along with Ansible. If not, see <http://www.gnu.org/licenses/>.
import socket
import datetime
import time
import sys
import re
import binascii import binascii
import datetime
import math
import re
import select
import socket
import sys
import time
HAS_PSUTIL = False HAS_PSUTIL = False
try: try:
@ -349,6 +351,10 @@ def main():
state = params['state'] state = params['state']
path = params['path'] path = params['path']
search_regex = params['search_regex'] search_regex = params['search_regex']
if search_regex is not None:
compiled_search_re = re.compile(search_regex, re.MULTILINE)
else:
compiled_search_re = None
if port and path: if port and path:
module.fail_json(msg="port and path parameter can not both be passed to wait_for") module.fail_json(msg="port and path parameter can not both be passed to wait_for")
@ -404,55 +410,72 @@ def main():
if path: if path:
try: try:
os.stat(path) os.stat(path)
if search_regex: except OSError, e:
# If anything except file not present, throw an error
if e.errno != 2:
elapsed = datetime.datetime.now() - start
module.fail_json(msg="Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
# file doesn't exist yet, so continue
else:
# File exists. Are there additional things to check?
if not compiled_search_re:
# nope, succeed!
break
try: try:
f = open(path) f = open(path)
try: try:
if re.search(search_regex, f.read(), re.MULTILINE): if re.search(compiled_search_re, f.read()):
# String found, success!
break break
else:
time.sleep(1)
finally: finally:
f.close() f.close()
except IOError: except IOError:
time.sleep(1) pass
elif port:
alt_connect_timeout = math.ceil((end - datetime.datetime.now()).total_seconds())
try:
s = _create_connection((host, port), min(connect_timeout, alt_connect_timeout))
except:
# Failed to connect by connect_timeout. wait and try again
pass pass
else: else:
break # Connected -- are there additional conditions?
except OSError, e: if compiled_search_re:
# File not present
if e.errno == 2:
time.sleep(1)
else:
elapsed = datetime.datetime.now() - start
module.fail_json(msg="Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
elif port:
try:
s = _create_connection( (host, port), connect_timeout)
if search_regex:
data = '' data = ''
matched = False matched = False
while 1: while datetime.datetime.now() < end:
data += s.recv(1024) max_timeout = math.ceil((end - datetime.datetime.now()).total_seconds())
if not data: (readable, w, e) = select.select([s], [], [], max_timeout)
if not readable:
# No new data. Probably means our timeout
# expired
continue
response = s.recv(1024)
if not response:
# Server shutdown
break break
elif re.search(search_regex, data, re.MULTILINE): data += response
if re.search(compiled_search_re, data):
matched = True matched = True
break break
# Shutdown the client socket
s.shutdown(socket.SHUT_RDWR)
s.close()
if matched: if matched:
# Found our string, success!
break
else:
# Connection established, success!
s.shutdown(socket.SHUT_RDWR) s.shutdown(socket.SHUT_RDWR)
s.close() s.close()
break break
else:
s.shutdown(socket.SHUT_RDWR) # Conditions not yet met, wait and try again
s.close()
break
except:
time.sleep(1) time.sleep(1)
pass
else: else: # while-else
time.sleep(1) # Timeout expired
else:
elapsed = datetime.datetime.now() - start elapsed = datetime.datetime.now() - start
if port: if port:
if search_regex: if search_regex:
@ -485,4 +508,5 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
main() if __name__ == '__main__':
main()