torch-survival: A PyTorch-based library for Deep Survival Analysis¶
torch-survival is a library built upon PyTorch and scikit-survival to make survival analysis using deep learning more accessible.
It’s main goal is to implement a diverse set of deep learning methods for survival analysis using scikit-survival’s API design, thus offering an improved out-of-box experience compared to existing libraries (see Alternatives). In addition, it provides many common loss functions and metrics that may also be used standalone in other PyTorch-based projects.
Important
While this library has been extensively tested and evaluated as part of the SurvHub benchmark, some parts of the API design and documentation are still considered a work in progress. Expect breaking changes as we converge towards a consistent and enjoyable developer experience.
Installation¶
torch-survival is available from PyPI and only depends on PyTorch, scikit-survival, and (for now) pycox. To install:
pip install torch-survival
For now only alpha releases are available.
Getting Started¶
Since the API design of torch-survival closely mimics that of scikit-survival (and in extension scikit-learn) it only takes a few lines of code to get started:
import pandas as pd
from sklearn.model_selection import train_test_split
from sksurv.datasets import load_whas500
from torch_survival.models import DeepSurv
X, y = load_whas500()
X = pd.get_dummies(X, drop_first=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
model = DeepSurv(random_state=42, device='cpu')
model.fit(X_train, y_train)
c_index = model.score(X_test, y_test)
Implemented Models¶
This list summarizes the currently available models. We aim to steadily improve coverage and also welcome community contributions.
DeepSurv from Katzman et al.: DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network (BMC Medical Research Methodology 2018)
DeepHit from Lee et al.: DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks (AAAI 2018)
DeepWeiSurv from Bennis et al.: Estimation of Conditional Mixture Weibull Distribution with Right Censored Data Using Neural Network for Time-to-Event Analysis (PAKDD 2020)
RankDeepSurv from Jing et al.: A deep survival analysis method based on ranking (Artificial Intelligence in Medicine 2019)
Alternatives¶
While torch-survival is the first comprehensive and production-ready library for deep survival analysis, there are other related libraries that you may want to consider:
pycox has building blocks for both continuous-time models (think DeepSurv) and discrete-time models (think DeepHit) but lacks ready-to-use models.
torchsurv implements loss functions and metrics with a heavier focus on statistical evaluation at the expense of computational performance.