Source code for chameleon.astutil

# -*- coding: utf-8 -*-
#
# Copyright (C) 2008-2009 Edgewall Software
# All rights reserved.
#
# This software is licensed as described in the file COPYING, which
# you should have received as part of this distribution. The terms
# are also available at http://genshi.edgewall.org/wiki/License.
#
# This software consists of voluntary contributions made by many
# individuals. For the exact contribution history, see the revision
# history and logs, available at http://genshi.edgewall.org/log/.

"""Support classes for generating code from abstract syntax trees."""

try:
    import ast
except ImportError:
    from chameleon import ast25 as ast

import sys
import logging
import weakref
import collections

node_annotations = weakref.WeakKeyDictionary()

try:
    node_annotations[ast.Name()] = None
except TypeError:
    logging.debug(
        "Unable to create weak references to AST nodes. " \
        "A lock will be used around compilation loop."
        )

    node_annotations = {}

__docformat__ = 'restructuredtext en'


def annotated(value):
    node = load("annotation")
    node_annotations[node] = value
    return node


def parse(source, mode='eval'):
    return compile(source, '', mode, ast.PyCF_ONLY_AST)


def load(name):
    return ast.Name(id=name, ctx=ast.Load())


def store(name):
    return ast.Name(id=name, ctx=ast.Store())


def param(name):
    return ast.Name(id=name, ctx=ast.Param())


def delete(name):
    return ast.Name(id=name, ctx=ast.Del())


def subscript(name, value, ctx):
    return ast.Subscript(
        value=value,
        slice=ast.Index(value=ast.Str(s=name)),
        ctx=ctx,
        )


def walk_names(target, mode):
    for node in ast.walk(target):
        if isinstance(node, ast.Name) and \
               isinstance(node.ctx, mode):
            yield node.id


def iter_fields(node):
    """
    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
    that is present on *node*.
    """
    for field in node._fields:
        try:
            yield field, getattr(node, field)
        except AttributeError:
            pass


def iter_child_nodes(node):
    """
    Yield all direct child nodes of *node*, that is, all fields that are nodes
    and all items of fields that are lists of nodes.
    """
    for name, field in iter_fields(node):
        if isinstance(field, Node):
            yield field
        elif isinstance(field, list):
            for item in field:
                if isinstance(item, Node):
                    yield item


def walk(node):
    """
    Recursively yield all descendant nodes in the tree starting at *node*
    (including *node* itself), in no specified order.  This is useful if you
    only want to modify nodes in place and don't care about the context.
    """
    todo = collections.deque([node])
    while todo:
        node = todo.popleft()
        todo.extend(iter_child_nodes(node))
        yield node


def copy(source, target):
    target.__class__ = source.__class__
    target.__dict__ = source.__dict__


def swap(root, replacement, name):
    for node in ast.walk(root):
        if (isinstance(node, ast.Name) and
            isinstance(node.ctx, ast.Load) and
            node.id == name):
            assert hasattr(replacement, '_fields')
            node_annotations.setdefault(node, replacement)


def marker(name):
    return ast.Str(s="__%s" % name)


class Node(object):
    """AST baseclass that gives us a convenient initialization
    method. We explicitly declare and use the ``_fields`` attribute."""

    _fields = ()

    def __init__(self, *args, **kwargs):
        assert isinstance(self._fields, tuple)
        self.__dict__.update(kwargs)
        for name, value in zip(self._fields, args):
            setattr(self, name, value)

    def __repr__(self):
        """Poor man's single-line pretty printer."""

        name = type(self).__name__
        return '<%s%s at %x>' % (
            name,
            "".join(" %s=%r" % (name, getattr(self, name, "\"?\""))
                        for name in self._fields),
            id(self)
            )

    def extract(self, condition):
        result = []
        for node in walk(self):
            if condition(node):
                result.append(node)

        return result


class Builtin(Node):
    """Represents a Python builtin.

    Used when a builtin is used internally by the compiler, to avoid
    clashing with a user assignment (e.g. ``help`` is a builtin, but
    also commonly assigned in templates).
    """

    _fields = "id", "ctx"

    ctx = ast.Load()


class Symbol(Node):
    """Represents an importable symbol."""

    _fields = "value",


class Static(Node):
    """Represents a static value."""

    _fields = "value", "name"

    name = None


class Comment(Node):
    _fields = "text", "space", "stmt"

    stmt = None
    space = ""


class TokenRef(Node):
    """Represents a source-code token reference."""

    _fields = "pos", "length"


class ASTCodeGenerator(object):
    """General purpose base class for AST transformations.

    Every visitor method can be overridden to return an AST node that has been
    altered or replaced in some way.
    """

    def __init__(self, tree):
        self.lines_info = []
        self.line_info = []
        self.lines = []
        self.line = ""
        self.last = None
        self.indent = 0
        self.blame_stack = []
        self.visit(tree)

        if self.line.strip():
            self._new_line()

        self.line = None
        self.line_info = None

        # strip trivial lines
        self.code = "\n".join(
            line.strip() and line or ""
            for line in self.lines
            )

    def _change_indent(self, delta):
        self.indent += delta

    def _new_line(self):
        if self.line is not None:
            self.lines.append(self.line)
            self.lines_info.append(self.line_info)
        self.line = ' ' * 4 * self.indent
        if len(self.blame_stack) == 0:
            self.line_info = []
            self.last = None
        else:
            self.line_info = [(0, self.blame_stack[-1],)]
            self.last = self.blame_stack[-1]

    def _write(self, s):
        if len(s) == 0:
            return
        if len(self.blame_stack) == 0:
            if self.last is not None:
                self.last = None
                self.line_info.append((len(self.line), self.last))
        else:
            if self.last != self.blame_stack[-1]:
                self.last = self.blame_stack[-1]
                self.line_info.append((len(self.line), self.last))
        self.line += s

    def flush(self):
        if self.line:
            self._new_line()

    def visit(self, node):
        if node is None:
            return None
        if type(node) is tuple:
            return tuple([self.visit(n) for n in node])
        try:
            self.blame_stack.append((node.lineno, node.col_offset,))
            info = True
        except AttributeError:
            info = False
        visitor = getattr(self, 'visit_%s' % node.__class__.__name__, None)
        if visitor is None:
            raise Exception('No handler for ``%s`` (%s).' % (
                node.__class__.__name__, repr(node)))
        ret = visitor(node)
        if info:
            self.blame_stack.pop()
        return ret

    def visit_Module(self, node):
        for n in node.body:
            self.visit(n)
    visit_Interactive = visit_Module
    visit_Suite = visit_Module

    def visit_Expression(self, node):
        return self.visit(node.body)

    # arguments = (expr* args, identifier? vararg,
    #              identifier? kwarg, expr* defaults)
    def visit_arguments(self, node):
        first = True
        no_default_count = len(node.args) - len(node.defaults)
        for i, arg in enumerate(node.args):
            if not first:
                self._write(', ')
            else:
                first = False
            self.visit(arg)
            if i >= no_default_count:
                self._write('=')
                self.visit(node.defaults[i - no_default_count])
        if getattr(node, 'vararg', None):
            if not first:
                self._write(', ')
            else:
                first = False
            self._write('*' + node.vararg)
        if getattr(node, 'kwarg', None):
            if not first:
                self._write(', ')
            else:
                first = False
            self._write('**' + node.kwarg)

    def visit_arg(self, node):
        self._write(node.arg)

    # FunctionDef(identifier name, arguments args,
    #                           stmt* body, expr* decorators)
    def visit_FunctionDef(self, node):
        self._new_line()
        for decorator in getattr(node, 'decorator_list', ()):
            self._new_line()
            self._write('@')
            self.visit(decorator)
        self._new_line()
        self._write('def ' + node.name + '(')
        self.visit(node.args)
        self._write('):')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)

    # ClassDef(identifier name, expr* bases, stmt* body)
    def visit_ClassDef(self, node):
        self._new_line()
        self._write('class ' + node.name)
        if node.bases:
            self._write('(')
            self.visit(node.bases[0])
            for base in node.bases[1:]:
                self._write(', ')
                self.visit(base)
            self._write(')')
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)

    # Return(expr? value)
    def visit_Return(self, node):
        self._new_line()
        self._write('return')
        if getattr(node, 'value', None):
            self._write(' ')
            self.visit(node.value)

    # Delete(expr* targets)
    def visit_Delete(self, node):
        self._new_line()
        self._write('del ')
        self.visit(node.targets[0])
        for target in node.targets[1:]:
            self._write(', ')
            self.visit(target)

    # Assign(expr* targets, expr value)
    def visit_Assign(self, node):
        self._new_line()
        for target in node.targets:
            self.visit(target)
            self._write(' = ')
        self.visit(node.value)

    # AugAssign(expr target, operator op, expr value)
    def visit_AugAssign(self, node):
        self._new_line()
        self.visit(node.target)
        self._write(' ' + self.binary_operators[node.op.__class__] + '= ')
        self.visit(node.value)

    # Print(expr? dest, expr* values, bool nl)
    def visit_Print(self, node):
        self._new_line()
        self._write('print')
        if getattr(node, 'dest', None):
            self._write(' >> ')
            self.visit(node.dest)
            if getattr(node, 'values', None):
                self._write(', ')
        else:
            self._write(' ')
        if getattr(node, 'values', None):
            self.visit(node.values[0])
            for value in node.values[1:]:
                self._write(', ')
                self.visit(value)
        if not node.nl:
            self._write(',')

    # For(expr target, expr iter, stmt* body, stmt* orelse)
    def visit_For(self, node):
        self._new_line()
        self._write('for ')
        self.visit(node.target)
        self._write(' in ')
        self.visit(node.iter)
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
        if getattr(node, 'orelse', None):
            self._new_line()
            self._write('else:')
            self._change_indent(1)
            for statement in node.orelse:
                self.visit(statement)
            self._change_indent(-1)

    # While(expr test, stmt* body, stmt* orelse)
    def visit_While(self, node):
        self._new_line()
        self._write('while ')
        self.visit(node.test)
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
        if getattr(node, 'orelse', None):
            self._new_line()
            self._write('else:')
            self._change_indent(1)
            for statement in node.orelse:
                self.visit(statement)
            self._change_indent(-1)

    # If(expr test, stmt* body, stmt* orelse)
    def visit_If(self, node):
        self._new_line()
        self._write('if ')
        self.visit(node.test)
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
        if getattr(node, 'orelse', None):
            self._new_line()
            self._write('else:')
            self._change_indent(1)
            for statement in node.orelse:
                self.visit(statement)
            self._change_indent(-1)

    # With(expr context_expr, expr? optional_vars, stmt* body)
    def visit_With(self, node):
        self._new_line()
        self._write('with ')
        self.visit(node.context_expr)
        if getattr(node, 'optional_vars', None):
            self._write(' as ')
            self.visit(node.optional_vars)
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)

    # Raise(expr? type, expr? inst, expr? tback)
    def visit_Raise(self, node):
        self._new_line()
        self._write('raise')
        if not getattr(node, "type", None):
            exc = getattr(node, "exc", None)
            if exc is None:
                return
            self._write(' ')
            return self.visit(exc)
        self._write(' ')
        self.visit(node.type)
        if not node.inst:
            return
        self._write(', ')
        self.visit(node.inst)
        if not node.tback:
            return
        self._write(', ')
        self.visit(node.tback)

    # Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody)
    def visit_Try(self, node):
        self._new_line()
        self._write('try:')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
        if getattr(node, 'handlers', None):
            for handler in node.handlers:
                self.visit(handler)
        self._new_line()

        if getattr(node, 'orelse', None):
            self._write('else:')
            self._change_indent(1)
            for statement in node.orelse:
                self.visit(statement)
            self._change_indent(-1)

        if getattr(node, 'finalbody', None):
            self._new_line()
            self._write('finally:')
            self._change_indent(1)
            for statement in node.finalbody:
                self.visit(statement)
            self._change_indent(-1)

    # TryExcept(stmt* body, excepthandler* handlers, stmt* orelse)
    def visit_TryExcept(self, node):
        self._new_line()
        self._write('try:')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
        if getattr(node, 'handlers', None):
            for handler in node.handlers:
                self.visit(handler)
        self._new_line()
        if getattr(node, 'orelse', None):
            self._write('else:')
            self._change_indent(1)
            for statement in node.orelse:
                self.visit(statement)
            self._change_indent(-1)

    # excepthandler = (expr? type, expr? name, stmt* body)
    def visit_ExceptHandler(self, node):
        self._new_line()
        self._write('except')
        if getattr(node, 'type', None):
            self._write(' ')
            self.visit(node.type)
        if getattr(node, 'name', None):
            if sys.version_info[0] == 2:
                assert getattr(node, 'type', None)
                self._write(', ')
            else:
                self._write(' as ')
            self.visit(node.name)
        self._write(':')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)
    visit_excepthandler = visit_ExceptHandler

    # TryFinally(stmt* body, stmt* finalbody)
    def visit_TryFinally(self, node):
        self._new_line()
        self._write('try:')
        self._change_indent(1)
        for statement in node.body:
            self.visit(statement)
        self._change_indent(-1)

        if getattr(node, 'finalbody', None):
            self._new_line()
            self._write('finally:')
            self._change_indent(1)
            for statement in node.finalbody:
                self.visit(statement)
            self._change_indent(-1)

    # Assert(expr test, expr? msg)
    def visit_Assert(self, node):
        self._new_line()
        self._write('assert ')
        self.visit(node.test)
        if getattr(node, 'msg', None):
            self._write(', ')
            self.visit(node.msg)

    def visit_alias(self, node):
        self._write(node.name)
        if getattr(node, 'asname', None):
            self._write(' as ')
            self._write(node.asname)

    # Import(alias* names)
    def visit_Import(self, node):
        self._new_line()
        self._write('import ')
        self.visit(node.names[0])
        for name in node.names[1:]:
            self._write(', ')
            self.visit(name)

    # ImportFrom(identifier module, alias* names, int? level)
    def visit_ImportFrom(self, node):
        self._new_line()
        self._write('from ')
        if node.level:
            self._write('.' * node.level)
        self._write(node.module)
        self._write(' import ')
        self.visit(node.names[0])
        for name in node.names[1:]:
            self._write(', ')
            self.visit(name)

    # Exec(expr body, expr? globals, expr? locals)
    def visit_Exec(self, node):
        self._new_line()
        self._write('exec ')
        self.visit(node.body)
        if not node.globals:
            return
        self._write(', ')
        self.visit(node.globals)
        if not node.locals:
            return
        self._write(', ')
        self.visit(node.locals)

    # Global(identifier* names)
    def visit_Global(self, node):
        self._new_line()
        self._write('global ')
        self.visit(node.names[0])
        for name in node.names[1:]:
            self._write(', ')
            self.visit(name)

    # Expr(expr value)
    def visit_Expr(self, node):
        self._new_line()
        self.visit(node.value)

    # Pass
    def visit_Pass(self, node):
        self._new_line()
        self._write('pass')

    # Break
    def visit_Break(self, node):
        self._new_line()
        self._write('break')

    # Continue
    def visit_Continue(self, node):
        self._new_line()
        self._write('continue')

    ### EXPRESSIONS
    def with_parens(f):
        def _f(self, node):
            self._write('(')
            f(self, node)
            self._write(')')
        return _f

    bool_operators = {ast.And: 'and', ast.Or: 'or'}

    # BoolOp(boolop op, expr* values)
    @with_parens
    def visit_BoolOp(self, node):
        joiner = ' ' + self.bool_operators[node.op.__class__] + ' '
        self.visit(node.values[0])
        for value in node.values[1:]:
            self._write(joiner)
            self.visit(value)

    binary_operators = {
        ast.Add: '+',
        ast.Sub: '-',
        ast.Mult: '*',
        ast.Div: '/',
        ast.Mod: '%',
        ast.Pow: '**',
        ast.LShift: '<<',
        ast.RShift: '>>',
        ast.BitOr: '|',
        ast.BitXor: '^',
        ast.BitAnd: '&',
        ast.FloorDiv: '//'
    }

    # BinOp(expr left, operator op, expr right)
    @with_parens
    def visit_BinOp(self, node):
        self.visit(node.left)
        self._write(' ' + self.binary_operators[node.op.__class__] + ' ')
        self.visit(node.right)

    unary_operators = {
        ast.Invert: '~',
        ast.Not: 'not',
        ast.UAdd: '+',
        ast.USub: '-',
    }

    # UnaryOp(unaryop op, expr operand)
    def visit_UnaryOp(self, node):
        self._write(self.unary_operators[node.op.__class__] + ' ')
        self.visit(node.operand)

    # Lambda(arguments args, expr body)
    @with_parens
    def visit_Lambda(self, node):
        self._write('lambda ')
        self.visit(node.args)
        self._write(': ')
        self.visit(node.body)

    # IfExp(expr test, expr body, expr orelse)
    @with_parens
    def visit_IfExp(self, node):
        self.visit(node.body)
        self._write(' if ')
        self.visit(node.test)
        self._write(' else ')
        self.visit(node.orelse)

    # Dict(expr* keys, expr* values)
    def visit_Dict(self, node):
        self._write('{')
        for key, value in zip(node.keys, node.values):
            self.visit(key)
            self._write(': ')
            self.visit(value)
            self._write(', ')
        self._write('}')

    def visit_Set(self, node):
        self._write('{')
        elts = list(node.elts)
        last = elts.pop()
        for elt in elts:
            self.visit(elt)
            self._write(', ')
        self.visit(last)
        self._write('}')

    # ListComp(expr elt, comprehension* generators)
    def visit_ListComp(self, node):
        self._write('[')
        self.visit(node.elt)
        for generator in node.generators:
            # comprehension = (expr target, expr iter, expr* ifs)
            self._write(' for ')
            self.visit(generator.target)
            self._write(' in ')
            self.visit(generator.iter)
            for ifexpr in generator.ifs:
                self._write(' if ')
                self.visit(ifexpr)
        self._write(']')

    # GeneratorExp(expr elt, comprehension* generators)
    def visit_GeneratorExp(self, node):
        self._write('(')
        self.visit(node.elt)
        for generator in node.generators:
            # comprehension = (expr target, expr iter, expr* ifs)
            self._write(' for ')
            self.visit(generator.target)
            self._write(' in ')
            self.visit(generator.iter)
            for ifexpr in generator.ifs:
                self._write(' if ')
                self.visit(ifexpr)
        self._write(')')

    # Yield(expr? value)
    def visit_Yield(self, node):
        self._write('yield')
        if getattr(node, 'value', None):
            self._write(' ')
            self.visit(node.value)

    comparison_operators = {
        ast.Eq: '==',
        ast.NotEq: '!=',
        ast.Lt: '<',
        ast.LtE: '<=',
        ast.Gt: '>',
        ast.GtE: '>=',
        ast.Is: 'is',
        ast.IsNot: 'is not',
        ast.In: 'in',
        ast.NotIn: 'not in',
    }

    # Compare(expr left, cmpop* ops, expr* comparators)
    @with_parens
    def visit_Compare(self, node):
        self.visit(node.left)
        for op, comparator in zip(node.ops, node.comparators):
            self._write(' ' + self.comparison_operators[op.__class__] + ' ')
            self.visit(comparator)

    # Call(expr func, expr* args, keyword* keywords,
    #                         expr? starargs, expr? kwargs)
    def visit_Call(self, node):
        self.visit(node.func)
        self._write('(')
        first = True
        for arg in node.args:
            if not first:
                self._write(', ')
            first = False
            self.visit(arg)

        for keyword in node.keywords:
            if not first:
                self._write(', ')
            first = False
            # keyword = (identifier arg, expr value)
            if keyword.arg is not None:
                self._write(keyword.arg)
                self._write('=')
            else:
                self._write('**')
            self.visit(keyword.value)
        # Attribute removed in Python 3.5
        if getattr(node, 'starargs', None):
            if not first:
                self._write(', ')
            first = False
            self._write('*')
            self.visit(node.starargs)

        # Attribute removed in Python 3.5
        if getattr(node, 'kwargs', None):
            if not first:
                self._write(', ')
            first = False
            self._write('**')
            self.visit(node.kwargs)
        self._write(')')

    # Repr(expr value)
    def visit_Repr(self, node):
        self._write('`')
        self.visit(node.value)
        self._write('`')

    # Constant(object value)
    def visit_Constant(self, node):
        if node.value is Ellipsis:
            self._write('...')
        else:
            self._write(repr(node.value))

    # Num(object n)
    def visit_Num(self, node):
        self._write(repr(node.n))

    # Str(string s)
    def visit_Str(self, node):
        self._write(repr(node.s))

    def visit_Ellipsis(self, node):
        self._write('...')

    # Attribute(expr value, identifier attr, expr_context ctx)
    def visit_Attribute(self, node):
        self.visit(node.value)
        self._write('.')
        self._write(node.attr)

    # Subscript(expr value, slice slice, expr_context ctx)
    def visit_Subscript(self, node):
        self.visit(node.value)
        self._write('[')
        if isinstance(node.slice, ast.Tuple) and node.slice.elts:
            self.visit(node.slice.elts[0])
            if len(node.slice.elts) == 1:
                self._write(', ')
            else:
                for dim in node.slice.elts[1:]:
                    self._write(', ')
                    self.visit(dim)
        else:
            self.visit(node.slice)
        self._write(']')

    # Slice(expr? lower, expr? upper, expr? step)
    def visit_Slice(self, node):
        if getattr(node, 'lower', None) is not None:
            self.visit(node.lower)
        self._write(':')
        if getattr(node, 'upper', None) is not None:
            self.visit(node.upper)
        if getattr(node, 'step', None) is not None:
            self._write(':')
            self.visit(node.step)

    # Index(expr value)
    def visit_Index(self, node):
        self.visit(node.value)

    # ExtSlice(slice* dims)
    def visit_ExtSlice(self, node):
        self.visit(node.dims[0])
        if len(node.dims) == 1:
            self._write(', ')
        else:
            for dim in node.dims[1:]:
                self._write(', ')
                self.visit(dim)

    # Starred(expr value, expr_context ctx)
    def visit_Starred(self, node):
        self._write('*')
        self.visit(node.value)

    # Name(identifier id, expr_context ctx)
    def visit_Name(self, node):
        self._write(node.id)

    # List(expr* elts, expr_context ctx)
    def visit_List(self, node):
        self._write('[')
        for elt in node.elts:
            self.visit(elt)
            self._write(', ')
        self._write(']')

    # Tuple(expr *elts, expr_context ctx)
    def visit_Tuple(self, node):
        self._write('(')
        for elt in node.elts:
            self.visit(elt)
            self._write(', ')
        self._write(')')

    # NameConstant(singleton value)
    def visit_NameConstant(self, node):
        self._write(str(node.value))

class AnnotationAwareVisitor(ast.NodeVisitor):
    def visit(self, node):
        annotation = node_annotations.get(node)
        if annotation is not None:
            assert hasattr(annotation, '_fields')
            node = annotation

        super(AnnotationAwareVisitor, self).visit(node)

    def apply_transform(self, node):
        if node not in node_annotations:
            result = self.transform(node)
            if result is not None and result is not node:
                node_annotations[node] = result


class NameLookupRewriteVisitor(AnnotationAwareVisitor):
    def __init__(self, transform):
        self.transform = transform
        self.transformed = set()
        self.scopes = [set()]

    def __call__(self, node):
        self.visit(node)
        return self.transformed

    def visit_arg(self, node):
        scope = self.scopes[-1]
        scope.add(node.arg)

    def visit_Name(self, node):
        scope = self.scopes[-1]
        if isinstance(node.ctx, ast.Param):
            scope.add(node.id)
        elif node.id not in scope:
            self.transformed.add(node.id)
            self.apply_transform(node)

    def visit_FunctionDef(self, node):
        self.scopes[-1].add(node.name)

    def visit_alias(self, node):
        name = node.asname if node.asname is not None else node.name
        self.scopes[-1].add(name)

    def visit_Lambda(self, node):
        self.scopes.append(set())
        try:
            self.visit(node.args)
            self.visit(node.body)
        finally:
            self.scopes.pop()


class ItemLookupOnAttributeErrorVisitor(AnnotationAwareVisitor):
    def __init__(self, transform):
        self.transform = transform

    def visit_Attribute(self, node):
        self.generic_visit(node)
        self.apply_transform(node)