Automatic code changing in Python with the ast module

One of the changes that Python 3 brings versus Python 2 is that, if you redefine the __eq__ method of an object then your __hash__ method automatically becomes None. This means that your class is rendered unhashable (and thus unusable in many settings). This makes sense as the hash function is closely related to object equality. But, if you have Python 2 code that is nonetheless working and you are planning to move to Python 3, then automatically adding an __hash__ function to recover object hashbility might be a pragmatically acceptable solution.

This not-so-perfect example sets the stage for a mini-tutorial on traversing a Python source tree, finding classes that have new __eq__ methods but lack a new __hash__ method and then patching the class to have one. This is, thus, mostly an excuse for a brief exploration on the ast module of Python...

The ast module in Python

The code presented here works on Python 3. There are a few differences for Python 2. For instance in Python 2 print is a statement, not an ordinary function. But the gist of things is the same.

The ast module allows you to access the Abstract Syntax Tree of a Python program. This can be useful for program analysis and program transformation, among other things. Lets start with the AST representation of a Hello World program:

#Typical hello world example
print('Hello World')

Getting the AST tree for this is trivial:

import ast
tree = ast.parse(open('hello.py').read())

Now, the tree object might seem daunting at first, but don't let appearances fool you, it is actually quite simple. The type of tree will be _ast.Module. If you ask, by doing tree._fields, for its children (as you ask can to all AST nodes) you will see that a Module will have a body (i.e. tree.body). The body attribute will have, you guessed it, the body of the file:

print(tree.body)
[<_ast.expr object at>]</_ast.expr>

The body of a module is composed of a list of Statements. In our case we have a single Node: an Expr (Expression is another type of node - do not confuse them). Notice that <b>the AST representation will not include the comments</b>, these are "lost".

The only _field of an Expr is value, which is just an indirection to the type of certain statements: In our case we have a Call to the print function. <b>Remember, in Python 3 print is just an ordinary function, not a statement - so this would look different in Python 2.</b>

OK, back to our Call. What is in a call? Well, you call a function with arguments so a call is a function name plus a set of arguments:

print(tree.body[0].value._fields)
('func', 'args', 'keywords', 'starargs', 'kwargs')</pre>

You will notice the func plus a lot of stuff for arguments. This is because, as you know, Python has quite a rich way of passing arguments (positional arguments, named arguments, ...). For now lets concentrate on func and args only. func is a Name with an attribute called id (the function name). args is a list with a single element, a String:

print(tree.body[0].value.func.id)
print
>>>print(tree.body[0].value.args[0].s)
Hello World

This is actually quite simple, here is a graphical representation:

The best way to find all types of nodes is to check the abstract grammar documentation, or you might prefer the Python 2 version.

Second attempt (a function definition)

Lets have a look at the AST for a function definition:

def print_hello(who):
    print('Hello %s' % who)

If you get a tree for this code, you will notice that the body is still composed of a single statement:

print(tree.body)
[<_ast.functiondef object at>]</_ast.functiondef></pre>

At the top level, there is only a single statement, the print function call belongs to the function definition proper. A function definition has name, args, body, decorator_list and returns. Name is a String (print_hello), no more indirections here, easy. args cannot be very simple because it has to accommodate the fairly rich Python syntax for argument specification, so its fields are ('args', 'vararg', 'kwonlyargs', 'kw_defaults', 'kwarg', 'defaults'). In our case we just have a simple positional argument, so we can find it on args.args:

print(tree.body[0].args.args)
[<_ast.arg object at>]</_ast.arg></pre>

The arg object has a arg field (tree.body[0].args.args[0].arg starts to sound a bit ludicrous, but such is life), which is a string (who). As a side note, in Python 3, you can annotate function parameters (so you have a annotation field).

Now, the function body can be found in the body field of the function definition:

print(tree.body[0].body)
[<_ast.expr object at>]</_ast.expr>

To finalise this version, I just want to present the sub-tree inside the print (i.e. the "Hello %s" % who):

OK, before we go to the final version, one very last thing:

#Lets look at two extra properties of the print line...
print(tree.body[0].body[0].lineno, tree.body[0].body[0].col_offset)
2 4

Yes, you can retrieve the line number and the column of a statement...

Third attempt (a class definition)

class Hello:
    def __init__(self, who):
        self.who = who

    def print_hello(self):
        print('Hello %s' % self.who)</pre>

It should not come as a shock that the module only has a single statement:

[<_ast.classdef object at>]</_ast.classdef>

And yes, as you might expect by now, the ClassDef object has a name and a body attribute (things start making sense and are consistent). Indeed most things are as expected: The class has two statements (two FuncDefs). There are only a couple of conceptually new things, and these are visible in the line:

self.who = who

Here we have two new features: the assignment (=) and the compound name (self.who)

Notice that the list of potential targets is a list, this is because you can do things like:

x, y = 1, 2

There is much more than can be said. Processing ASTs requires a recursive mindset, indeed if you are not used to think recursively I suspect that might be your biggest hurdle in doing things with the AST module. And with Python things can get very nested indeed (nested functions, nested classes, lambda functions, ...)

A complete solution

OK, lets switch gears completely and apply this to a concrete case. The Abjad project helps composers build up complex pieces of music notation in an iterative and incremental way. It is currently Python 2 only, and I am volunteering some of my time to help it support Python 3. This is highly-OO, highly well documented and with a massive load of test cases (kudos to the main developers for this). Some of the classes do have __eq__ methods defined, but lack __hash__ methods, so a pragmatic (though not totally rigorous solution) is to add the __hash__ methods required by Python 3.

A caveat here: the processing has to be done in Python 2, so the code below is Python 2. The idea is to generate Python 2 code with hashes that can be automatically translated by 2to3 (no need for monkey patching after 2to3)

First thing, we have to traverse all the files:

def traverse_dir(my_dir):
    content = os.listdir(my_dir)
    for element in content:
        if os.path.isdir(my_dir + os.sep + element):
            traverse_dir(my_dir + os.sep + element)
        elif os.path.isfile(my_dir + os.sep + element) and element.endswith('.py'):
            process_file(my_dir + os.sep + element)</pre>

Nothing special where, just plain traverse of a directory structure.

Now we want to find all classes that have an __eq__ method, but not an __hash__ method:

def get_classes(tree):
    # Will not work for nested classes
    my_classes = []
    for expr in tree.body:
        if type(expr) == ast.ClassDef:
            my_classes.append(expr)
    return my_classes

The function above will return all classes from a list of statements (typically a module body).

def get_class_methods(tree):
    my_methods = []
    for expr in tree.body:
        if type(expr) == ast.FunctionDef:
            my_methods.append(expr)
    return my_methods

The function above will return all function definitions from a ClassDef object.

def process_file(my_file):
    shutil.copyfile(my_file, my_file + '.bak')
    tree = ast.parse(open(my_file).read())
    my_classes = get_classes(tree)
    patches = {}
    for my_class in my_classes:
        methods = get_class_methods(my_class)
        has_eq = '__eq__' in [method.name for method in methods]
        has_hash = '__hash__' in [method.name for method in methods]
        if has_eq and not has_hash:
            lineno = compute_patch(methods)
            patches[lineno] = my_class.name
    patch(my_file, patches)

This is the main loop applied to each file: We get all the available classes; for each class we get all available methods and if there is an __eq__ method with no __hash__ method then a patch is computed and then applied.

The first thing that we need is to know where to patch the code:

def compute_patch(methods):
    names = [method.name for method in methods]
    names.append('__hash__')
    try:
        names.remove('__init__')
    except ValueError:
        pass
    names.sort()
    try:
        method_after = names[names.index('__hash__') + 1]
    except IndexError:
        method_after = names[-2]
    for method in methods:
        if method.name == method_after:
            return method.lineno

The main point here is to decide to which line to apply the patch. It has of course to be inside the class, but we want to be slightly more elegant than that: We want to put the method in lexicographical order with all the others (so, __hash__ would go between __eq__ and __init__). Now we can patch:

def patch(my_file, patches):
    f = open(my_file + '.bak')
    w = open(my_file, 'w')
    lineno = 0
    for l in f:
        lineno += 1
        if lineno in patches:
            w.write("""    def __hash__(self):
        r'''Hashes my class.

        Required to be explicitely re-defined on Python 3 if __eq__ changes.

        Returns integer.
        '''
        return super(%s, self).__hash__()

""" % patches[lineno])
        w.write(l)

Final notes

This could be done in many other ways. Indeed this is not even the standard way (that would be coding a 2to3 fixer). The practical solution is not general (for instance, it does not support nested classes). Also, it has problems with comments (as we lose them in the AST). But, for the practical purpose (patching abjad) it was good enough (You can see the result here). For further reading, you might want to have a look at this stack overflow question.