Nonparametric Survival Estimators in Haskell
In today’s post, I walk through a Haskell implementation of two fundamental estimators in survival analysis: the product-limit (Kaplan-Meier) estimator (KM) of the survival curve and the Nelson-Aalen estimator (NA) of the cumulative hazard. While toying around with the monoidimator package a few months ago, I realized that one could implement a data structure such that the data would not need to pre-sorted by time and KM could be evaluated at any point and updated as an online algorithm. My implementation is not necessarily efficient (maybe it is, I don’t know), but I think it is useful for a few reasons:
- seeing the dual-like relationship of survival curve and the cumulative hazard (and hence the two estimators). I say “dual-like” in the sense in categories with products the dual of a product is the co-product, or sum.
- exposing statisticians to used to working in a dynamic language like R to the joy (IMO) of Haskell
- demonstrating property based testing
- perhaps serve as starting place for survival analysis library for Haskell
Quick Background: the estimators
The product limit (Kaplan-Meier) estimator of the survival curve is typically presented as:
\(\hat{S}(t) = \prod_{i: t_i \leq t} \left(1 - \frac{d_i}{n_i} \right),\)
where \(d_i\) is the number of subjects who had the event at time \(t_i\), \(n_i\) is the number still at risk at that time, and the times \(t\) are the times at which an event of interest occurred (which I’ll call an outcome).
Using the same definitions, the Nelson-Aalen estimator of the cumulative hazard can be written as:
\(\hat{H}(t) = \sum_{i: t_i \leq t} \left(\frac{d_i}{n_i} \right).\)
In order compute either estimator, you need to have data sorted and summarized by times of the outcomes. Often survival data is presented as pairs \((Y, \delta)\) where \(Y\) is the time of the outcome if \(\delta = 1\) and the time of the outcome is known to be \(\geq Y\) when \(\delta = 0\). This how I’ll represent the data in my Haskell program. But rather than using an indicator function, I’ll encode whether the time is an outcome or censored directly in the type.
The Program
Event Data
First, I define the main data type that I call an Event
. I parameterize Event
by a type variable a
, meaning I can create more specific types such as Event Int
or Event Double
.
module Estimator where
data Event a =
Outcome a
| Censor a
deriving (Show, Eq)
Now, I have a container for the typical survival data by following equivalences: \((x, \delta = 1) \simeq\) Outcome x
and \((x, \delta = 0) \simeq\) Censor x
. Now I need a way to order Event
s in particular to handle the case that an outcome and an censor occur at the same time. To do this, I make Event
an instance of the Ord
typeclass and encode the convention that - in the case of ties - an outcome comes before censoring:
instance (Ord a) => Ord (Event a) where
(Outcome x) `compare` (Outcome y) = x `compare` y
(Censor x) `compare` (Censor y) = x `compare` y
compare (Outcome x) (Censor y)
| x <= y = LT
| x > y = GT
compare (Censor x) (Outcome y)
| x < y = LT
| x >= y = GT
I include a few utility functions for working with Event
s. Note the use and usefulness of pattern matching in the function signatures.
getTime :: Event a -> a
getTime (Outcome x) = x
getTime (Censor x) = x
isOutcome :: Event a -> Bool
isOutcome (Outcome _) = True
isOutcome _ = False
isCensor :: Event a -> Bool
isCensor = not.isOutcome
Tracker Data
I imagine processing streams of possibly unordered Event
s. I need a way to keep counts of the types of events and when they occurred. Basically, I want to create some form of counting process. To track what I observe, I store information in a Tracker
, which stores:
getEvent
: theEvent
(i.e. the type and time). Based on the ordering of the sum types defined above, for two events at the same time,Outcome
\(\leq\)Censor
, and aTracker
also follows this ordering.getOCount
: the number of times an outcome has occurred at this timegetCCount
: the number of times a censor event occurred at this timegetAtRisk
: the number at risk at this time.
data Tracker a = Tracker {
getEvent :: Event a
, getOCount :: Int
, getCCount :: Int
, getAtRisk :: Int
} deriving (Show, Eq)
instance (Ord a) => Ord (Tracker a) where
compare x y = compare (getEvent x) (getEvent y)
Summing Trackers
I use a recursive algorithm for constructing my desired data structure, a list of Tracker
s. The gist is to box an Event
into a new Tracker
then recursively “sum” existing, ordered Tracker
s until the new Tracker
is less than or equal to an existing Tracker
or it reaches the end of the list. I put “sum” in quotes because I need to define what it means to add two Tracker
s.
Let me get more helper functions out of the way first. The function getTotalCount
simply returns the number of outcomes and censored for a Tracker
. The initTracker
function initialized a Tracker
from an Event
. You can see that for an Outcome
the outcome count is set to 1
(and censor count to 0
), and conversely for a Censor
. The setAtRisk
function replaces the at risk count. (All this getting and setting points to lens as a potentially useful abstraction for future iterations)
getTotalCount :: Tracker a -> Int
getTotalCount x = getOCount x + getCCount x
initTracker :: Event a -> Tracker a
initTracker x
| isOutcome x = Tracker x 1 0 1
| otherwise = Tracker x 0 1 1
setAtRisk :: Tracker a -> Int -> Tracker a
setAtRisk (Tracker x o c _) n' = (Tracker x o c n')
Now, I can define how Tracker
s are added. The most complicated case is when the event times are equal because Outcome
needs to replace Censor
in the Tracker
event slot:
addEqTrackers :: Tracker a -> Tracker a -> Tracker a
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Outcome _) o' c' n') =
Tracker (Outcome t) (o + o') (c + c') (o + c + n')
addEqTrackers (Tracker (Censor _) o c _) (Tracker (Outcome t') o' c' n') =
Tracker (Outcome t') (o + o') (c + c') (n' + o + c)
addEqTrackers (Tracker (Outcome t) o c _) (Tracker (Censor _) _ c' n') =
Tracker (Outcome t) o (c + c') (n' + o + c)
addEqTrackers (Tracker (Censor t) _ c _) (Tracker (Censor _) _ c' n') =
Tracker (Censor t) 0 (c + c') (n' + c)
I’m sure addEqTrackers
could be significantly and possibly combined with the following functions; but as an initial pass, I like spelling out all the cases clearly.
addLtTrackers :: Tracker a -> Tracker a -> Tracker a
addLtTrackers x y = setAtRisk x (getAtRisk y + getTotalCount x)
addGtTrackers :: Tracker a -> Tracker a -> Tracker a
addGtTrackers x y = setAtRisk y (getAtRisk y + getTotalCount x)
addGtTrackers' :: Tracker a -> Tracker a -> Tracker a
addGtTrackers' x y = setAtRisk x (getAtRisk y + getTotalCount x - getTotalCount y)
The add*Trackers
functions do the work of updating Tracker
s as needed. For example, if I have two Tracker
s x
and y
and x < y
, then addLtTrackers
updates the at risk of x
with the sum of x
’s total count with the at risk of y
(see the definition of f
below to see this in action).
Accumulating Trackers
A recursive function (that I uncreatively named f
) uses the add*Trackers
functions to take a new Tracker
(wrapped in Maybe
) and update a List
of Tracker
s.
f :: (Ord a) => Maybe (Tracker a) -> [Tracker a] -> [Tracker a]
f (Just x) [] = [x]
f Nothing [] = []
f Nothing ts = ts
f (Just x) (t:ts)
| lt = [addLtTrackers x t] ++ t:ts
| eq = [addEqTrackers x t] ++ ts
| gt = [addGtTrackers x t] ++ (f (Just $ addGtTrackers' x t) ts)
where cm = compare ((getTime.getEvent) x) ((getTime.getEvent) t)
lt = (cm == LT)
eq = (cm == EQ)
gt = (cm == GT)
This function f
does the heavy lifting of creating a List
of Tracker
s ordered by time and tracking counts as new events are processed. All that remains is converting a List
of Event
s to a List
of Tracker
s, which, now that I have f
, is a one-liner using a fold pattern:
processEvents :: (Ord a, Num a) => [Event a] -> [Tracker a]
processEvents z = foldl (\x -> \y -> f (Just $ initTracker y) x) [] z
The Estimators
All the above code is for creating a data structure that is a time-ordered list of information needed to compute the KM or NA estimators. Once we have this data structure, computing the estimators is a matter of converting each element of the list to a term in the estimator and taking the product (or sum).
Since survival curves are generally presented at outcome times, I include a function that filters a Tracker
list to just Outcome
s. This filtration isn’t strictly necessary, but I’ve written the property tests below assuming the filter is in place.
outcomes :: [Tracker a] -> [Tracker a]
outcomes [] = []
outcomes (e:es)
| isOutcome (getEvent e) = [e] ++ outcomes es
| isCensor (getEvent e) = outcomes es
The eval
function evaluates an estimator from a List
of Tracker
s, given two functions (the first two arguments). The first function (Tracker a -> Rational)
converts a Tracker
to an algebraic term. For example, the KM estimator wants \(1 - d_i/n_i\) at outcome times, which for a Tracker
x is basically 1 - (getOCount x)/(getAtRisk x)
. The second argument, (Rational -> Rational -> Rational)
, is the algebraic operator. For the KM estimator, this is multiplication (*
). The third argument is the starting point of the estimate (keep in mind the estimate is a curve or a list of coordinates). The starting point is \((0, 1)\) for KM and \((0, 0)\) for NA.
eval :: (Num a) => (Tracker a -> Rational)
-> (Rational -> Rational -> Rational)
-> (a, Rational)
-> [Tracker a]
-> [(a, Rational)]
eval g op start ecs = scanl (\x -> \y -> ((getTime.getEvent) y, op (snd x) (g y))) start (outcomes ecs)
Estimators
The KM and NA estimators are created passing different functions to eval
:
toKMterm :: (Num a) => Tracker a -> Rational
toKMterm x
| isOutcome (getEvent x) && (n /= 0) = 1 - (d/n)
| isOutcome (getEvent x) && (n == 0) = 0
| otherwise = 1
where d = fromIntegral $ getOCount x
n = fromIntegral $ getAtRisk x
productLimit :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
productLimit x = eval toKMterm (*) (0, 1) (processEvents x)
toNAterm :: (Num a) => Tracker a -> Rational
toNAterm x
| isOutcome (getEvent x) && (n /= 0) = d/n
| isOutcome (getEvent x) && (n == 0) = 0
| otherwise = 1
where d = fromIntegral $ getOCount x
n = fromIntegral $ getAtRisk x
cumHazard :: (Ord a, Num a) => [Event a] -> [(a, Rational)]
cumHazard x = eval toNAterm (+) (0, 0) (processEvents x)
I think it’s neat to see how different estimators can be created as specific cases of a more general structure.
Does it work?
The code above compiles, so in that sense yes. The code will run. But I have not fully encoded the logic of the estimators into the type system. For example, I do not enforce the time ordering of the [Tracker a]
data structure at the type level. Future versions could do this.
I’ve tested the product-limit estimator against the R survival
package and the values agree in the few cases I’ve done. Certainly more comparative testing is warranted. For now, I’m including the property-based tests that I’ve come up with so far. An in-depth review of property-based testing is beyond the scope of this post, but property-based testing should be familiar ground to most statisticians. It’s very similar to checking the operating characteristics of estimators via simulation.
Here I check the following properties by randomly generating data and then checking that the properties hold:
- The length of the survival curve (or cumulative hazard) is equal to the number of outcomes + 1 (as I include a start point in the returned list).
- The estimators are invariant to shuffling the inputs
- The estimators are monotonic (decreasing for the product-limit and increasing for the cumulative hazard)
import Test.Hspec
import Test.Hspec.QuickCheck
import Test.QuickCheck.Modifiers
import Test.QuickCheck
import Estimator
import Data.List (nub)
instance (Num a, Ord a, Arbitrary a) => Arbitrary (Event a) where
arbitrary = do
b <- choose (True, False)
case b of
True -> Outcome <$> abs <$> arbitrary
False -> Censor <$> abs <$> arbitrary
newtype ShuffledEvents = ShuffledEvents ([Event Int], [Event Int]) deriving Show
instance Arbitrary ShuffledEvents where
arbitrary = do
i <- arbitrary
o <- shuffle i
return $ ShuffledEvents (i, o)
prop_length :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_length estimator x =
(length $ nub $ filter isOutcome x) + 1 === (length $ estimator x)
prop_monotonic_decreasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_decreasing estimator x =
(length $ nub $ filter isOutcome x) > 1 ==>
( all (\x -> fst x < snd x) (zip (tail z) z) )
where z = (map snd $ estimator x)
prop_monotonic_increasing :: ([Event Int] -> [(Int, Rational)]) -> [Event Int] -> Property
prop_monotonic_increasing estimator x =
(length $ nub $ filter isOutcome x) > 1 ==>
( all (\x -> fst x > snd x) (zip (tail z) z) )
where z = (map snd $ estimator x)
prop_shuffle :: ([Event Int] -> [(Int, Rational)]) -> ShuffledEvents -> Property
prop_shuffle estimator (ShuffledEvents x) =
(estimator $ fst x) === (estimator $ snd x)
main :: IO ()
main = hspec $ do
describe "productLimit invariants" $
modifyMaxSuccess (*10000) $
do
it "length is number of unique outcomes + 1" $ property (prop_length productLimit)
it "result is invariant to permuting inputs" $ property (prop_shuffle productLimit)
it "survival curve is monotonic decreasing" $ property (prop_monotonic_decreasing productLimit)
describe "cumHazard invariants" $
modifyMaxSuccess (*10000) $
do
it "length is number of unique outcomes + 1" $ property (prop_length cumHazard)
it "result is invariant to permuting inputs" $ property (prop_shuffle cumHazard)
it "cumulative hazard is monotonic increasing" $ property (prop_monotonic_increasing cumHazard)
Running one million test cases yields:
$ cabal test
Test suite estimator-test: RUNNING...
productLimit invariants
length is number of unique outcomes + 1
+++ OK, passed 1000000 tests.
result is invariant to permuting inputs
+++ OK, passed 1000000 tests.
survival curve is monotonic decreasing
+++ OK, passed 1000000 tests; 872098 discarded.
cumHazard invariants
length is number of unique outcomes + 1
+++ OK, passed 1000000 tests.
result is invariant to permuting inputs
+++ OK, passed 1000000 tests.
cumulative hazard is monotonic increasing
+++ OK, passed 1000000 tests; 872098 discarded.
The reason so many tests are discarded for the monotonicity test is that the Property
throws out cases where zero or one Outcomes occurred. Still, the monotonicity property holds in over 200000 randomly generated datasets.
Summary
That’s a lot of code for a blog post, and I haven’t really shown how you actually use the code to analyze data. But I hope I’ve demonstrated how particular estimators can be seen as special cases of higher order functions, given a taste of Haskell, and hinted at the value property-based testing.