[cloud] ec2_group fix CIDR with host bits set - fixes #25403 (#29605)

* WIP adds network subnetting functions

* adds functions to convert between netmask and masklen
* adds functions to verify netmask and masklen
* adds function to dtermine network and subnet from address / mask pair

* network_common: add a function to get the first 48 bits in a IPv6 address.

ec2_group: only use network bits of a CIDR.

* Add tests for CIDRs with host bits set.

* ec2_group: add warning if CIDR isn't the networking address.

* Fix pep8.

* Improve wording.

* fix import for network utils

* Update tests to use pytest instead of unittest

* add test for to_ipv6_network()

* Fix PEP8
This commit is contained in:
Sloane Hertel 2017-12-20 14:57:47 -05:00 committed by Ryan Brown
commit d877c146ab
4 changed files with 385 additions and 101 deletions

View file

@ -31,8 +31,11 @@ import operator
import socket
from itertools import chain
from struct import pack
from socket import inet_aton, inet_ntoa
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils.six.moves import zip
from ansible.module_utils.basic import AnsibleFallbackNotFound
try:
@ -45,6 +48,7 @@ except ImportError:
OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le'])
ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')])
VALID_MASKS = [2**8 - 2**i for i in range(0, 9)]
def to_list(val):
@ -426,3 +430,106 @@ class Template:
if marker in data:
return True
return False
def is_netmask(val):
parts = str(val).split('.')
if not len(parts) == 4:
return False
for part in parts:
try:
if int(part) not in VALID_MASKS:
raise ValueError
except ValueError:
return False
return True
def is_masklen(val):
try:
return 0 <= int(val) <= 32
except ValueError:
return False
def to_bits(val):
""" converts a netmask to bits """
bits = ''
for octet in val.split('.'):
bits += bin(int(octet))[2:].zfill(8)
return str
def to_netmask(val):
""" converts a masklen to a netmask """
if not is_masklen(val):
raise ValueError('invalid value for masklen')
bits = 0
for i in range(32 - int(val), 32):
bits |= (1 << i)
return inet_ntoa(pack('>I', bits))
def to_masklen(val):
""" converts a netmask to a masklen """
if not is_netmask(val):
raise ValueError('invalid value for netmask: %s' % val)
bits = list()
for x in val.split('.'):
octet = bin(int(x)).count('1')
bits.append(octet)
return sum(bits)
def to_subnet(addr, mask, dotted_notation=False):
""" coverts an addr / mask pair to a subnet in cidr notation """
try:
if not is_masklen(mask):
raise ValueError
cidr = int(mask)
mask = to_netmask(mask)
except ValueError:
cidr = to_masklen(mask)
addr = addr.split('.')
mask = mask.split('.')
network = list()
for s_addr, s_mask in zip(addr, mask):
network.append(str(int(s_addr) & int(s_mask)))
if dotted_notation:
return '%s %s' % ('.'.join(network), to_netmask(cidr))
return '%s/%s' % ('.'.join(network), cidr)
def to_ipv6_network(addr):
""" IPv6 addresses are eight groupings. The first three groupings (48 bits) comprise the network address. """
# Split by :: to identify omitted zeros
ipv6_prefix = addr.split('::')[0]
# Get the first three groups, or as many as are found + ::
found_groups = []
for group in ipv6_prefix.split(':'):
found_groups.append(group)
if len(found_groups) == 3:
break
if len(found_groups) < 3:
found_groups.append('::')
# Concatenate network address parts
network_addr = ''
for group in found_groups:
if group != '::':
network_addr += str(group)
network_addr += str(':')
# Ensure network address ends with ::
if not network_addr.endswith('::'):
network_addr += str(':')
return network_addr

View file

@ -289,6 +289,7 @@ from ansible.module_utils.ec2 import camel_dict_to_snake_dict
from ansible.module_utils.ec2 import HAS_BOTO3
from ansible.module_utils.ec2 import boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list, compare_aws_tags
from ansible.module_utils.ec2 import AWSRetry
from ansible.module_utils.network.common.utils import to_ipv6_network, to_subnet
import traceback
try:
@ -521,7 +522,22 @@ def update_rules_description(module, client, rule_type, group_id, ip_permissions
def authorize_ip(type, changed, client, group, groupRules,
ip, ip_permission, module, rule, ethertype):
# If rule already exists, don't later delete it
for thisip in ip:
for this_ip in ip:
split_addr = this_ip.split('/')
if len(split_addr) == 2:
# this_ip is a IPv4 or IPv6 CIDR that may or may not have host bits set
# Get the network bits.
try:
thisip = to_subnet(split_addr[0], split_addr[1])
except ValueError:
thisip = to_ipv6_network(split_addr[0]) + "/" + split_addr[1]
if thisip != this_ip:
module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, "
"check the network mask and make sure that only network bits are set: {1}.".format(this_ip, thisip))
else:
thisip = this_ip
rule_id = make_rule_key(type, rule, group['GroupId'], thisip)
if rule_id in groupRules:

View file

@ -435,6 +435,112 @@
- 'result.changed'
- 'result.group_id.startswith("sg-")'
# ============================================================
- name: test adding a rule with a IPv4 CIDR with host bits set (expected changed=true)
ec2_group:
name: '{{ec2_group_name}}'
description: '{{ec2_group_description}}'
ec2_region: '{{ec2_region}}'
ec2_access_key: '{{ec2_access_key}}'
ec2_secret_key: '{{ec2_secret_key}}'
security_token: '{{security_token}}'
state: present
# set purge_rules to false so we don't get a false positive from previously added rules
purge_rules: false
rules:
- proto: "tcp"
ports:
- 8195
cidr_ip: 10.0.0.1/8
register: result
- name: assert state=present (expected changed=true)
assert:
that:
- 'result.changed'
- 'result.group_id.startswith("sg-")'
# ============================================================
- name: test adding the same rule with a IPv4 CIDR with host bits set (expected changed=false and a warning)
ec2_group:
name: '{{ec2_group_name}}'
description: '{{ec2_group_description}}'
ec2_region: '{{ec2_region}}'
ec2_access_key: '{{ec2_access_key}}'
ec2_secret_key: '{{ec2_secret_key}}'
security_token: '{{security_token}}'
state: present
# set purge_rules to false so we don't get a false positive from previously added rules
purge_rules: false
rules:
- proto: "tcp"
ports:
- 8195
cidr_ip: 10.0.0.1/8
register: result
- name: assert state=present (expected changed=false and a warning)
assert:
that:
# No way to assert for warnings?
- 'not result.changed'
- 'result.group_id.startswith("sg-")'
# ============================================================
- name: test adding a rule with a IPv6 CIDR with host bits set (expected changed=true)
ec2_group:
name: '{{ec2_group_name}}'
description: '{{ec2_group_description}}'
ec2_region: '{{ec2_region}}'
ec2_access_key: '{{ec2_access_key}}'
ec2_secret_key: '{{ec2_secret_key}}'
security_token: '{{security_token}}'
state: present
# set purge_rules to false so we don't get a false positive from previously added rules
purge_rules: false
rules:
- proto: "tcp"
ports:
- 8196
cidr_ipv6: '2001:db00::1/24'
register: result
- name: assert state=present (expected changed=true)
assert:
that:
- 'result.changed'
- 'result.group_id.startswith("sg-")'
# ============================================================
- name: test adding a rule again with a IPv6 CIDR with host bits set (expected changed=false and a warning)
ec2_group:
name: '{{ec2_group_name}}'
description: '{{ec2_group_description}}'
ec2_region: '{{ec2_region}}'
ec2_access_key: '{{ec2_access_key}}'
ec2_secret_key: '{{ec2_secret_key}}'
security_token: '{{security_token}}'
state: present
# set purge_rules to false so we don't get a false positive from previously added rules
purge_rules: false
rules:
- proto: "tcp"
ports:
- 8196
cidr_ipv6: '2001:db00::1/24'
register: result
- name: assert state=present (expected changed=false and a warning)
assert:
that:
# No way to assert for warnings?
- 'not result.changed'
- 'result.group_id.startswith("sg-")'
# ============================================================
- name: test state=absent (expected changed=true)
ec2_group:

View file

@ -18,134 +18,189 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.compat.tests import unittest
import pytest
from ansible.module_utils.network.common.utils import to_list, sort_list
from ansible.module_utils.network.common.utils import dict_diff, dict_merge
from ansible.module_utils.network.common.utils import conditional, Template
from ansible.module_utils.network.common.utils import to_masklen, to_netmask, to_subnet, to_ipv6_network
from ansible.module_utils.network.common.utils import is_masklen, is_netmask
class TestModuleUtilsNetworkCommon(unittest.TestCase):
def test_to_list():
for scalar in ('string', 1, True, False, None):
assert isinstance(to_list(scalar), list)
def test_to_list(self):
for scalar in ('string', 1, True, False, None):
self.assertTrue(isinstance(to_list(scalar), list))
for container in ([1, 2, 3], {'one': 1}):
assert isinstance(to_list(container), list)
for container in ([1, 2, 3], {'one': 1}):
self.assertTrue(isinstance(to_list(container), list))
test_list = [1, 2, 3]
assert id(test_list) != id(to_list(test_list))
test_list = [1, 2, 3]
self.assertNotEqual(id(test_list), id(to_list(test_list)))
def test_sort(self):
data = [3, 1, 2]
self.assertEqual([1, 2, 3], sort_list(data))
def test_sort():
data = [3, 1, 2]
assert [1, 2, 3] == sort_list(data)
string_data = '123'
self.assertEqual(string_data, sort_list(string_data))
string_data = '123'
assert string_data == sort_list(string_data)
def test_dict_diff(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
def test_dict_diff():
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
result = dict_diff(base, other)
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
# string assertions
self.assertNotIn('one', result)
self.assertNotIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)
result = dict_diff(base, other)
# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertNotIn('key2', result['obj1'])
# string assertions
assert 'one' not in result
assert 'two' not in result
assert result['three'] == 4
assert result['four'] == 4
# list assertions
self.assertEqual(result['l1'], [2, 1])
self.assertNotIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertNotIn('l4', result)
# dict assertions
assert 'obj1' in result
assert 'key1' in result['obj1']
assert 'key2' not in result['obj1']
# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertNotIn('key2', result['obj1'])
# list assertions
assert result['l1'] == [2, 1]
assert 'l2' not in result
assert result['l3'] == [1]
assert 'l4' not in result
# bool assertions
self.assertNotIn('b1', result)
self.assertNotIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])
# nested assertions
assert 'obj1' in result
assert result['obj1']['key1'] == 2
assert 'key2' not in result['obj1']
def test_dict_merge(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
# bool assertions
assert 'b1' not in result
assert 'b2' not in result
assert result['b3']
assert result['b4']
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
result = dict_merge(base, other)
def test_dict_merge():
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))
# string assertions
self.assertIn('one', result)
self.assertIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)
other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))
# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertIn('key2', result['obj1'])
result = dict_merge(base, other)
# list assertions
self.assertEqual(result['l1'], [1, 2, 3])
self.assertIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertIn('l4', result)
# string assertions
assert 'one' in result
assert 'two' in result
assert result['three'] == 4
assert result['four'] == 4
# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertIn('key2', result['obj1'])
# dict assertions
assert 'obj1' in result
assert 'key1' in result['obj1']
assert 'key2' in result['obj1']
# bool assertions
self.assertIn('b1', result)
self.assertIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])
# list assertions
assert result['l1'] == [1, 2, 3]
assert 'l2' in result
assert result['l3'] == [1]
assert 'l4' in result
def test_conditional(self):
self.assertTrue(conditional(10, 10))
self.assertTrue(conditional('10', '10'))
self.assertTrue(conditional('foo', 'foo'))
self.assertTrue(conditional(True, True))
self.assertTrue(conditional(False, False))
self.assertTrue(conditional(None, None))
self.assertTrue(conditional("ge(1)", 1))
self.assertTrue(conditional("gt(1)", 2))
self.assertTrue(conditional("le(2)", 2))
self.assertTrue(conditional("lt(3)", 2))
self.assertTrue(conditional("eq(1)", 1))
self.assertTrue(conditional("neq(0)", 1))
self.assertTrue(conditional("min(1)", 1))
self.assertTrue(conditional("max(1)", 1))
self.assertTrue(conditional("exactly(1)", 1))
# nested assertions
assert 'obj1' in result
assert result['obj1']['key1'] == 2
assert 'key2' in result['obj1']
def test_template(self):
tmpl = Template()
self.assertEqual('foo', tmpl('{{ test }}', {'test': 'foo'}))
# bool assertions
assert 'b1' in result
assert 'b2' in result
assert result['b3']
assert result['b4']
def test_conditional():
assert conditional(10, 10)
assert conditional('10', '10')
assert conditional('foo', 'foo')
assert conditional(True, True)
assert conditional(False, False)
assert conditional(None, None)
assert conditional("ge(1)", 1)
assert conditional("gt(1)", 2)
assert conditional("le(2)", 2)
assert conditional("lt(3)", 2)
assert conditional("eq(1)", 1)
assert conditional("neq(0)", 1)
assert conditional("min(1)", 1)
assert conditional("max(1)", 1)
assert conditional("exactly(1)", 1)
def test_template():
tmpl = Template()
assert 'foo' == tmpl('{{ test }}', {'test': 'foo'})
def test_to_masklen():
assert 24 == to_masklen('255.255.255.0')
def test_to_masklen_invalid():
with pytest.raises(ValueError):
to_masklen('255')
def test_to_netmask():
assert '255.0.0.0' == to_netmask(8)
assert '255.0.0.0' == to_netmask('8')
def test_to_netmask_invalid():
with pytest.raises(ValueError):
to_netmask(128)
def test_to_subnet():
result = to_subnet('192.168.1.1', 24)
assert '192.168.1.0/24' == result
result = to_subnet('192.168.1.1', 24, dotted_notation=True)
assert '192.168.1.0 255.255.255.0' == result
def test_to_subnet_invalid():
with pytest.raises(ValueError):
to_subnet('foo', 'bar')
def test_is_masklen():
assert is_masklen(32)
assert not is_masklen(33)
assert not is_masklen('foo')
def test_is_netmask():
assert is_netmask('255.255.255.255')
assert not is_netmask(24)
assert not is_netmask('foo')
def test_to_ipv6_network():
assert '2001:db8::' == to_ipv6_network('2001:db8::')
assert '2001:0db8:85a3::' == to_ipv6_network('2001:0db8:85a3:0000:0000:8a2e:0370:7334')
assert '2001:0db8:85a3::' == to_ipv6_network('2001:0db8:85a3:0:0:8a2e:0370:7334')