Minimal Parenthesization of Lambda Terms

2024-09-15

Thanks to Max Bernstein for inspiring me to write this down.

When writing a compiler, it's useful to be able to visualize the syntax tree as it gets processed. One simple way of doing so is to serialize the tree to text, using parentheses to specify the structure. The most naive approach is to simply wrap every subexpression in parentheses, which means we get representations that look like

(((((((((1 + 2) + 3) + 4) + 5) + 6) + 7) + 8) + 9) + 10)

even though a human would have probably written the term as

1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10

As terms become bigger, excessive parentheses make it hard to look at a term and quickly understand its structure. Max Bernstein has written about a simple algorithm that finds the minimal parenthesization necessary for terms in an arithmetic expression language. We'll apply the same idea to the simply typed lambda calculus (implemented in OCaml).

The Term Language

Our calculus will have a single base type of machine integers:

type ty =
  | Z64                 (* machine integers *)
  | Arrow of ty * ty    (* functions        *)

We'll use strings as identifiers:

type identifier = string

(* associate an identifier with some value *)
type 'a binding = {
  name : identifier ;
  value : 'a ;
}

We'll include primitive operations for integer addition, subtraction, multiplication, and exponentiation:

type binop =
  | Add
  | Sub
  | Mul
  | Exp

Expressions are defined as follows:

type expression =
  | Lit of int                                  (* integer literal      *)
  | Bin of binop * expression * expression      (* integer arithmetic   *)
  | Var of identifier                           (* variable             *)
  | App of expression * expression              (* function application *)
  | Abs of ty binding * expression              (* function abstraction *)

Representing Types

Let's start by considering the problem in the case of types, which have a relatively simple grammar. A naive implementation might look something like the following:

let print = Printf.sprintf

let rec naive_show_ty : ty -> string = function
  | Z64 -> "Z"
  | Arrow (domain, codomain) ->
      print "(%s -> %s)" (naive_show_ty domain) (naive_show_ty codomain)

This will represent the type Arrow (Z64, Arrow (Z64, Z64)) as "(Z -> (Z -> Z))". If we take function arrows to be right associative (as they conventionally are), then this could have been represented as "Z -> Z -> Z".

In the case of types, the only information we need is whether we are on the left of a function arrow, in which case an arrow will need to be parenthesized. We can pass that information down, recursively. The resulting algorithm is:

let show_ty : ty -> string =
  let rec show (left : bool) = function
    | Z64 -> "Z"
    | Arrow (domain, codomain) ->
        let representation = print "%s -> %s" (show true domain) (show false codomain) in
        if left then print "(%s)" representation else representation in
  show false

Representing Terms

The same kind of idea applies to the more complicated grammar of terms. The key insight is that parenthesization decisions require only local context: parentheses are necessary if and only if the precedence of an operation is less than or equal to the precedence of the parent. For example, the fact that 3 + 4 must be parenthesized in 1 + 2 * (3 + 4) - 5 depends only on the fact that + has lower precedence than *.

As we descend the syntax tree, we'll pass down the precedence of each parent. In other words, we'll pass down the minimum precedence needed to avoid parenthesization at each step. (Associativity may be understood as requiring one less level of precedence on the associative side.)

Here's the precedence and associativity table for our expression language:

OPERATION     PRECEDENCE  ASSOCIATIVITY
lambda        1           n/a
+             2           left
-             2           left
*             3           left
^             4           right
application   5           left

We'll express our text serialization algorithm in terms of nullary, unary, and binary operations with precedence and associativity. To do this, we'll define a data structure that captures the relevant properties of an operation:

type associativity =
  | Left
  | Right

type precedence = int

type operation =
  | Nullary
  | Unary   of precedence * expression
  (* associativity is only applicable in the case of binary operations *)
  | Binary  of precedence * associativity * expression * expression

Mapping our expression language into this form is straightforward:

let structure : expression -> operation = function
  | Lit _ -> Nullary
  | Bin (op, lhs, rhs) ->
      begin match op with
      | Add -> Binary (2, Left , lhs, rhs)
      | Sub -> Binary (2, Left , lhs, rhs)
      | Mul -> Binary (3, Left , lhs, rhs)
      | Exp -> Binary (4, Right, lhs, rhs)
      end
  | Var _ -> Nullary
  | App (f, x) -> Binary (5, Left, f, x)
  | Abs (_, body) -> Unary (1, body)

These precedence and associativity values could be read from some other source here, if we wanted to share the data with a parser.

Notice that lambda abstraction is a unary operator, but it doesn't have a constant textual representation in the way that the arithmetic operators do. To account for this, we'll express the representation of a node as a function:

let node_text : expression -> string = function
  | Lit i -> print "%d" i
  | Bin (Add, _, _) -> " + "
  | Bin (Sub, _, _) -> " - "
  | Bin (Mul, _, _) -> " * "
  | Bin (Exp, _, _) -> " ^ "
  | Var id -> id
  | App _ -> " "
  | Abs ({ name ; value = domain }, _) ->
      print "λ %s : %s . " name (show_ty domain)

At this point, we're ready to write the complete serializer:

let show_expression : expression -> string =

  (* wrap a string in parentheses when a condition is met *)
  let wrap (s : string) (condition : bool) =
    if condition then print "(%s)" s else s in

  let rec show (p : precedence) (expr : expression) =
    let atom = node_text expr in
    match structure expr with
    | Nullary -> atom
    | Unary (p', e) ->
        let s = print "%s%s" atom (show (p' - 1) e) in
        wrap s (p' <= p)
    | Binary (p', assoc, lhs, rhs) ->
        let (left, right) = match assoc with
        | Left  -> (p' - 1, p')
        | Right -> (p', p' - 1) in
        let s = print "%s%s%s" (show left lhs) atom (show right rhs) in
        wrap s (p' <= p) in

  show 0

For unary and binary operators, the pattern is the same: we compare the precedence of the subterm we're looking at (p') to the minimum precedence passed down by the parent (p). When p' isn't greater than the minimum, we wrap the subterm in parentheses. When an operator is left or right associative, we subtract one from the minimum precedence to pass down on the left or right, respectively.

This gives us representations that look like this:

(λ a : Z . λ b : Z . a + b) 1 2 * (λ a : Z . λ b : Z . a - b) 3 4

Resources





recurse center webring