Nonparametric Survival Estimators in Haskell

In today’s post, I walk through a Haskell implementation of two fundamental estimators in survival analysis: the product-limit (Kaplan-Meier) estimator (KM) of the survival curve and the Nelson-Aalen estimator (NA) of the cumulative hazard. While toying around with the monoidimator package a few months ago, I realized that one could implement a data structure such that the data would not need to pre-sorted by time and KM could be evaluated at any point and updated as an online algorithm. My implementation is not necessarily efficient (maybe it is, I don’t know), but I think it is useful for a few reasons:

  • seeing the dual-like relationship of survival curve and the cumulative hazard (and hence the two estimators). I say “dual-like” in the sense in categories with products the dual of a product is the co-product, or sum.
  • exposing statisticians to used to working in a dynamic language like R to the joy (IMO) of Haskell
  • demonstrating property based testing
  • perhaps serve as starting place for survival analysis library for Haskell

Quick Background: the estimators

The product limit (Kaplan-Meier) estimator of the survival curve is typically presented as:

\(\hat{S}(t) = \prod_{i: t_i \leq t} \left(1 - \frac{d_i}{n_i} \right),\)

where \(d_i\) is the number of subjects who had the event at time \(t_i\), \(n_i\) is the number still at risk at that time, and the times \(t\) are the times at which an event of interest occurred (which I’ll call an outcome).

Using the same definitions, the Nelson-Aalen estimator of the cumulative hazard can be written as:

\(\hat{H}(t) = \sum_{i: t_i \leq t} \left(\frac{d_i}{n_i} \right).\)

In order compute either estimator, you need to have data sorted and summarized by times of the outcomes. Often survival data is presented as pairs \((Y, \delta)\) where \(Y\) is the time of the outcome if \(\delta = 1\) and the time of the outcome is known to be \(\geq Y\) when \(\delta = 0\). This how I’ll represent the data in my Haskell program. But rather than using an indicator function, I’ll encode whether the time is an outcome or censored directly in the type.

The Program

Event Data

First, I define the main data type that I call an Event. I parameterize Event by a type variable a, meaning I can create more specific types such as Event Int or Event Double.

module Estimator where

data Event a = 
      Outcome a 
    | Censor  a
    deriving (Show, Eq)

Now, I have a container for the typical survival data by following equivalences: \((x, \delta = 1) \simeq\) Outcome x and \((x, \delta = 0) \simeq\) Censor x. Now I need a way to order Events in particular to handle the case that an outcome and an censor occur at the same time. To do this, I make Event an instance of the Ord typeclass and encode the convention that - in the case of ties - an outcome comes before censoring:

instance (Ord a) => Ord (Event a) where
  (Outcome x) `compare` (Outcome y) = x `compare` y
  (Censor  x) `compare` (Censor y) = x `compare` y
  compare (Outcome x) (Censor y)
    | x <= y = LT
    | x > y  = GT
  compare (Censor x) (Outcome y)
    | x < y   = LT
    | x >= y  = GT

I include a few utility functions for working with Events. Note the use and usefulness of pattern matching in the function signatures.

getTime :: Event a -> a
getTime (Outcome x) = x
getTime (Censor x)  = x

isOutcome :: Event a -> Bool
isOutcome (Outcome _) = True
isOutcome _ = False

isCensor :: Event a -> Bool
isCensor = not.isOutcome

Tracker Data

I imagine processing streams of possibly unordered Events. I need a way to keep counts of the types of events and when they occurred. Basically, I want to create some form of counting process. To track what I observe, I store information in a Tracker, which stores:

  • getEvent: the Event (i.e. the type and time). Based on the ordering of the sum types defined above, for two events at the same time, Outcome \(\leq\) Censor, and a Tracker also follows this ordering.
  • getOCount: the number of times an outcome has occurred at this time
  • getCCount: the number of times a censor event occurred at this time
  • getAtRisk: the number at risk at this time.
data Tracker a = Tracker {
    getEvent  :: Event a
  , getOCount :: Int
  , getCCount :: Int
  , getAtRisk :: Int
} deriving (Show, Eq)

instance (Ord a) => Ord (Tracker a) where
  compare x y = compare (getEvent x) (getEvent y)

Summing Trackers

I use a recursive algorithm for constructing my desired data structure, a list of Trackers. The gist is to box an Event into a new Tracker then recursively “sum” existing, ordered Trackers until the new Tracker is less than or equal to an existing Tracker or it reaches the end of the list. I put “sum” in quotes because I need to define what it means to add two Trackers.

Let me get more helper functions out of the way first. The function getTotalCount simply returns the number of outcomes and censored for a Tracker. The initTracker function initialized a Tracker from an Event. You can see that for an Outcome the outcome count is set to 1 (and censor count to 0), and conversely for a Censor. The setAtRisk function replaces the at risk count. (All this getting and setting points to lens as a potentially useful abstraction for future iterations)

getTotalCount :: Tracker a -> Int
getTotalCount x = getOCount x + getCCount x

initTracker :: Event a -> Tracker a
initTracker x
   | isOutcome x = Tracker x 1 0 1
   | otherwise   = Tracker x 0 1 1

setAtRisk :: Tracker a -> Int -> Tracker a
setAtRisk (Tracker x o c _) n' = (Tracker x o c n')

Now, I can define how Trackers are added. The most complicated case is when the event times are equal because Outcome needs to replace Censor in the Tracker event slot:

addEqTrackers :: Tracker a -> Tracker a -> Tracker a
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Outcome _) o' c' n') =
     Tracker (Outcome t) (o + o') (c + c') (o + c + n')
addEqTrackers (Tracker (Censor _) o c _) (Tracker (Outcome t') o' c' n') =
     Tracker (Outcome t') (o + o') (c + c') (n' + o + c)
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Censor _) _ c' n') =
     Tracker (Outcome t) o (c + c') (n' + o + c)
addEqTrackers (Tracker (Censor t) _ c _) (Tracker (Censor _) _ c' n') =
     Tracker (Censor t) 0 (c + c') (n' + c)

I’m sure addEqTrackers could be significantly and possibly combined with the following functions; but as an initial pass, I like spelling out all the cases clearly.

addLtTrackers :: Tracker a -> Tracker a -> Tracker a
addLtTrackers x y = setAtRisk x (getAtRisk y + getTotalCount x)

addGtTrackers :: Tracker a -> Tracker a -> Tracker a
addGtTrackers x y = setAtRisk y (getAtRisk y + getTotalCount x)

addGtTrackers' :: Tracker a -> Tracker a -> Tracker a
addGtTrackers' x y = setAtRisk x (getAtRisk y + getTotalCount x - getTotalCount y)

The add*Trackers functions do the work of updating Trackers as needed. For example, if I have two Trackers x and y and x < y, then addLtTrackers updates the at risk of x with the sum of x’s total count with the at risk of y (see the definition of f below to see this in action).

Accumulating Trackers

A recursive function (that I uncreatively named f) uses the add*Trackers functions to take a new Tracker (wrapped in Maybe) and update a List of Trackers.

f :: (Ord a) => Maybe (Tracker a) -> [Tracker a] -> [Tracker a]
f (Just x) [] = [x]
f Nothing []  = []
f Nothing ts  = ts
f (Just x) (t:ts) 
  | lt = [addLtTrackers x t] ++ t:ts
  | eq = [addEqTrackers x t] ++ ts
  | gt = [addGtTrackers x t] ++ (f (Just $ addGtTrackers' x t) ts)
  where cm = compare ((getTime.getEvent) x) ((getTime.getEvent) t)
        lt = (cm == LT)
        eq = (cm == EQ)
        gt = (cm == GT)

This function f does the heavy lifting of creating a List of Trackers ordered by time and tracking counts as new events are processed. All that remains is converting a List of Events to a List of Trackers, which, now that I have f, is a one-liner using a fold pattern:

processEvents :: (Ord a, Num a) => [Event a] -> [Tracker a]
processEvents z = foldl (\x -> \y -> f (Just $ initTracker y) x) [] z

The Estimators

All the above code is for creating a data structure that is a time-ordered list of information needed to compute the KM or NA estimators. Once we have this data structure, computing the estimators is a matter of converting each element of the list to a term in the estimator and taking the product (or sum).

Since survival curves are generally presented at outcome times, I include a function that filters a Tracker list to just Outcomes. This filtration isn’t strictly necessary, but I’ve written the property tests below assuming the filter is in place.

outcomes :: [Tracker a] -> [Tracker a]
outcomes [] = []
outcomes (e:es) 
   | isOutcome (getEvent e) = [e] ++ outcomes es
   | isCensor (getEvent e)  = outcomes es

The eval function evaluates an estimator from a List of Trackers, given two functions (the first two arguments). The first function (Tracker a -> Rational) converts a Tracker to an algebraic term. For example, the KM estimator wants \(1 - d_i/n_i\) at outcome times, which for a Tracker x is basically 1 - (getOCount x)/(getAtRisk x). The second argument, (Rational -> Rational -> Rational), is the algebraic operator. For the KM estimator, this is multiplication (*). The third argument is the starting point of the estimate (keep in mind the estimate is a curve or a list of coordinates). The starting point is \((0, 1)\) for KM and \((0, 0)\) for NA.

eval ::  (Num a) => (Tracker a -> Rational)
      -> (Rational -> Rational -> Rational) 
      -> (a, Rational)
      -> [Tracker a] 
      -> [(a, Rational)]
eval g op start ecs = scanl (\x -> \y -> ((getTime.getEvent) y, op (snd x) (g y))) start (outcomes ecs)

Estimators

The KM and NA estimators are created passing different functions to eval:

toKMterm :: (Num a) => Tracker a -> Rational
toKMterm x
  | isOutcome (getEvent x) && (n /= 0) = 1 - (d/n)
  | isOutcome (getEvent x) && (n == 0) = 0
  | otherwise   = 1
  where d = fromIntegral $ getOCount x
        n = fromIntegral $ getAtRisk x

productLimit :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
productLimit x = eval toKMterm (*) (0, 1) (processEvents x)

toNAterm :: (Num a) => Tracker a -> Rational
toNAterm x
  | isOutcome (getEvent x) && (n /= 0) = d/n
  | isOutcome (getEvent x) && (n == 0) = 0
  | otherwise   = 1
  where d = fromIntegral $ getOCount x
        n = fromIntegral $ getAtRisk x

cumHazard :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
cumHazard x = eval toNAterm (+) (0, 0) (processEvents x) 

I think it’s neat to see how different estimators can be created as specific cases of a more general structure.

Does it work?

The code above compiles, so in that sense yes. The code will run. But I have not fully encoded the logic of the estimators into the type system. For example, I do not enforce the time ordering of the [Tracker a] data structure at the type level. Future versions could do this.

I’ve tested the product-limit estimator against the R survival package and the values agree in the few cases I’ve done. Certainly more comparative testing is warranted. For now, I’m including the property-based tests that I’ve come up with so far. An in-depth review of property-based testing is beyond the scope of this post, but property-based testing should be familiar ground to most statisticians. It’s very similar to checking the operating characteristics of estimators via simulation.

Here I check the following properties by randomly generating data and then checking that the properties hold:

  1. The length of the survival curve (or cumulative hazard) is equal to the number of outcomes + 1 (as I include a start point in the returned list).
  2. The estimators are invariant to shuffling the inputs
  3. The estimators are monotonic (decreasing for the product-limit and increasing for the cumulative hazard)
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck.Modifiers
import Test.QuickCheck
import Estimator
import Data.List (nub)

instance (Num a, Ord a, Arbitrary a) => Arbitrary (Event a) where
    arbitrary = do
        b <- choose (True, False)
        case b of 
            True  -> Outcome <$> abs <$> arbitrary
            False -> Censor  <$> abs <$> arbitrary

newtype ShuffledEvents = ShuffledEvents ([Event Int], [Event Int]) deriving Show
instance Arbitrary ShuffledEvents where
    arbitrary = do
        i <- arbitrary
        o <- shuffle i
        return $ ShuffledEvents (i, o)

prop_length :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_length estimator x =
    (length $ nub $ filter isOutcome x) + 1 === (length $ estimator x)

prop_monotonic_decreasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_decreasing estimator x =
    (length $ nub $ filter isOutcome x) > 1 ==> 
         ( all (\x -> fst x < snd x) (zip (tail z) z) )
    where z = (map snd $ estimator x)

prop_monotonic_increasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_increasing estimator x =
    (length $ nub $ filter isOutcome x) > 1 ==> 
         ( all (\x -> fst x > snd x) (zip (tail z) z) )
    where z = (map snd $ estimator x)

prop_shuffle :: ([Event Int] -> [(Int, Rational)]) -> ShuffledEvents -> Property
prop_shuffle estimator (ShuffledEvents x) = 
  (estimator $ fst x) === (estimator $ snd x) 

main :: IO ()
main = hspec $ do
  describe "productLimit invariants" $ 
    modifyMaxSuccess (*10000) $
    do 
      it "length is number of unique outcomes + 1" $ property (prop_length productLimit)
      it "result is invariant to permuting inputs" $ property (prop_shuffle productLimit)
      it "survival curve is monotonic decreasing" $ property (prop_monotonic_decreasing productLimit)

  describe "cumHazard invariants" $ 
    modifyMaxSuccess (*10000) $
    do 
      it "length is number of unique outcomes + 1" $ property (prop_length cumHazard)
      it "result is invariant to permuting inputs" $ property (prop_shuffle cumHazard)
      it "cumulative hazard is monotonic increasing" $ property (prop_monotonic_increasing cumHazard)

Running one million test cases yields:

$ cabal test
Test suite estimator-test: RUNNING...

productLimit invariants
  length is number of unique outcomes + 1
    +++ OK, passed 1000000 tests.
  result is invariant to permuting inputs
    +++ OK, passed 1000000 tests.
  survival curve is monotonic decreasing
    +++ OK, passed 1000000 tests; 872098 discarded.
cumHazard invariants
  length is number of unique outcomes + 1
    +++ OK, passed 1000000 tests.
  result is invariant to permuting inputs
    +++ OK, passed 1000000 tests.
  cumulative hazard is monotonic increasing
    +++ OK, passed 1000000 tests; 872098 discarded.

The reason so many tests are discarded for the monotonicity test is that the Property throws out cases where zero or one Outcomes occurred. Still, the monotonicity property holds in over 200000 randomly generated datasets.

Summary

That’s a lot of code for a blog post, and I haven’t really shown how you actually use the code to analyze data. But I hope I’ve demonstrated how particular estimators can be seen as special cases of higher order functions, given a taste of Haskell, and hinted at the value property-based testing.