RankDeepSurv¶
- class torch_survival.models.RankDeepSurv(hidden_layer_sizes={'n_layers': 4, 'n_neurons': 50}, activation={'choices': ['relu', 'selu']}, dropout={'high': 0.5, 'low': 0.0}, optimizer={'choices': ['sgd', 'adam']}, learning_rate={'high': 0.01, 'log': True, 'low': 1e-06}, momentum={'high': 0.95, 'low': 0.8}, scheduler='inverse_time', decay={'high': 0.001, 'low': 0.0}, alpha={'high': 1.0, 'low': 0.0}, batch_size=128, n_epochs=500, n_trials=50, random_state=None, device=None)¶
Implements the RankDeepSurv model presented by Jing et al. [1].
Uses a deep neural network trained with a mean squared error and ranking loss, adapted for censoring, to estimate the survival time of each individual. This implementation tries to stay faithful to the original paper, with the following deviations:
Neither the original paper nor the provided implementation state how hyperparameters are derived. We use Optuna’s default TPE sampler with 5-fold internal cross-validation here.
The original loss formulation uses separate \(\alpha\) and \(\beta\) weights for the regression and ranking loss, respectively. We instead use only a single \(\alpha\) such that \(\mathcal{L}= \alpha \mathcal{L}_{mse} + (1-\alpha) \mathcal{L}_{rank}\).
Our implementation does not support or tune \(\ell_2\) regularization. We found this to be detrimental to performance and were unable to fully replicate the described weight regularization.
Note
As the model predicts a single expected survival time, it does not support predicting either a cumulative hazard function or a survival function. This is not an omission but a fundamental consequence of the model’s design.
- Parameters:
hidden_layer_sizes (list of ints or TunedTopology, default=TunedTopology(n_layers=4, n_neurons=50)) – If a list, the ith element represents the number of neurons in the ith hidden layer. Alternatively, a tunable topology specifies the maximum number of layers and of neurons per layer.
activation ({'relu', 'selu'} or TunedCategorical, default=TunedCategorical(choices=['relu', 'selu'])) – Activation function for the hidden layer.
dropout (float or TunedFloat, default=TunedFloat(low=0.0, high=0.5)) – Dropout probability for the hidden layer.
optimizer ({'sgd', 'adam'} or TunedCategorical, default=TunedCategorical(choices=['sgd', 'adam'])) – Optimizer used for weight optimization.
learning_rate (float or TunedFloat, default=TunedFloat(low=1e-7, high=1e-3, log=True)) – Initial learning rate used when optimizing weights.
momentum (float or TunedFloat, default=TunedFloat(low=0.8, high=0.95)) – Momentum or first moment vector
scheduler ({'inverse_time'} or TunedCategorical, default='inverse_time') – Scheduler used for weight updates.
decay (float or TunedFloat, default=TunedFloat(low=0.0, high=0.001)) – Decay used by scheduler when updating learning rate.
alpha (float or TunedFloat, default=TunedFloat(low=0.0, high=1.0)) – Weight term for regression and ranking loss, with \(\mathcal{L}=\alpha \mathcal{L}_{mse} + (1-\alpha) \mathcal{L}_{rank}\).
batch_size (int or TunedInt, default=128) – Size of minibatches, here the number of pairs as training iterates over pairs of compatible samples.
n_epochs (int, default=500) – Number of training epochs (how many times each pair of compatible samples will be used).
n_trials (int, default=50) – Number of hyperparameter optimization trials. Only relevant if tunable parameters are passed.
random_state (int, default=None) – Determines random number generation for hyperparameter optimization and weight initialization. Pass an int for reproducible results across multiple function calls.
device (str or torch.device, default=None) – Device on which tensors will be allocated. If None, uses CUDA if available, else CPU.
- fit(X, y)¶
Fit the model to the given survival data.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix.
y (structured array, shape = (n_samples,)) – A structured array with two fields. The first field is a boolean where
Trueindicates an event andFalseindicates right-censoring. The second field is a float with the time of event or time of censoring.
- Returns:
self – The trained estimator.
- Return type:
- predict(X)¶
Predict survival times.
The survival time is predicted directly by a neural network.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix.
- Returns:
survival_time – Predicted survival times.
- Return type:
array, shape = (n_samples,)
- score(X, y)¶
Returns the concordance index of the prediction.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Test samples.
y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
- Returns:
cindex – Estimated concordance index.
- Return type:
float
See also
sksurv.metrics.concordance_index_censoredComputes the concordance index.
- save(path)¶
Save a trained model to disk.
This saves the model’s parameters (dictionary) and weights (tensors) such that the estimator can be fully restored using
RankDeepSurv.load(). Internally relies ontorch.save().- Parameters:
path (str or file-like object or os.PathLike) – Path at which to save the model.
- classmethod load(path, device=None)¶
Reload an already trained model.
- Parameters:
path (str or file-like object or os.PathLike) – Path from which to load the model. See also
RankDeepSurv.save()for additional details.device (str or torch.device, default=None) – Device on which tensors will be allocated. If None, uses CUDA if available, else CPU.
- Returns:
model – The trained estimator.
- Return type: