diff --git a/lib/ansible/module_utils/netcfg.py b/lib/ansible/module_utils/netcfg.py index 28bef48c03..139064d972 100644 --- a/lib/ansible/module_utils/netcfg.py +++ b/lib/ansible/module_utils/netcfg.py @@ -4,7 +4,7 @@ # still belong to the author of the module, and may assign their own license # to the complete work. # -# Copyright (c) 2015 Peter Sprygada, +# (c) 2016 Red Hat Inc. # # Redistribution and use in source and binary forms, with or without modification, # are permitted provided that the following conditions are met: @@ -25,57 +25,20 @@ # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # - -import itertools import re -from ansible.module_utils.six import string_types -from ansible.module_utils.six.moves import zip, zip_longest +from ansible.module_utils.six.moves import zip DEFAULT_COMMENT_TOKENS = ['#', '!', '/*', '*/'] -def to_list(val): - if isinstance(val, (list, tuple)): - return list(val) - elif val is not None: - return [val] - else: - return list() - - -class Config(object): - - def __init__(self, connection): - self.connection = connection - - def __call__(self, commands, **kwargs): - lines = to_list(commands) - return self.connection.configure(lines, **kwargs) - - def load_config(self, commands, **kwargs): - commands = to_list(commands) - return self.connection.load_config(commands, **kwargs) - - def get_config(self, **kwargs): - return self.connection.get_config(**kwargs) - - def save_config(self): - return self.connection.save_config() - class ConfigLine(object): - def __init__(self, text): - self.text = text - self.children = list() - self.parents = list() - self.raw = None - - @property - def line(self): - line = [p.text for p in self.parents] - line.append(self.text) - return ' '.join(line) + def __init__(self, raw): + self.text = str(raw).strip() + self.raw = raw + self._children = list() + self._parents = list() def __str__(self): return self.raw @@ -86,295 +49,217 @@ class ConfigLine(object): def __ne__(self, other): return not self.__eq__(other) + def __getitem__(self, key): + for item in self._children: + if item.text == key: + return item + raise KeyError(key) + + @property + def line(self): + line = self.parents + line.append(self.text) + return ' '.join(line) + + @property + def children(self): + return _obj_to_text(self._children) + + @property + def parents(self): + return _obj_to_text(self._parents) + + @property + def path(self): + config = _obj_to_raw(self._parents) + config.append(self.raw) + return '\n'.join(config) + + def add_child(self, obj): + assert isinstance(obj, ConfigLine), 'child must be of type `ConfigLine`' + self._children.append(obj) + def ignore_line(text, tokens=None): for item in (tokens or DEFAULT_COMMENT_TOKENS): if text.startswith(item): return True -def get_next(iterable): - item, next_item = itertools.tee(iterable, 2) - next_item = itertools.islice(next_item, 1, None) - return zip_longest(item, next_item) - -def parse(lines, indent, comment_tokens=None): - toplevel = re.compile(r'\S') - childline = re.compile(r'^\s*(.+)$') - - ancestors = list() - config = list() - - for line in str(lines).split('\n'): - text = str(re.sub(r'([{};])', '', line)).strip() - - cfg = ConfigLine(text) - cfg.raw = line - - if not text or ignore_line(text, comment_tokens): - continue - - # handle top level commands - if toplevel.match(line): - ancestors = [cfg] - - # handle sub level commands - else: - match = childline.match(line) - line_indent = match.start(1) - level = int(line_indent / indent) - parent_level = level - 1 - - cfg.parents = ancestors[:level] - - if level > len(ancestors): - config.append(cfg) - continue - - for i in range(level, len(ancestors)): - ancestors.pop() - - ancestors.append(cfg) - ancestors[parent_level].children.append(cfg) - - config.append(cfg) - - return config +_obj_to_text = lambda x: [o.text for o in x] +_obj_to_raw = lambda x: [o.raw for o in x] def dumps(objects, output='block'): if output == 'block': - items = [c.raw for c in objects] + item = _obj_to_raw(objects) elif output == 'commands': - items = [c.text for c in objects] - elif output == 'lines': - items = list() - for obj in objects: - line = list() - line.extend([p.text for p in obj.parents]) - line.append(obj.text) - items.append(' '.join(line)) + items = _obj_to_text(objects) else: raise TypeError('unknown value supplied for keyword output') return '\n'.join(items) class NetworkConfig(object): - def __init__(self, indent=None, contents=None, device_os=None): - self.indent = indent or 1 - self._config = list() - self._device_os = device_os - self._syntax = 'block' # block, lines, junos - - if self._device_os == 'junos': - self._syntax = 'junos' + def __init__(self, indent=1, contents=None): + self._indent = indent + self._items = list() if contents: self.load(contents) @property def items(self): - return self._config + return self._items + + def __getitem__(self, key): + for line in self: + if line.text == key: + return line + raise KeyError(key) + + def __iter__(self): + return iter(self._items) def __str__(self): - if self._device_os == 'junos': - return dumps(self.expand_line(self.items), 'lines') - return dumps(self.expand_line(self.items)) + return '\n'.join([c.raw for c in self.items]) - def load(self, contents): - # Going to start adding device profiles post 2.2 - tokens = list(DEFAULT_COMMENT_TOKENS) - if self._device_os == 'sros': - tokens.append('echo') - self._config = parse(contents, indent=4, comment_tokens=tokens) - else: - self._config = parse(contents, indent=self.indent) + def load(self, s): + self._items = self.parse(s) - def load_from_file(self, filename): - self.load(open(filename).read()) + def loadfp(self, fp): + return self.load(open(fp).read()) - def get(self, path): - if isinstance(path, string_types): - path = [path] - for item in self._config: - if item.text == path[-1]: - parents = [p.text for p in item.parents] - if parents == path[:-1]: - return item + def parse(self, lines, comment_tokens=None): + toplevel = re.compile(r'\S') + childline = re.compile(r'^\s*(.+)$') + + ancestors = list() + config = list() + + curlevel = 0 + prevlevel = 0 + + for linenum, line in enumerate(str(lines).split('\n')): + text = str(re.sub(r'([{};])', '', line)).strip() + + cfg = ConfigLine(line) + + if not text or ignore_line(text, comment_tokens): + continue + + # handle top level commands + if toplevel.match(line): + ancestors = [cfg] + prevlevel = curlevel + curlevel = 0 + + # handle sub level commands + else: + match = childline.match(line) + line_indent = match.start(1) + + prevlevel = curlevel + curlevel = int(line_indent / self._indent) + + if (curlevel - 1) > prevlevel: + curlevel = prevlevel + 1 + + parent_level = curlevel - 1 + + cfg._parents = ancestors[:curlevel] + + if curlevel > len(ancestors): + config.append(cfg) + continue + + for i in range(curlevel, len(ancestors)): + ancestors.pop() + + ancestors.append(cfg) + ancestors[parent_level].add_child(cfg) + + config.append(cfg) + + return config def get_object(self, path): for item in self.items: if item.text == path[-1]: - parents = [p.text for p in item.parents] - if parents == path[:-1]: + if item.parents == path[:-1]: return item - def get_section_objects(self, path): - if not isinstance(path, list): - path = [path] + def get_section(self, path): + assert isinstance(path, list), 'path argument must be a list object' obj = self.get_object(path) if not obj: raise ValueError('path does not exist in config') - return self.expand_section(obj) + return self._expand_section(obj) - def search(self, regexp, path=None): - regex = re.compile(r'^%s' % regexp, re.M) - - if path: - parent = self.get(path) - if not parent or not parent.children: - return - children = [c.text for c in parent.children] - data = '\n'.join(children) - else: - data = str(self) - - match = regex.search(data) - if match: - if match.groups(): - values = match.groupdict().values() - groups = list(set(match.groups()).difference(values)) - return (groups, match.groupdict()) - else: - return match.group() - - def findall(self, regexp): - regexp = r'%s' % regexp - return re.findall(regexp, str(self)) - - def expand_line(self, objs): - visited = set() - expanded = list() - for o in objs: - for p in o.parents: - if p not in visited: - visited.add(p) - expanded.append(p) - expanded.append(o) - visited.add(o) - return expanded - - def expand_section(self, configobj, S=None): + def _expand_section(self, configobj, S=None): if S is None: S = list() S.append(configobj) - for child in configobj.children: + for child in configobj._children: if child in S: continue - self.expand_section(child, S) + self._expand_section(child, S) return S - def expand_block(self, objects, visited=None): - items = list() - - if not visited: - visited = set() - - for o in objects: - items.append(o) - visited.add(o) - for child in o.children: - items.extend(self.expand_block([child], visited)) - - return items - - def diff_line(self, other, path=None): - diff = list() + def _diff_line(self, other): + updates = list() for item in self.items: if item not in other: - diff.append(item) - return diff + updates.append(item) + return updates - def diff_strict(self, other, path=None): - diff = list() - for index, item in enumerate(self.items): + def _diff_strict(self, other): + updates = list() + for index, line in enumerate(self._items): try: - if item != other[index]: - diff.append(item) + if line != other._lines[index]: + updates.append(line) except IndexError: - diff.append(item) - return diff + updates.append(line) + return updates - def diff_exact(self, other, path=None): - diff = list() - if len(other) != len(self.items): - diff.extend(self.items) + def _diff_exact(self, other): + updates = list() + if len(other) != len(self._items): + updates.extend(self._items) else: - for ours, theirs in zip(self.items, other): + for ours, theirs in zip(self._items, other): if ours != theirs: - diff.extend(self.items) + updates.extend(self._items) break - return diff + return updates - def difference(self, other, path=None, match='line', replace='line'): + def difference(self, other, match='line', path=None, replace=None): try: - if path and match != 'line': - try: - other = other.get_section_objects(path) - except ValueError: - other = list() - else: - other = other.items - func = getattr(self, 'diff_%s' % match) - updates = func(other, path=path) + meth = getattr(self, '_diff_%s' % match) + updates = meth(other) except AttributeError: - raise - raise TypeError('invalid value for match keyword') + raise TypeError('invalid value for match keyword argument, ' + 'valid values are line, strict, or exact') - if self._device_os == 'junos': - return updates + visited = set() + expanded = list() - if replace == 'block': - parents = list() - for u in updates: - if u.parents is None: - if u not in parents: - parents.append(u) - else: - for p in u.parents: - if p not in parents: - parents.append(p) - - return self.expand_block(parents) - - return self.expand_line(updates) - - def replace(self, patterns, repl, parents=None, add_if_missing=False, - ignore_whitespace=True): - - match = None - - parents = to_list(parents) or list() - patterns = [re.compile(r, re.I) for r in to_list(patterns)] - - for item in self.items: - for regexp in patterns: - text = item.text - if not ignore_whitespace: - text = item.raw - if regexp.search(text): - if item.text != repl: - if parents == [p.text for p in item.parents]: - match = item - break - - if match: - match.text = repl - indent = len(match.raw) - len(match.raw.lstrip()) - match.raw = repl.rjust(len(repl) + indent) - - elif add_if_missing: - self.add(repl, parents=parents) + for item in updates: + for p in item._parents: + if p not in visited: + visited.add(p) + expanded.append(p) + expanded.append(item) + visited.add(item) + return expanded def add(self, lines, parents=None): - """Adds one or lines of configuration - """ - ancestors = list() offset = 0 obj = None ## global config command if not parents: - for line in to_list(lines): + for line in lines: item = ConfigLine(line) item.raw = line if item not in self.items: @@ -384,12 +269,12 @@ class NetworkConfig(object): for index, p in enumerate(parents): try: i = index + 1 - obj = self.get_section_objects(parents[:i])[0] + obj = self.get_section(parents[:i])[0] ancestors.append(obj) except ValueError: # add parent to config - offset = index * self.indent + offset = index * self._indent obj = ConfigLine(p) obj.raw = p.rjust(len(p) + offset) if ancestors: @@ -399,15 +284,15 @@ class NetworkConfig(object): ancestors.append(obj) # add child objects - for line in to_list(lines): + for line in lines: # check if child already exists for child in ancestors[-1].children: if child.text == line: break else: - offset = len(parents) * self.indent + offset = len(parents) * self._indent item = ConfigLine(line) item.raw = line.rjust(len(line) + offset) - item.parents = ancestors + item._parents = ancestors ancestors[-1].children.append(item) self.items.append(item) diff --git a/lib/ansible/module_utils/network.py b/lib/ansible/module_utils/network.py index 1443acbf8b..03ba6ea289 100644 --- a/lib/ansible/module_utils/network.py +++ b/lib/ansible/module_utils/network.py @@ -31,7 +31,6 @@ import itertools from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import env_fallback, get_exception from ansible.module_utils.netcli import Cli, Command -from ansible.module_utils.netcfg import Config from ansible.module_utils._text import to_native NET_TRANSPORT_ARGS = dict( @@ -78,6 +77,25 @@ class NetworkError(Exception): super(NetworkError, self).__init__(msg) self.kwargs = kwargs +class Config(object): + + def __init__(self, connection): + self.connection = connection + + def __call__(self, commands, **kwargs): + lines = to_list(commands) + return self.connection.configure(lines, **kwargs) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.connection.load_config(commands, **kwargs) + + def get_config(self, **kwargs): + return self.connection.get_config(**kwargs) + + def save_config(self): + return self.connection.save_config() + class NetworkModule(AnsibleModule): @@ -174,8 +192,3 @@ def register_transport(transport, default=False): def add_argument(key, value): NET_CONNECTION_ARGS[key] = value -def get_module(*args, **kwargs): - # This is a temporary factory function to avoid break all modules - # until the modules are updated. This function *will* be removed - # before 2.2 final - return NetworkModule(*args, **kwargs)