Why continuation-passing style works, and the Cont monad

May 02, 2022
« Previous post   Next post »

And now in posts that absolutely no one asked for: let's talk about the Cont monad!

Despite having read plenty of articles on the subject and written non-trivial code in continuation-passing style, I always felt there was something I didn't quite "get" about CPS. I'd heard a lot of platitudes about how continuations "reify return addresses," that call/cc allows you to "time travel" in your program. But that still didn't tell me why any of this worked. Those platitudes just served to make me more confused; I couldn't get an intuitive sense of where the "power" of CPS was coming from.

Similarly, I had brushed against the Cont monad before; I had a vague knowledge that it was related to CPS, but didn't really understand what it was for or how exactly the two dovetailed.

What follows is what I've found on my journey to understand these two things that had been eluding me. For a long time I was happy to ignore understanding CPS more deeply, since it's completely irrelevant for writing useful code. But I always had a vague sense that understanding it would teach me a lot, that there was some enlightenment to acquire. I figured that now is as good a time as any to reach enlightenment, so recently I decided to buckle down and stretch my brain around these ideas.

We're not going to be going over what continuation-passing style is; instead, this is specifically about the Cont monad, how it's implemented, and how it relates to CPS. If you'd like an introduction to CPS, check out this excellent overview of usages by Ziyang Liu. That post provides a cursory look at the Cont monad, but doesn't dive deeply into how it works; hence where this post comes in.

An insight for why CPS works

Here's the type we're going to be working with. The core of the Cont monad is one single function, callCC.

newtype Cont r a = Cont { runCont :: (a -> r) -> r }

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

It might be helpful to contrast this with the signature of a function written in CPS. Let's use our favorite functional guinea pig, factorial.

factCPS :: Int -> (Int -> r) -> r

If you squint a bit, Cont is just the last part of the type signature, where factCPS takes in a continuation and returns a value. One way I've conceptualized this is that a Cont r a is "finished" computation, like how once factCPS gets its parameter, it's just waiting to know where to put the result value. So a Cont r a is an a waiting for a continuation.

What if we wanted to calculate, say, 5!!? Set aside the fact that it won't fit in an Int,1 how would we write a program to do it in CPS? Well, we'd need the result from the first call to factCPS, which means the second call would have to go inside the continuation we pass...

let fivebangbang = factCPS 5 (\f5 -> factCPS f5 id)

Okay, what about situations where we didn't need to pass the result of one CPS function into another? Say we just wanted to use factCPS to calculate 3! + 5! + 7!. What does that look like?

foo :: (Int -> r) -> r
foo cc =
  factCPS 3 (\f3 ->
    factCPS 5 (\f5 ->
      factCPS 7 (\f7 ->
        cc (f3 + f5 + f7))))

The definition looks... quite similar, actually! All our function calls get chained together with continuations that bind the result of the function, so that at the end all the results are in scope and we can pass the final sum to the current continuation. It's also similar to the way monadic binds work.

In fact, if a CPS expression is written properly, it always takes this form; that is, if you look at the continuation passed to the first function call, that continuation has the second function call in tail-call position, and the continuation passed to the second function call has the third function call in tail-call position, and on and on.

fn1 ... (\x -> fn2 ... (\y -> fn3 ... (\z -> fn4 ... (\w -> ...))))

The fact that CPS code always takes this shape was something I hadn't fully understood about CPS, and was a major realisation for me. This was where a lightbulb went off.

That is,

Every function call in CPS must be in tail-call position.

which means

Every function call has complete control over the value of the entire expression.

Now that's not entirely true, because at least in Haskell, that polymorphic r prevents functions from returning whatever they want; they have to go through an appropriately-typed continuation first. But assuming they do have such a continuation, calling it and returning the value it gives immediately ends the evaluation of the whole expression. And since continuation-passing style functions have access to the current continuation as a value, they can pass it around without ever having to "modify" it by building up a computation, potentially allowing every single function in the callstack to "modify control flow" by calling a different continuation from the current one, and who knows what computations are built up in the continuation a function calls, if any.

Look at that shape above one more time. fn1 obviously determines the result of that whole expression; it's the head of the whole function call. But this is CPS, so if fn1 wants to return a value, it has to call some continuation. Say it calls the (\x -> fn2 ... continuation it's been given to produce a result; fn2 determines the result value of that continuation, which means that once that continuation gets called, fn2 now has control over what gets returned! Every single continuation-passing style function you call gets its chance in turn to decide whether to call the continuation it's been given, or to "hijack" the value of the expression by calling a different continuation instead.

Contrast all this with a more normal definition of foo from above:

fact :: Int -> Int

foo :: Int
foo =
  let
    f3 = fact 3
    f5 = fact 5
    f7 = fact 7
  in f3 + f5 + f7

Here, the calls to fact are not in tail-call position; the "final say" of what gets returned is always in the hands of foo. The calls to fact are isolated and don't interact with each other. Whereas with the CPS definition of foo, once that toplevel call to factCPS happens, factCPS could return anything. It might call the continuation you passed it — but it might not. Perhaps it's closed over some other continuation that it will call instead. Control over the return value gets relinquished in CPS in a way that doesn't happen with "normal" code.

This is what gives continuation-passing style its power: all function calls are in tail-call position, so all of them can fully determine the result. It's why a CPS function can "look like" it changes control flow or jumps back into a function call higher in stack, even though in some sense control flow hasn't changed; at the end of the day it's a sequence of function calls. The function just needs to call a continuation that produces the result as if control flow had changed.

Again, all of this is "obvious" from the construction of CPS, but it wasn't something I had fully internalized yet.

For convenience, I may refer to calling a continuation as "jumping" or "early exit" for the rest of this article, but you should understand that to be shorthand for "a function call which is in tail call position has called a continuation other than its normal one, resulting in a different return value that looks as if it had altered control flow."

Okay, but how does the Cont monad work?

Before we look at Cont proper, here's one fundamental piece of knowledge:

The types a and (a -> r) -> r are isomorphic.

Which we can show by implementing the morphisms between them:

toCPS :: a -> (a -> r) -> r
toCPS x f = f x

-- You'll probably need -XRankNTypes for this to compile
fromCPS :: (forall r. (a -> r) -> r) -> a
fromCPS f = f id

This immediately gives us an intuitive sense of what the Functor/Applicative/Monad instances for the Cont monad should look like.

  Cont r a
≡ (a → r) → r
≡ a
≡ Identity a

So whatever the monad instance for Cont looks like, it should probably function similarly to the Identity monad.

With that in mind, let's just try implementing the typeclass instances for Cont! Here's the definition of the Cont type again, for your convenience.

newtype Cont r a = Cont { runCont :: (a -> r) -> r }

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

Exercise 1: Implement the Functor, Applicative, and Monad instances for Cont.

Don't worry if these are tricky; in some sense they're supposed to be. Writing code in CPS is often very unintuitive. Keep in mind the "shape" of a CPS computation that we outlined above, and think about how to pass the result of one computation as an input to another in a CPS-y way.

Hint

For the most part you can also type-tetris these definitions; the types are generic enough that if a definition typechecks, it's probably correct, regardless of whether you understand it.

In order to "use" the value inside a Cont r a, you have to pass it a continuation. For all three instances, focus on what continuation you'd have to construct so that you'd get the result you want.

Solution

instance Functor (Cont r) where
  fmap f (Cont g) = Cont $ \cc -> g (cc . f)

instance Applicative (Cont r) where
  pure x = Cont $ \cc -> cc x
  (<*>) (Cont mab) (Cont ma) = Cont $ \cc ->
    mab (\f -> ma (\a -> cc (f a)))

instance Monad (Cont r) where
  return = pure
  (>>=) (Cont ma) f = Cont $ \cc ->
    ma (\a -> runCont (f a) cc)

Exercise 2: Implement factCPS and foo from before using Cont. This should be pretty simple.

Solution

factCPS :: Int -> Cont r Int
factCPS 0 = pure 1
factCPS n = (* n) <$> factCPS (n-1)

foo :: Cont r Int
foo = do
  f3 <- factCPS 3
  f5 <- factCPS 5
  f7 <- factCPS 7
  pure (f3 + f5 + f7)

These definitions for factCPS and foo are fairly simple. In fact, they pretty much look like normal, straight-line code! We didn't have to mention the current continuation or construct weird lambdas at all — which is the whole point. We still have the possibility of getting freaky with continuations, but we don't have to constantly mangle our code when we're not using that capability. If we don't do anything weird with continuations, then we're essentially working in the Identity monad, like we expected.

Just to make sure we fully understand what's going on here, let's look at what the monadic code actually turns into, given our definitions for the typeclass instances.

Say we had a do-block like so:

do x <- f1 10
   y <- f2 20
   pure (x+y)

Which desugars into calls to bind:

f1 10 >>= (\x -> f2 20 >>= (\y -> pure (x+y)))

And let's go one step further, substituting in the definitions of pure and (>>=) for Cont:

-- since Cont is a newtype, we elide references to its constructor
-- or to runCont to simplify things

-- substitute the inner bind
f1 10 >>= (\x -> (\cc -> f2 20 (\a -> (\y -> pure (x+y)) a cc)))
-- substitute the outer bind
(\cc -> f1 10 (\a -> (\x ->
  (\cc' -> f2 20 (\a' -> (\y -> pure (x+y)) a' cc'))) a cc))
-- substitute the pure
(\cc -> f1 10 (\a -> (\x ->
  (\cc' -> f2 20 (\a' -> (\y ->
    (\cc'' -> cc'' (x+y))) a' cc'))) a cc))
-- beta reduce the applications with the a's and cc's
(\cc -> f1 10 (\a -> (f2 20 (\a' -> (cc (a+a'))))))
-- renames
(\cc -> f1 10 (\x -> (f2 20 (\y -> cc (x+y)))))

So we can see that what the monad instance produces is the very same CPS sequencing that we wrote by hand when we were just working with (a -> r) -> r! Plug that last definition into GHCi and you'll see that it gives you exactly the expected type: (Int -> r) -> r, i.e. a finished computation waiting for a continuation. The monad instance handles the boring busywork of building up a continuation for you, and you get to write what essentially looks like straight-line code.

Now that the continuations are being abstracted away underneath the monadic instance, however, we don't actually have a way to access the current continuation and pass it around, without which we lose the whole point of transforming our code this way. That's where callCC comes in.

Getting freaky with continuations: callCC

call/cc, otherwise known as call-with-current-continuation, is a rather infamous little function, which, as the name suggests, passes you the current continuation to do whatever you want with. It originated in Scheme, and most of the languages that implement it are functional in nature; a notable exception is Ruby.

Its infamy partly stems from the fact that the continuation that call/cc gives you is truly global; call/cc can be called from anywhere, the continuation you get contains the entire rest of the program from the point at which call/cc is called, and any part of the program can call that continuation to "jump" back to where call/cc was called. This is further exacerbated by the ability to store the continuation that call/cc gives you into a reference or global variable, potentially creating extremely complicated global control flow.

As we'll see, callCC in Haskell isn't quite as extreme, although we'll be able to get pretty close.

But as we noted above, the main point of callCC is simply so we can access the current continuation which has been hidden from us. Let's try implementing it. Here's the type signature of callCC for your convenience:

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a

Let's unpack that a little bit: callCC takes a single parameter, which is a function which itself expects a function. The idea is that that a -> Cont r b function is the current continuation, which callCC will pass to the function it receives. The reason for the inner type being Cont r b instead of Cont r a is so that the continuation can get called anywhere, regardless of what said expression would "normally" return. Should that continuation get called, that "early exits" from the computation that the body function defines. As we've seen, all that's really going on underneath the hood is that every function call is in tail-call position, just that now we have the unmodified current continuation in scope and thus have a different possible way to return a value.

Exercise 3: Implement callCC.

Hint

Once again, you can sort of type-tetris this definition.

Your definition needs access to the current continuation before the body function modifies it, so you'll want to start your definition with something like Cont (\cc -> ...).

Solution

callCC :: ((a -> Cont r b) -> Cont r a) -> Cont r a
callCC f = Cont $ \cc ->
  let Cont g = f (\x -> Cont (const (cc x)))
  in g cc

And with that, we have everything we need to write convenient, continuation-based code! For instance, here's a very small example of making use of early exit:

bar :: Bool -> String -> Cont r String
bar flag name =
  callCC (\cc -> do
    msg <- addGreeting flag name cc
    pure ("Result: " <> msg))

addGreeting :: Bool -> String -> (String -> Cont r String) -> Cont r String
addGreeting flag name k = do
  if flag then k "Flag is set, exiting early"
  else pure ("Hello, " <> name)

λ> runCont (bar False "William") id
>>> "Result: Hello, William"

λ> runCont (bar True "William") id
>>> "Flag is set, exiting early"

Notice how when the flag is set, the followup code to prepend "Result: " gets completely skipped. Which, again, makes sense given that we know that the monadic code is equivalent to the fully tail-call-ized code we looked at before.

Finally, a few last notes on this definition of call/cc in Haskell versus the full generality in other languages. Firstly, with this definition the "scope" of call/cc is limited by the Cont type itself; calling a continuation might "jump" anywhere within a Cont computation, but you can safely assume that it won't jump out of the computation. Secondly, the type of callCC, where it passes the current continuation to a function rather than just returning it directly, is intentional to prevent the current continuation from escaping the scope of that function. That means it's only possible to jump from inside the scope to outside, but it's not possible to jump back in. So you can have early exits, but not loops. Neither of these things are true if we go to the full generality of the ContT monad transformer, but that's a topic for another time.

Wrapping up

I probably shouldn't have to say this, but CPS is wildly impractical for writing real code with. Well, I won't tell you what you should and shouldn't do, but for obvious reasons giving every single function in your callstack the ability to make completely arbitrary control flow decisions at the "global" level is horrible for modularity, and can destroy any ability to understand code in isolation. Yes, things like exceptions and error types for early exit are less powerful than continuations; they are also more tractable to squishy human meat brains.

If you do choose to use CPS/the Cont monad in your own code, then just like with the other monads we've talked about, use the implementation defined in transformers.

That wraps up the Cont monad. Found this useful, or otherwise have comments or questions? Talk to me!

« Previous post   Next post »

Before you close that tab...


Footnotes

↥1 And dependending on how factCPS is implemented, might not terminate either.