# Nonparametric Survival Estimators in Haskell

In today’s post, I walk through a Haskell implementation of two fundamental estimators in survival analysis: the product-limit (Kaplan-Meier) estimator (KM) of the survival curve and the Nelson-Aalen estimator (NA) of the cumulative hazard. While toying around with the monoidimator package a few months ago, I realized that one could implement a data structure such that the data would not need to pre-sorted by time and KM could be evaluated at any point and updated as an online algorithm. My implementation is not necessarily efficient (maybe it is, I don’t know), but I think it is useful for a few reasons:

- seeing the dual-like relationship of survival curve and the cumulative hazard (and hence the two estimators). I say “dual-like” in the sense in categories with products the dual of a product is the co-product, or sum.
- exposing statisticians to used to working in a dynamic language like R to the joy (IMO) of Haskell
- demonstrating property based testing
- perhaps serve as starting place for survival analysis library for Haskell

# Quick Background: the estimators

The product limit (Kaplan-Meier) estimator of the survival curve is typically presented as:

\(\hat{S}(t) = \prod_{i: t_i \leq t} \left(1 - \frac{d_i}{n_i} \right),\)

where \(d_i\) is the number of subjects who had the event at time \(t_i\), \(n_i\) is the number still at risk at that time, and the times \(t\) are the times at which an event of interest occurred (which I’ll call an *outcome*).

Using the same definitions, the Nelson-Aalen estimator of the cumulative hazard can be written as:

\(\hat{H}(t) = \sum_{i: t_i \leq t} \left(\frac{d_i}{n_i} \right).\)

In order compute either estimator, you need to have data sorted and summarized by times of the outcomes. Often survival data is presented as pairs \((Y, \delta)\) where \(Y\) is the time of the outcome if \(\delta = 1\) and the time of the outcome is known to be \(\geq Y\) when \(\delta = 0\). This how I’ll represent the data in my Haskell program. But rather than using an indicator function, I’ll encode whether the time is an outcome or censored directly in the type.

# The Program

## Event Data

First, I define the main data type that I call an `Event`

. I parameterize `Event`

by a type variable `a`

, meaning I can create more specific types such as `Event Int`

or `Event Double`

.

```
module Estimator where
data Event a =
Outcome a
| Censor a
deriving (Show, Eq)
```

Now, I have a container for the typical survival data by following equivalences: \((x, \delta = 1) \simeq\) `Outcome x`

and \((x, \delta = 0) \simeq\) `Censor x`

. Now I need a way to order `Event`

s in particular to handle the case that an outcome and an censor occur at the same time. To do this, I make `Event`

an instance of the `Ord`

typeclass and encode the convention that - in the case of ties - an outcome comes before censoring:

```
instance (Ord a) => Ord (Event a) where
(Outcome x) `compare` (Outcome y) = x `compare` y
(Censor x) `compare` (Censor y) = x `compare` y
compare (Outcome x) (Censor y)
| x <= y = LT
| x > y = GT
compare (Censor x) (Outcome y)
| x < y = LT
| x >= y = GT
```

I include a few utility functions for working with `Event`

s. Note the use and usefulness of pattern matching in the function signatures.

```
getTime :: Event a -> a
getTime (Outcome x) = x
getTime (Censor x) = x
isOutcome :: Event a -> Bool
isOutcome (Outcome _) = True
isOutcome _ = False
isCensor :: Event a -> Bool
isCensor = not.isOutcome
```

## Tracker Data

I imagine processing streams of possibly unordered `Event`

s. I need a way to keep counts of the types of events and when they occurred. Basically, I want to create some form of counting process. To track what I observe, I store information in a `Tracker`

, which stores:

`getEvent`

: the`Event`

(i.e. the type and time). Based on the ordering of the sum types defined above, for two events at the same time,`Outcome`

\(\leq\)`Censor`

, and a`Tracker`

also follows this ordering.`getOCount`

: the number of times an outcome has occurred at this time`getCCount`

: the number of times a censor event occurred at this time`getAtRisk`

: the number at risk at this time.

```
data Tracker a = Tracker {
getEvent :: Event a
, getOCount :: Int
, getCCount :: Int
, getAtRisk :: Int
} deriving (Show, Eq)
instance (Ord a) => Ord (Tracker a) where
compare x y = compare (getEvent x) (getEvent y)
```

## Summing Trackers

I use a recursive algorithm for constructing my desired data structure, a list of `Tracker`

s. The gist is to box an `Event`

into a new `Tracker`

then recursively “sum” existing, ordered `Tracker`

s until the new `Tracker`

is less than or equal to an existing `Tracker`

or it reaches the end of the list. I put “sum” in quotes because I need to define what it means to add two `Tracker`

s.

Let me get more helper functions out of the way first. The function `getTotalCount`

simply returns the number of outcomes and censored for a `Tracker`

. The `initTracker`

function initialized a `Tracker`

from an `Event`

. You can see that for an `Outcome`

the outcome count is set to `1`

(and censor count to `0`

), and conversely for a `Censor`

. The `setAtRisk`

function replaces the at risk count. (*All this getting and setting points to lens as a potentially useful abstraction for future iterations*)

```
getTotalCount :: Tracker a -> Int
getTotalCount x = getOCount x + getCCount x
initTracker :: Event a -> Tracker a
initTracker x
| isOutcome x = Tracker x 1 0 1
| otherwise = Tracker x 0 1 1
setAtRisk :: Tracker a -> Int -> Tracker a
setAtRisk (Tracker x o c _) n' = (Tracker x o c n')
```

Now, I can define how `Tracker`

s are added. The most complicated case is when the event times are equal because `Outcome`

needs to replace `Censor`

in the `Tracker`

event slot:

```
addEqTrackers :: Tracker a -> Tracker a -> Tracker a
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Outcome _) o' c' n') =
Tracker (Outcome t) (o + o') (c + c') (o + c + n')
addEqTrackers (Tracker (Censor _) o c _) (Tracker (Outcome t') o' c' n') =
Tracker (Outcome t') (o + o') (c + c') (n' + o + c)
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Censor _) _ c' n') =
Tracker (Outcome t) o (c + c') (n' + o + c)
addEqTrackers (Tracker (Censor t) _ c _) (Tracker (Censor _) _ c' n') =
Tracker (Censor t) 0 (c + c') (n' + c)
```

I’m sure `addEqTrackers`

could be significantly and possibly combined with the following functions; but as an initial pass, I like spelling out all the cases clearly.

```
addLtTrackers :: Tracker a -> Tracker a -> Tracker a
addLtTrackers x y = setAtRisk x (getAtRisk y + getTotalCount x)
addGtTrackers :: Tracker a -> Tracker a -> Tracker a
addGtTrackers x y = setAtRisk y (getAtRisk y + getTotalCount x)
addGtTrackers' :: Tracker a -> Tracker a -> Tracker a
addGtTrackers' x y = setAtRisk x (getAtRisk y + getTotalCount x - getTotalCount y)
```

The `add*Trackers`

functions do the work of updating `Tracker`

s as needed. For example, if I have two `Tracker`

s `x`

and `y`

and `x < y`

, then `addLtTrackers`

updates the at risk of `x`

with the sum of `x`

’s total count with the at risk of `y`

(see the definition of `f`

below to see this in action).

## Accumulating Trackers

A recursive function (that I uncreatively named `f`

) uses the `add*Trackers`

functions to take a new `Tracker`

(wrapped in `Maybe`

) and update a `List`

of `Tracker`

s.

```
f :: (Ord a) => Maybe (Tracker a) -> [Tracker a] -> [Tracker a]
f (Just x) [] = [x]
f Nothing [] = []
f Nothing ts = ts
f (Just x) (t:ts)
| lt = [addLtTrackers x t] ++ t:ts
| eq = [addEqTrackers x t] ++ ts
| gt = [addGtTrackers x t] ++ (f (Just $ addGtTrackers' x t) ts)
where cm = compare ((getTime.getEvent) x) ((getTime.getEvent) t)
lt = (cm == LT)
eq = (cm == EQ)
gt = (cm == GT)
```

This function `f`

does the heavy lifting of creating a `List`

of `Tracker`

s ordered by time and tracking counts as new events are processed. All that remains is converting a `List`

of `Event`

s to a `List`

of `Tracker`

s, which, now that I have `f`

, is a one-liner using a fold pattern:

```
processEvents :: (Ord a, Num a) => [Event a] -> [Tracker a]
processEvents z = foldl (\x -> \y -> f (Just $ initTracker y) x) [] z
```

# The Estimators

All the above code is for creating a data structure that is a time-ordered list of information needed to compute the KM or NA estimators. Once we have this data structure, computing the estimators is a matter of converting each element of the list to a term in the estimator and taking the product (or sum).

Since survival curves are generally presented at outcome times, I include a function that filters a `Tracker`

list to just `Outcome`

s. This filtration isn’t strictly necessary, but I’ve written the property tests below assuming the filter is in place.

```
outcomes :: [Tracker a] -> [Tracker a]
outcomes [] = []
outcomes (e:es)
| isOutcome (getEvent e) = [e] ++ outcomes es
| isCensor (getEvent e) = outcomes es
```

The `eval`

function *evaluates* an estimator from a `List`

of `Tracker`

s, given two functions (the first two arguments). The first function `(Tracker a -> Rational)`

converts a `Tracker`

to an algebraic term. For example, the KM estimator wants \(1 - d_i/n_i\) at outcome times, which for a `Tracker`

x is basically `1 - (getOCount x)/(getAtRisk x)`

. The second argument, `(Rational -> Rational -> Rational)`

, is the algebraic operator. For the KM estimator, this is multiplication (`*`

). The third argument is the starting point of the estimate (keep in mind the estimate is a *curve* or a list of coordinates). The starting point is \((0, 1)\) for KM and \((0, 0)\) for NA.

```
eval :: (Num a) => (Tracker a -> Rational)
-> (Rational -> Rational -> Rational)
-> (a, Rational)
-> [Tracker a]
-> [(a, Rational)]
eval g op start ecs = scanl (\x -> \y -> ((getTime.getEvent) y, op (snd x) (g y))) start (outcomes ecs)
```

## Estimators

The KM and NA estimators are created passing different functions to `eval`

:

```
toKMterm :: (Num a) => Tracker a -> Rational
toKMterm x
| isOutcome (getEvent x) && (n /= 0) = 1 - (d/n)
| isOutcome (getEvent x) && (n == 0) = 0
| otherwise = 1
where d = fromIntegral $ getOCount x
n = fromIntegral $ getAtRisk x
productLimit :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
productLimit x = eval toKMterm (*) (0, 1) (processEvents x)
toNAterm :: (Num a) => Tracker a -> Rational
toNAterm x
| isOutcome (getEvent x) && (n /= 0) = d/n
| isOutcome (getEvent x) && (n == 0) = 0
| otherwise = 1
where d = fromIntegral $ getOCount x
n = fromIntegral $ getAtRisk x
cumHazard :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
cumHazard x = eval toNAterm (+) (0, 0) (processEvents x)
```

I think it’s neat to see how different estimators can be created as specific cases of a more general structure.

# Does it work?

The code above compiles, so in that sense yes. The code will run. But I have not fully encoded the logic of the estimators into the type system. For example, I do not enforce the time ordering of the `[Tracker a]`

data structure at the type level. Future versions could do this.

I’ve tested the product-limit estimator against the R `survival`

package and the values agree in the few cases I’ve done. Certainly more comparative testing is warranted. For now, I’m including the property-based tests that I’ve come up with so far. An in-depth review of property-based testing is beyond the scope of this post, but property-based testing should be familiar ground to most statisticians. It’s very similar to checking the operating characteristics of estimators via simulation.

Here I check the following properties by randomly generating data and then checking that the properties hold:

- The length of the survival curve (or cumulative hazard) is equal to the number of outcomes + 1 (as I include a start point in the returned list).
- The estimators are invariant to shuffling the inputs
- The estimators are monotonic (decreasing for the product-limit and increasing for the cumulative hazard)

```
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck.Modifiers
import Test.QuickCheck
import Estimator
import Data.List (nub)
instance (Num a, Ord a, Arbitrary a) => Arbitrary (Event a) where
arbitrary = do
b <- choose (True, False)
case b of
True -> Outcome <$> abs <$> arbitrary
False -> Censor <$> abs <$> arbitrary
newtype ShuffledEvents = ShuffledEvents ([Event Int], [Event Int]) deriving Show
instance Arbitrary ShuffledEvents where
arbitrary = do
i <- arbitrary
o <- shuffle i
return $ ShuffledEvents (i, o)
prop_length :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_length estimator x =
(length $ nub $ filter isOutcome x) + 1 === (length $ estimator x)
prop_monotonic_decreasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_decreasing estimator x =
(length $ nub $ filter isOutcome x) > 1 ==>
( all (\x -> fst x < snd x) (zip (tail z) z) )
where z = (map snd $ estimator x)
prop_monotonic_increasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_increasing estimator x =
(length $ nub $ filter isOutcome x) > 1 ==>
( all (\x -> fst x > snd x) (zip (tail z) z) )
where z = (map snd $ estimator x)
prop_shuffle :: ([Event Int] -> [(Int, Rational)]) -> ShuffledEvents -> Property
prop_shuffle estimator (ShuffledEvents x) =
(estimator $ fst x) === (estimator $ snd x)
main :: IO ()
main = hspec $ do
describe "productLimit invariants" $
modifyMaxSuccess (*10000) $
do
it "length is number of unique outcomes + 1" $ property (prop_length productLimit)
it "result is invariant to permuting inputs" $ property (prop_shuffle productLimit)
it "survival curve is monotonic decreasing" $ property (prop_monotonic_decreasing productLimit)
describe "cumHazard invariants" $
modifyMaxSuccess (*10000) $
do
it "length is number of unique outcomes + 1" $ property (prop_length cumHazard)
it "result is invariant to permuting inputs" $ property (prop_shuffle cumHazard)
it "cumulative hazard is monotonic increasing" $ property (prop_monotonic_increasing cumHazard)
```

Running one million test cases yields:

```
$ cabal test
Test suite estimator-test: RUNNING...
productLimit invariants
length is number of unique outcomes + 1
+++ OK, passed 1000000 tests.
result is invariant to permuting inputs
+++ OK, passed 1000000 tests.
survival curve is monotonic decreasing
+++ OK, passed 1000000 tests; 872098 discarded.
cumHazard invariants
length is number of unique outcomes + 1
+++ OK, passed 1000000 tests.
result is invariant to permuting inputs
+++ OK, passed 1000000 tests.
cumulative hazard is monotonic increasing
+++ OK, passed 1000000 tests; 872098 discarded.
```

The reason so many tests are discarded for the monotonicity test is that the `Property`

throws out cases where zero or one Outcomes occurred. Still, the monotonicity property holds in over 200000 randomly generated datasets.

# Summary

That’s a lot of code for a blog post, and I haven’t really shown how you actually use the code to analyze data. But I hope I’ve demonstrated how particular estimators can be seen as special cases of higher order functions, given a taste of Haskell, and hinted at the value property-based testing.