A Practical Introduction to the Python Abstract Syntax Tree
Read the "What is an Abstract Syntax Tree?" sectionWhat is an Abstract Syntax Tree?
An Abstract Syntax Tree (AST) is a representation of the source code structure as data. It is called 'abstract' because it removes surface details like comments, formatting or extra parentheses. It can be thought of as a nested dictionary or JSON-like document that describes the code.
(1 + 2) can be represented in different ways:
# AST as a list
[{"type": "id", "value": "+"},
[{"type": "literal", "value": 1},
{"type": "literal", "value": 2}]
]
# AST as a dictionary
{
"type": "Add",
"left": {"type": "Number", "value": 1},
"right": {"type": "Number", "value": 2}
}
Representing code as data allows writing all sorts of software: Abstract Syntax Trees are used in interpreters, compilers (including things like babel in JavaScript), linters (flake8/black/ruff in Python, eslint in JavaScript) and many more tools.
These programs can be surprisingly approachable, at least conceptually.
An interpreter - whose job is to execute code - typically parses the source code into an AST and then evaluates it.
The evaluation step might look like:
def evaluate(node):
"""An ast interpreter for basic arithmetic expressions."""
if node["type"] == "Add":
return evaluate(node["left"]) + evaluate(node["right"])
if node["type"] == "Subtract":
return evaluate(node["left"]) - evaluate(node["right"])
if node["type"] == "Number":
return node["value"]
raise ValueError(f"unsupported node type {node['type']}")
While a linter might do things like:
def lint(node):
if node["type"] == "FunctionDefinition" and len(node["args"]) > 5:
report_error(node.location, "too many parameters in function")
Visualizing ASTs makes them much less mysterious.
https://ast-explorer.dev is a great tool to experiment with them in many languages.
For Python specifically, the ast module can be used for a quick command-line look: python -m ast <<< "1 + 2".
Read the "In Python" sectionIn Python
When executing Python code with CPython:
- the source code is parsed into an AST (a tree of Python objects where nodes are instances of classes like
ast.FunctionDeforast.Call) - the AST is compiled into code objects (sometimes cached on disk as
.pycfiles) - the compiled code is executed by the bytecode interpreter
When people say "the Python interpreter", they usually refer to this whole pipeline1.
The Python standard library directly exposes this process:
import ast
tree = ast.parse("a = 1 + 2; print(a)")
code_object = compile(tree, filename="<ast>", mode="exec")
# this executes the code and prints 3
exec(code_object)
This direct access to the AST means we can modify the tree before calling compile, changing the program before it runs.
These modifications are known as AST transforms.
Read the "AST Transforms" sectionAST Transforms
AST transforms get the source code to behave like different source code.
A fun example is what pytest does to display nice error messages when an assert fails.
Read the "A real-world example: pytest assertions" sectionA real-world example: pytest assertions
Python default behavior on AssertionError is to print:
Traceback (most recent call last):
File "/home/laurent/test/assert.py", line 2, in <module>
assert a == b
AssertionError
Want to know the values of a and b? Add print statements, run the code again.
Pytest is much friendlier:
> assert a == b
E assert 1 == 2
assert.py:3: AssertionError
It is a little magical: we did not change the code, we just used pytest test_module.py instead of python test_module.py.
It is still Python executing the code, only when running with pytest, instead of:
assert a == b
Python executes something like:
try:
assert a == b
except AssertionError:
[pytest-generated code to display a nice error message]
raise
Pytest does that without modifying the source code (your files on disk), by transforming the AST before Python gets to execute it2.
The pytest code for this uses the parse/transform/compile pipeline we described (and is even easier to read than the diagram):
tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)
# Later `co` is executed with `exec`
So all that remains is to write the rewrite_asserts transform logic.
Read the "Mechanics of an AST transform" sectionMechanics of an AST transform
An AST transform takes a tree as input and modifies it. ast.NodeTransformer is a helper class that traverses the tree for us.
An AST is made of many different node types (ast.Assign, ast.Call, ast.Name and many more).
NodeTransformer walks the tree and calls the visit_<node_type> methods when they exist.
The example from the official documentation: replace all variables (name lookups) with data["variable_name"]
class RewriteName(NodeTransformer):
def visit_Name(self, node):
return Subscript(
value=Name(id='data', ctx=Load()),
slice=Constant(value=node.id),
ctx=node.ctx
)
This turns b = a + 1 into data['b'] = data['a'] + 1.
Not too useful in practice, but it shows NodeTransformer in action:
- it walks the syntax tree and calls the
visit_<node_type>method - returning a new node replaces the original in the tree, returning
Nonedeletes it generic_visitis called for nodes that are not handled by a specific method (should be called explicitly to process children of a visited node)- by default nodes are untouched
Read the "Building a pytest-like assert transformer" sectionBuilding a pytest-like assert transformer
Say we want to write a transformer that converts:
assert a == b
into:
try:
assert a == b
except AssertionError:
raise AssertionError(f"a == b failed\na = {a}\nb = {b}")
This achieves a behavior similar to pytest: it improves on Python's assert by showing us the values of variables in the error message.
General advice to write such a transformer:
- We don't need to know all the node types ahead of time. It's easy to pick them up as needed.
- It helps a lot to visualize the source and
target ASTs, with a web tool or
ast.dump. ast.unparsecan be used to convert an AST to source code: useful to test the transformer.
Some moderately ugly visualizations of the source and target ASTs:
Children: Constant, FormattedValue
Here, looking at the code and their trees:
- We want to replace each
Assertnode with aTrynode, that contains the originalAssertnode and has an exception handler with a customRaisenode. - We'll need to construct a few new AST nodes like
JoinedStr,ConstantorFormattedValue. - We'll want the
a == bpart ofassert a == bas a string literal, to include in our custom error message. The assert node exposes this expression via thetestattribute, which we can turn back into a code string withast.unparse.
Along the way we will need to collect the variables used in the Assert node (a and b in our example) so they can be included in the error message.
Variables are represented as Name nodes, where the id attribute is the variable name.
Here's an implementation:
test_runner.py
import ast
class RewriteAssertNodeTransformer(ast.NodeTransformer):
def __init__(self):
# use a stack to collect variables when inside an assert and clear the list after.
self.assert_stack = []
def visit_Assert(self, node: ast.Assert):
self.assert_stack.append([])
# visit children to collect variables
self.generic_visit(node)
name_nodes = self.assert_stack.pop()
assertion_test_as_text = ast.unparse(node.test)
assertion_msg_parts = get_assertion_message_parts(assertion_test_as_text, name_nodes)
assertion_error_message = ast.JoinedStr(values=assertion_msg_parts)
except_handler = ast.ExceptHandler(
type=ast.Name(id="AssertionError", ctx=ast.Load()),
body=[ast.Raise(ast.Call(ast.Name(id="AssertionError", ctx=ast.Load()), [assertion_error_message], []))],
)
return ast.Try(
body=[node],
handlers=[except_handler],
orelse=[],
finalbody=[],
)
def visit_Name(self, node: ast.Name):
# collect information and return node unchanged
if self.assert_stack:
self.assert_stack[-1].append(node)
return node
def get_assertion_message_parts(assertion_test_as_text, name_nodes):
assertion_msg_parts = [ast.Constant(assertion_test_as_text + " failed")]
for name_node in name_nodes:
assertion_msg_parts.append(ast.Constant(f"\n{name_node.id} = "))
assertion_msg_parts.append(ast.FormattedValue(
ast.Name(id=name_node.id, ctx=ast.Load()),
conversion=-1,
format_spec=None)
)
return assertion_msg_parts
if __name__ == "__main__":
source = "assert a == b"
tree = ast.parse(source)
transformed = RewriteAssertNodeTransformer().visit(tree)
# fix_missing_locations is a technicality - we didn't bother defining line/column numbers, but they are required
transformed = ast.fix_missing_locations(transformed)
print(ast.unparse(transformed))
This program transforms the AST for assert a == b and prints code for the transformed AST. The result matches our target source code.
This could be adapted to take a file as input, compile the modified tree and exec it. This would give an interface similar to pytest: test_runner.py test_module.py.
This example is intentionally minimal to remain approachable. If this still feels a bit overwhelming, spending time looking at the ASTs in the explorer tool should help.
An interesting exercise could be to support attribute access (assert a.b == c.d) and indexing into lists (assert a[b] == c[d]).
The real-world pytest transform does a lot more, like showing intermediate values of computations/function calls.
Read the "The code pytest generates" sectionThe code pytest generates
Strictly speaking, pytest does not generate Python code. As we have seen, it transforms the AST instead.
But we can still use ast.unparse to get source code from the transformed AST.
Here's what it looks like for assert a == b (using pytest 7.4.3 with python 3.11.23):
import builtins as @py_builtins
import _pytest.assertion.rewrite as pytest_ar
@py_assert1 = a == b
if not @py_assert1:
@py_format3 = pytest_ar._call_reprcompare(('==',), (@py_assert1,), ('%(py0)s == %(py2)s',), (a, b)) % {
'py0': pytest_ar._saferepr(a) if 'a' in @py_builtins.locals() or pytest_ar._should_repr_global_name(a) else 'a',
'py2': pytest_ar._saferepr(b) if 'b' in @py_builtins.locals() or pytest_ar._should_repr_global_name(b) else 'b'}
@py_format5 = ('' + 'assert %(py4)s') % {'py4': @py_format3}
raise AssertionError(@pytest_ar._format_explanation(@py_format5))
@py_assert1 = None
That's a lot of code for just a == b!
It uses if not @py_assert1: to check the assertion instead of a try/except. The code also looks much more complicated than in our example4, but the transform principles are the same.
Read the "Conclusion" sectionConclusion
Manipulating code as data can make for some fun and powerful tools. Like most metaprogramming techniques, it should be used with great responsibility.
Even if you rarely use the ast module directly in application code, understanding Abstract Syntax Trees makes many of the developer tools we use less mysterious.
Read the "Footnotes" sectionFootnotes
-
Though Python is usually described as 'interpreted', it is both interpreted and compiled. If this feels annoying, this Crafting Interpreters section does a great job of explaining the compiler vs interpreter language. ↩
-
pytest hooks into the import system to do this as modules are loaded, and caches the resulting bytecode. It's easier to think about this on a single file, the extra machinery is only required for multiple files with imports. ↩
-
This original writeup shows the executed code for pytest 2.1. The feature did not look very different 15 years ago! ↩
-
You might notice
@py_assert1is not a valid name for a Python variable! Turns out we cannot use the@character in Python code because the parser does not allow it. But once we're past the parser, everything is allowed. Using@in a variable name guarantees that pytest won't accidentally overwrite a variable in scope. ↩