torch-survival: A PyTorch-based library for Deep Survival Analysis

torch-survival is a library built upon PyTorch and scikit-survival to make survival analysis using deep learning more accessible.

It’s main goal is to implement a diverse set of deep learning methods for survival analysis using scikit-survival’s API design, thus offering an improved out-of-box experience compared to existing libraries (see Alternatives). In addition, it provides many common loss functions and metrics that may also be used standalone in other PyTorch-based projects.

Important

While this library has been extensively tested and evaluated as part of the SurvHub benchmark, some parts of the API design and documentation are still considered a work in progress. Expect breaking changes as we converge towards a consistent and enjoyable developer experience.

Installation

torch-survival is available from PyPI and only depends on PyTorch, scikit-survival, and (for now) pycox. To install:

pip install torch-survival

For now only alpha releases are available.

Getting Started

Since the API design of torch-survival closely mimics that of scikit-survival (and in extension scikit-learn) it only takes a few lines of code to get started:

import pandas as pd
from sklearn.model_selection import train_test_split
from sksurv.datasets import load_whas500
from torch_survival.models import DeepSurv

X, y = load_whas500()
X = pd.get_dummies(X, drop_first=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

model = DeepSurv(random_state=42, device='cpu')
model.fit(X_train, y_train)
c_index = model.score(X_test, y_test)

Implemented Models

This list summarizes the currently available models. We aim to steadily improve coverage and also welcome community contributions.

Alternatives

While torch-survival is the first comprehensive and production-ready library for deep survival analysis, there are other related libraries that you may want to consider:

  • pycox has building blocks for both continuous-time models (think DeepSurv) and discrete-time models (think DeepHit) but lacks ready-to-use models.

  • torchsurv implements loss functions and metrics with a heavier focus on statistical evaluation at the expense of computational performance.