A non-mathematician’s introduction to monads

Introduction

This vignette introduces the functional programming concept of monad, without going into much technical detail. {chronicler} is an implementation of a logger monad, but in truth, it is not necessary to know what monads are to use this package. However, if you are curious, read on. A monad is a computation device that offers two things:

(This definition is an oversimplification of the actual definition of a monad, but good enough for our purposes.)

To understand what a monad is, I believe it is useful to explain what sort of problem monads solve.

Suppose for instance that you wish for your functions to provide a log when they’re run. If your function looks like this:

my_sqrt <- function(x){

  sqrt(x)

}

Then you would need to rewrite this function like this:

my_sqrt <- function(x, log = ""){

  list(sqrt(x),
       c(log,
         paste0("Running sqrt with input ", x)))

}

There are two problems with such an implementation:

What do I mean with “these functions don’t compose”? Consider another such function my_log():

my_log <- function(x, log = ""){

  list(log(x),
       c(log,
         paste0("Running log with input ", x)))

}

sqrt() and log() compose, or rather, they can be chained:

10 |>
  sqrt() |>
  log()
#> [1] 1.151293

while this is not true for my_sqrt() and my_log():

10 |>
  my_sqrt() |>
  my_log()
Error in log(x) (from #3) : non-numeric argument to mathematical function

This is because my_log() expects a number, not a list which is what my_sqrt() returns.

A “monad” is what we need to solve these two problems. The first problem, not having to rewrite every function, can be tackled using function factories. Let’s write one for our problem:

log_it <- function(.f, ..., log = NULL){

  fstring <- deparse(substitute(.f))

  function(..., .log = log){

    list(result = .f(...),
         log = c(.log,
                 paste0("Running ", fstring, " with argument ", ...)))
  }
}

We can now create our functions easily:

l_sqrt <- log_it(sqrt)

l_sqrt(10)
#> $result
#> [1] 3.162278
#> 
#> $log
#> [1] "Running sqrt with argument 10"

l_log <- log_it(log)

l_log(10)
#> $result
#> [1] 2.302585
#> 
#> $log
#> [1] "Running log with argument 10"

We can call l_sqrt() and l_log() decorated functions and the values they return monadic values.

The second issue remains though; l_sqrt() and l_log() can’t be composed/chained. To solve this issue, we need another function, called bind():

bind <- function(.l, .f, ...){

  .f(.l$result, ..., .log = .l$log)

}

Using bind(), it is now possible to compose l_sqrt() and l_log():

10 |>
  l_sqrt() |>
  bind(l_log)
#> $result
#> [1] 1.151293
#> 
#> $log
#> [1] "Running sqrt with argument 10"             
#> [2] "Running log with argument 3.16227766016838"

bind() takes care of providing the right arguments to the underlying function. We can check that the result is correct by comparing it the $result value from the returned object to log(sqrt(10)):

log(sqrt(10))
#> [1] 1.151293

This solution of using a function factory and defining a helper function to make the decorated functions compose is what constitutes a monad, but strictly speaking, this is not precisely correct. It can be interesting to see the actual definition from the programming language Haskell, which is a pure functional programming language where monads must be used to solve certain issues:

Monads can be viewed as a standard programming interface to various data or control structures, which is captured by Haskell’s Monad class. All the common monads are members of it:
class Monad m where
  (>>=)  :: m a -> (  a -> m b) -> m b
  (>>)   :: m a ->  m b         -> m b
  return ::   a                 -> m a

(Source: Monad)

This definition is quite cryptic, especially if you don’t know Haskell, but what this means is that a Monad (in Haskell) is something that has three methods:

While we didn’t implement return (also called unit, which is also not a good name), our function factory log_it() does return/unit’s job but it returns m f(a) instead of m a. Using function factories comes more naturally to R users than using return/unit, hence why I did not focus on return/unit. Also, using our function factory, it is easy to implement return/unit:

unit <- log_it(identity)

so return/unit is just the identity() function that went through the function factory. In a sense, the function factory is even more necessary for defining a monad than return/unit.

Finally, you might read sometimes that monads are objects that have a flatmap() method. I think that this definition as well is not strictly correct and very likely an oversimplification. But what is flatmap() anyways? In practical terms, it is equivalent to bind(), but it is how you get there that’s different. To implement flatmap() two additional functions are needed: fmap() and flatten() (which is quite often called join(), but this has nothing to do with joining data frames, so I used flatten() instead).

fmap() is a function that takes a monadic value as an argument and an undecorated function and applies this undecorated function to the monadic value:

fmap <- function(m, f, ...){

  fstring <- deparse(substitute(f))

  list(result = f(m$result, ...),
       log = c(m$log,
               paste0("fmapping ", fstring, " with arguments ", paste0(m$result, ..., collapse = ","))))
}

Let’s first define a monadic value:

# Let’s use unit(), which we defined above, for this.

(m <- unit(10))
#> $result
#> [1] 10
#> 
#> $log
#> [1] "Running identity with argument 10"

Let’s now use fmap() to apply a non-decorated function to m:

fmap(m, log)
#> $result
#> [1] 2.302585
#> 
#> $log
#> [1] "Running identity with argument 10" "fmapping log with arguments 10"

Great, now what about flatten() (or join())? Why is that useful? Suppose that instead of log() we used l_log() with fmap() (so we’re using a decorated function instead of an undecorated one):

fmap(m, l_log)
#> $result
#> $result$result
#> [1] 2.302585
#> 
#> $result$log
#> [1] "Running log with argument 10"
#> 
#> 
#> $log
#> [1] "Running identity with argument 10" "fmapping l_log with arguments 10"

As you can see from the output, this produced a nested list, a monadic value where the value is itself a monadic value. We would like flatten()/join() to take care of this for us. So this could be an implementation of flatten():

flatten <- function(m){

  list(result = m$result$result,
       log = c(m$log))

}

Let’s try now:

flatten(fmap(m, l_log))
#> $result
#> [1] 2.302585
#> 
#> $log
#> [1] "Running identity with argument 10" "fmapping l_log with arguments 10"

Great! Now, as explained earlier, flatmap() and bind() are the same thing. But we have implemented flatten() and fmap(), so how do these two functions relate to flatmap()? It turns out that flatmap() is the composition of flatten() and fmap():

# I first define a composition operator for functions
`%.%` <- \(f,g)(function(...)(f(g(...))))

# I now compose flatten() and fmap()
# flatten %.% fmap is read as "flatten after fmap"
flatmap <- flatten %.% fmap

So this means that we can now replace:

10 |>
  l_sqrt() |>
  bind(l_log)
#> $result
#> [1] 1.151293
#> 
#> $log
#> [1] "Running sqrt with argument 10"             
#> [2] "Running log with argument 3.16227766016838"

by:

10 |>
  l_sqrt() |>
  flatmap(l_log)
#> $result
#> [1] 1.151293
#> 
#> $log
#> [1] "Running sqrt with argument 10"                 
#> [2] "fmapping l_log with arguments 3.16227766016838"

and we get the same result (well, not quite, since the log is different). I prefer introducing monads using bind(), because bind() comes as a natural solution to the problem of decorated functions not composing. Not so with flatmap(), but in some applications it might be easier to first define flatten() and join() and get flatmap() instead of trying to write bind() directly, so it’s good to know both approaches.

Before continuing with the final part of this introduction, I just want to share with you that lists are also monads. We have everything we need: as.list() is unit(), purrr::map() is fmap() and purrr::flatten() is flatten(). This means we can obtain flatmap() from composing purrr::flatten() and purrr::map():

# Since I'm using `{purrr}`, might as well use purrr::compose() instead of my own implementation
flatmap_list <- purrr::compose(purrr::flatten, purrr::map)

# Functions that return lists: they don't compose!
# no worries, we implemented `flatmap_list()`
list_sqrt <- \(x)(as.list(sqrt(x)))
list_log <- \(x)(as.list(log(x)))

10 |>
  list_sqrt() |>
  flatmap_list(list_log)
#> [[1]]
#> [1] 1.151293

(thanks to @armcn_ for showing me this)

In sum, monads are useful when you need values to also carry something more with them. This something can be a log, as shown here, but there are many examples. For another example of a monad implemented as an R package, see the maybe monad. {chronicle} actually takes advantage of the {maybe} package and uses the maybe monad to handle cases where functions fail. I provide a short introduction to the maybe monad in the Maybe monad vignette.

Monadic laws

Monads need to satisfy the so-called “monadic laws”. We’re going to verify if the monad implemented in {chronicler} satisfies these monadic laws.

First law

The first law states that passing a monadic value to a monadic function using bind() (or in the case of the {chronicler} package bind_record()) or passing a value to a monadic function is the same.

a <- as_chronicle(10)
r_sqrt <- record(sqrt)

test_that("first monadic law", {
  expect_equal(bind_record(a, r_sqrt)$value, r_sqrt(10)$value)
})
#> Test passed

Turns out that this is not quite the case here; the logs of the two objects will be slightly different. So I only check the value.

Second law

The second law states that binding a monadic value to return() (called as_chronicle() in this package, in other words, the function that coerces values to chronicler objects) does nothing. Here again we have an issue with the log, that’s why I focus on the value:

test_that("second monadic law", {
  expect_equal(bind_record(a, as_chronicle)$value, a$value)
})
#> Test passed

Third law

The third law is about associativity; applying monadic functions successively or composing them first gives the same result.

a <- as_chronicle(10)

r_sqrt <- record(sqrt)
r_exp <- record(exp)
r_mean <- record(mean)

test_that("third monadic law", {
  expect_equal(
  (
    (bind_record(a, r_sqrt)) |>
   bind_record(r_exp)
  )$value,
  (
    a |>
    (\(x) bind_record(x, r_sqrt) |> bind_record(r_exp))()
  )$value
  )
})
#> Test passed

flatmap() for chronicle objects

For exhaustivity’s sake, I check that I can get flatmap_record() by composing flatten_record() and fmap_record():


r_sqrt <- record(sqrt)
r_exp <- record(exp)
r_mean <- record(mean)

a <- 1:10 |>
  r_sqrt() |>
  bind_record(r_exp) |>
  bind_record(r_mean)

flatmap_record <- purrr::compose(flatten_record, fmap_record)

b <- 1:10 |>
  r_sqrt() |>
  flatmap_record(r_exp) |>
  flatmap_record(r_mean)

identical(a$value, b$value)
#> [1] TRUE