node_util.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # -*- coding: utf-8 -*-
  2. """
  3. Part of the astor library for Python AST manipulation.
  4. License: 3-clause BSD
  5. Copyright 2012-2015 (c) Patrick Maupin
  6. Copyright 2013-2015 (c) Berker Peksag
  7. Utilities for node (and, by extension, tree) manipulation.
  8. For a whole-tree approach, see the treewalk submodule.
  9. """
  10. import ast
  11. import itertools
  12. try:
  13. zip_longest = itertools.zip_longest
  14. except AttributeError:
  15. zip_longest = itertools.izip_longest
  16. class NonExistent(object):
  17. """This is not the class you are looking for.
  18. """
  19. pass
  20. def iter_node(node, name='', unknown=None,
  21. # Runtime optimization
  22. list=list, getattr=getattr, isinstance=isinstance,
  23. enumerate=enumerate, missing=NonExistent):
  24. """Iterates over an object:
  25. - If the object has a _fields attribute,
  26. it gets attributes in the order of this
  27. and returns name, value pairs.
  28. - Otherwise, if the object is a list instance,
  29. it returns name, value pairs for each item
  30. in the list, where the name is passed into
  31. this function (defaults to blank).
  32. - Can update an unknown set with information about
  33. attributes that do not exist in fields.
  34. """
  35. fields = getattr(node, '_fields', None)
  36. if fields is not None:
  37. for name in fields:
  38. value = getattr(node, name, missing)
  39. if value is not missing:
  40. yield value, name
  41. if unknown is not None:
  42. unknown.update(set(vars(node)) - set(fields))
  43. elif isinstance(node, list):
  44. for value in node:
  45. yield value, name
  46. def dump_tree(node, name=None, initial_indent='', indentation=' ',
  47. maxline=120, maxmerged=80,
  48. # Runtime optimization
  49. iter_node=iter_node, special=ast.AST,
  50. list=list, isinstance=isinstance, type=type, len=len):
  51. """Dumps an AST or similar structure:
  52. - Pretty-prints with indentation
  53. - Doesn't print line/column/ctx info
  54. """
  55. def dump(node, name=None, indent=''):
  56. level = indent + indentation
  57. name = name and name + '=' or ''
  58. values = list(iter_node(node))
  59. if isinstance(node, list):
  60. prefix, suffix = '%s[' % name, ']'
  61. elif values:
  62. prefix, suffix = '%s%s(' % (name, type(node).__name__), ')'
  63. elif isinstance(node, special):
  64. prefix, suffix = name + type(node).__name__, ''
  65. else:
  66. return '%s%s' % (name, repr(node))
  67. node = [dump(a, b, level) for a, b in values if b != 'ctx']
  68. oneline = '%s%s%s' % (prefix, ', '.join(node), suffix)
  69. if len(oneline) + len(indent) < maxline:
  70. return '%s' % oneline
  71. if node and len(prefix) + len(node[0]) < maxmerged:
  72. prefix = '%s%s,' % (prefix, node.pop(0))
  73. node = (',\n%s' % level).join(node).lstrip()
  74. return '%s\n%s%s%s' % (prefix, level, node, suffix)
  75. return dump(node, name, initial_indent)
  76. def strip_tree(node,
  77. # Runtime optimization
  78. iter_node=iter_node, special=ast.AST,
  79. list=list, isinstance=isinstance, type=type, len=len):
  80. """Strips an AST by removing all attributes not in _fields.
  81. Returns a set of the names of all attributes stripped.
  82. This canonicalizes two trees for comparison purposes.
  83. """
  84. stripped = set()
  85. def strip(node, indent):
  86. unknown = set()
  87. leaf = True
  88. for subnode, _ in iter_node(node, unknown=unknown):
  89. leaf = False
  90. strip(subnode, indent + ' ')
  91. if leaf:
  92. if isinstance(node, special):
  93. unknown = set(vars(node))
  94. stripped.update(unknown)
  95. for name in unknown:
  96. delattr(node, name)
  97. if hasattr(node, 'ctx'):
  98. delattr(node, 'ctx')
  99. if 'ctx' in node._fields:
  100. mylist = list(node._fields)
  101. mylist.remove('ctx')
  102. node._fields = mylist
  103. strip(node, '')
  104. return stripped
  105. class ExplicitNodeVisitor(ast.NodeVisitor):
  106. """This expands on the ast module's NodeVisitor class
  107. to remove any implicit visits.
  108. """
  109. def abort_visit(node): # XXX: self?
  110. msg = 'No defined handler for node of type %s'
  111. raise AttributeError(msg % node.__class__.__name__)
  112. def visit(self, node, abort=abort_visit):
  113. """Visit a node."""
  114. method = 'visit_' + node.__class__.__name__
  115. visitor = getattr(self, method, abort)
  116. return visitor(node)
  117. def allow_ast_comparison():
  118. """This ugly little monkey-patcher adds in a helper class
  119. to all the AST node types. This helper class allows
  120. eq/ne comparisons to work, so that entire trees can
  121. be easily compared by Python's comparison machinery.
  122. Used by the anti8 functions to compare old and new ASTs.
  123. Could also be used by the test library.
  124. """
  125. class CompareHelper(object):
  126. def __eq__(self, other):
  127. return type(self) == type(other) and vars(self) == vars(other)
  128. def __ne__(self, other):
  129. return type(self) != type(other) or vars(self) != vars(other)
  130. for item in vars(ast).values():
  131. if type(item) != type:
  132. continue
  133. if issubclass(item, ast.AST):
  134. try:
  135. item.__bases__ = tuple(list(item.__bases__) + [CompareHelper])
  136. except TypeError:
  137. pass
  138. def fast_compare(tree1, tree2):
  139. """ This is optimized to compare two AST trees for equality.
  140. It makes several assumptions that are currently true for
  141. AST trees used by rtrip, and it doesn't examine the _attributes.
  142. """
  143. geta = ast.AST.__getattribute__
  144. work = [(tree1, tree2)]
  145. pop = work.pop
  146. extend = work.extend
  147. # TypeError in cPython, AttributeError in PyPy
  148. exception = TypeError, AttributeError
  149. zipl = zip_longest
  150. type_ = type
  151. list_ = list
  152. while work:
  153. n1, n2 = pop()
  154. try:
  155. f1 = geta(n1, '_fields')
  156. f2 = geta(n2, '_fields')
  157. except exception:
  158. if type_(n1) is list_:
  159. extend(zipl(n1, n2))
  160. continue
  161. if n1 == n2:
  162. continue
  163. return False
  164. else:
  165. f1 = [x for x in f1 if x != 'ctx']
  166. if f1 != [x for x in f2 if x != 'ctx']:
  167. return False
  168. extend((geta(n1, fname), geta(n2, fname)) for fname in f1)
  169. return True