Skip to content
 

CPS and beta-reduction

I have finally implemented in Sly a pass to convert the code to Continuation-passing style (CPS), after reading an interesting article by Matt Might. I was not satisfied with the conversion algorithms I saw previously, but the “hybrid CPS conversion” is what I was really looking for. No extra β- or η-redexes are inserted during conversion (actually, an η-redex is introduced when converting conditionals to avoid exponencial code size explosion). I will not say much more about the conversion algorithm because Matt’s articles are excellent.

Using CPS as intermediate language has many advantages in general and specially for Scheme. In general, it is a simpler, more uniform language in which all transfers of control are explicit and all intermediate values are named. For Scheme, it makes all continuations explicit, which can then be captured by call-with-current-continuation. Nevertheless, most optimising compilers that use CPS as intermediate language marks introduced continuation closures as special and allocate them on the usual call/return control stack, even for Scheme, because the stack interacts better with memory caches in current processors. Capturing the current continuation is implemented with more complicated techniques like copying the control stack to the heap.

But how exactly does CPS make it easier to apply optimisations to the code? I am going to illustrate one of the possible optimisations with an example based on the code to solve the N-Queens problem:

(define (nqueens n)
 
  (define (dec-to n)
    (let loop ((i n) (l '()))
      (if (= i 0) l (loop (- i 1) (cons i l)))))
 
  (define (try x y z)
    (if (null? x)
      (if (null? y)
        1
        0)
      (+ (if (ok? (car x) 1 z)
           (try (append (cdr x) y) '() (cons (car x) z))
           0)
         (try (cdr x) (cons (car x) y) z))))
 
  (define (ok? row dist placed)
    (if (null? placed)
      #t
      (and (not (= (car placed) (+ row dist)))
           (not (= (car placed) (- row dist)))
           (ok? row (+ dist 1) (cdr placed)))))
 
  (try (dec-to n) '() '()))

For now I will focus on the ok? procedure. After macro expansion we have:

(letrec ((ok? (lambda (row dist placed)
                   (if (null? placed)
                       #t
                       ((lambda (temp1)
                          (if temp1
                              ((lambda (temp2)
                                 (if temp2
                                     (ok? row (+ dist 1) (cdr placed))
                                     temp2))
                               (not (= (car placed) (- row dist))))
                              temp1))
                        (not (= (car placed) (+ row dist)))))))))

The expansion of the and macro has created some procedures that are applied to their arguments right away, and not needed otherwise. Now let us convert this code to CPS:

(letrec ((ok? (lambda (k1 row dist placed)
                (letrec ((k2 (lambda (r1)
                                 (if r1
                                     (k1 #t)
                                     (letrec ((l1 (lambda (k3 temp1)
                                                      (if temp1
                                                          (letrec ((l2 (lambda (k4 temp2)
                                                                           (if temp2
                                                                               (letrec ((k5 (lambda (r2)
                                                                                                (letrec ((k6 (lambda (r3)
                                                                                                                 (ok? k4 row r2 r3))))
                                                                                                  (cdr k6 placed)))))
                                                                                 (+ k5 dist 1))
                                                                               (k4 temp2)))))
                                                            (letrec ((k7 (lambda (r4)
                                                                             (letrec ((k8 (lambda (r5)
                                                                                              (letrec ((k9 (lambda (r6)
                                                                                                               (letrec ((k10
                                                                                                                         (lambda (r7)
                                                                                                                           (l2 k3 r7))))
                                                                                                                 (not k10 r6)))))
                                                                                                (= k9 r4 r5)))))
                                                                               (- k8 row dist)))))
                                                              (car k7 placed)))
                                                          (k3 temp1)))))
                                       (letrec ((k11 (lambda (r8)
                                                        (letrec ((k12 (lambda (r9)
                                                                         (letrec ((k13 (lambda (r10)
                                                                                          (letrec ((k14 (lambda (r11)
                                                                                                           (l1 k1 r11))))
                                                                                            (not k14 r10)))))
                                                                           (= k13 r8 r9)))))
                                                          (+ k12 row dist)))))
                                         (car k11 placed)))))))
                  (null? k2 placed))))))

The code has increased in size but is simpler and easier to analyse and manipulate. For instance, in the code above we see that l2 is a procedure that is called only once, and is not passed as an argument to any other procedure. We can then replace the only call site of l2 with the body of the procedure, provided that we substitute the formal parameters for the actual arguments. This inlining is called β-reduction. Since no argument in the CPS can have side effects, β-reduction is sound:

(letrec ((ok? (lambda (k1 row dist placed)
                (letrec ((k2 (lambda (r1)
                                 (if r1
                                     (k1 #t)
                                     (letrec ((l1 (lambda (k3 temp1)
                                                    (if temp1
                                                        (letrec ((k7 (lambda (r4)
                                                                       (letrec ((k8 (lambda (r5)
                                                                                      (letrec ((k9 (lambda (r6)
                                                                                                     (letrec ((k10
                                                                                                               (lambda (r7)
                                                                                                                 (if r7
                                                                                                                     (letrec ((k5 (lambda (r2)
                                                                                                                                    (letrec ((k6 (lambda (r3)
                                                                                                                                                   (ok? k3 row r2 r3))))
                                                                                                                                      (cdr k6 placed)))))
                                                                                                                       (+ k5 dist 1))
                                                                                                                     (k3 r7)))))
                                                                                                       (not k10 r6)))))
                                                                                        (= k9 r4 r5)))))
                                                                         (- k8 row dist)))))
                                                          (car k7 placed))
                                                          (k3 temp1)))))
                                       (letrec ((k11 (lambda (r8)
                                                        (letrec ((k12 (lambda (r9)
                                                                         (letrec ((k13 (lambda (r10)
                                                                                          (letrec ((k14 (lambda (r11)
                                                                                                           (l1 k1 r11))))
                                                                                            (not k14 r10)))))
                                                                           (= k13 r8 r9)))))
                                                          (+ k12 row dist)))))
                                         (car k11 placed)))))))
                  (null? k2 placed))))))

After this we got rid of l2, which means less one closure to allocate and call. We can do the same with l1:

(letrec ((ok? (lambda (k1 row dist placed)
                (letrec ((k2 (lambda (r1)
                                 (if r1
                                     (k1 #t)
                                     (letrec ((k11 (lambda (r8)
                                                     (letrec ((k12 (lambda (r9)
                                                                     (letrec ((k13 (lambda (r10)
                                                                                     (letrec ((k14 (lambda (r11)
                                                                                                     (if r11
                                                                                                         (letrec ((k7 (lambda (r4)
                                                                                                                        (letrec ((k8 (lambda (r5)
                                                                                                                                       (letrec ((k9 (lambda (r6)
                                                                                                                                                      (letrec ((k10
                                                                                                                                                                (lambda (r7)
                                                                                                                                                                  (if r7
                                                                                                                                                                      (letrec ((k5 (lambda (r2)
                                                                                                                                                                                     (letrec ((k6 (lambda (r3)
                                                                                                                                                                                                    (ok? k1 row r2 r3))))
                                                                                                                                                                                       (cdr k6 placed)))))
                                                                                                                                                                        (+ k5 dist 1))
                                                                                                                                                                      (k1 r7)))))
                                                                                                                                                        (not k10 r6)))))
                                                                                                                                         (= k9 r4 r5)))))
                                                                                                                          (- k8 row dist)))))
                                                                                                           (car k7 placed))
                                                                                                         (k1 r11)))))
                                                                                       (not k14 r10)))))
                                                                       (= k13 r8 r9)))))
                                                       (+ k12 row dist)))))
                                       (car k11 placed))))))
                  (null? k2 placed))))))

So we see that two closures previously introduced during macro expansion are now gone. Notice that the continuation passed to the recursive call to ok? is k1, confirming that tail calls do not create new continuations. In the original code we see a procedure called dec-to which is also called only once as an argument to try. That procedure is also β-reduced in the same compiler pass. Other optimisations are also made simple by CPS, and I have implemented some of them in the same pass: Dead-variable elimination, removal of unreachable branches and dropping unused arguments to known procedures.

Leave a Reply