DeepSurv

class torch_survival.models.DeepSurv(hidden_layer_sizes={'n_layers': 4, 'n_neurons': 50}, activation={'choices': ['relu', 'selu']}, dropout={'high': 0.5, 'low': 0.0}, optimizer={'choices': ['sgd', 'adam']}, learning_rate={'high': 0.001, 'log': True, 'low': 1e-07}, momentum={'high': 0.95, 'low': 0.8}, scheduler='inverse_time', decay={'high': 0.001, 'low': 0.0}, n_epochs=500, n_trials=50, random_state=None, device=None)

Implements the DeepSurv model presented by Katzman et al. [1].

Uses a deep neural network trained with the Cox negative log partial likelihood to estimate the risk of each individual. The network’s configuration is tuned using the Sobol solver. This implementation tries to stay faithful to the original paper, with the following deviations:

  • Optuna’s default TPE sampler is used in favor of the Sobol sampler with 5-fold internal cross-validation instead of 3-fold internal cross-validation.

  • The hyperparameter search space is not detailed in the original paper, and the reference implementation is underspecified. We thus define our own shared search space in DeepSurvSearchSpace.

  • Our implementation does not support or tune \(\ell_2\) regularization. We found this to be detrimental to performance and were unable to fully replicate the described weight regularization.

Parameters:
  • hidden_layer_sizes (list of ints or TunedTopology, default=TunedTopology(n_layers=4, n_neurons=50)) – If a list, the ith element represents the number of neurons in the ith hidden layer. Alternatively, a tunable topology specifies the maximum number of layers and of neurons per layer.

  • activation ({'relu', 'selu'} or TunedCategorical, default=TunedCategorical(choices=['relu', 'selu'])) – Activation function for the hidden layer.

  • dropout (float or TunedFloat, default=TunedFloat(low=0.0, high=0.5)) – Dropout probability for the hidden layer.

  • optimizer ({'sgd', 'adam'} or TunedCategorical, default=TunedCategorical(choices=['sgd', 'adam'])) – Optimizer used for weight optimization.

  • learning_rate (float or TunedFloat, default=TunedFloat(low=1e-7, high=1e-3, log=True)) – Initial learning rate used when optimizing weights.

  • momentum (float or TunedFloat, default=TunedFloat(low=0.8, high=0.95)) – Momentum or first moment vector

  • scheduler ({'inverse_time'} or TunedCategorical, default='inverse_time') – Scheduler used for weight updates.

  • decay (float or TunedFloat, default=TunedFloat(low=0.0, high=0.001)) – Decay used by scheduler when updating learning rate.

  • n_epochs (int, default=500) – Number of training epochs (how many times each data point will be used).

  • n_trials (int, default=50) – Number of hyperparameter optimization trials. Only relevant if tunable parameters are passed.

  • random_state (int, default=None) – Determines random number generation for hyperparameter optimization and weight initialization. Pass an int for reproducible results across multiple function calls.

  • device (str or torch.device, default=None) – Device on which tensors will be allocated. If None, uses CUDA if available, else CPU.

fit(X, y)

Fit the model to the given survival data.

Parameters:
  • X (array-like, shape = (n_samples, n_features)) – Data matrix.

  • y (structured array, shape = (n_samples,)) – A structured array with two fields. The first field is a boolean where True indicates an event and False indicates right-censoring. The second field is a float with the time of event or time of censoring.

Returns:

self – The trained estimator.

Return type:

DeepSurv

predict(X)

Predict risk scores.

The risk score is predicted directly by a neural network. A higher score indicates a higher risk of experiencing the event.

Parameters:

X (array-like, shape = (n_samples, n_features)) – Data matrix.

Returns:

risk_score – Predicted risk scores.

Return type:

array, shape = (n_samples,)

predict_cumulative_hazard_function(X, return_array=False)

Predict cumulative hazard function.

The cumulative hazard function for an individual with feature vector \(x\) is defined as

\[h(t \mid x) = h_0(t) e^{\hat{h}(x)},\]

where \(h_0(t)\) is the baseline hazard function, estimated by Breslow’s estimator, and \(\hat{h}(x)\) is the time-independent proportional risk estimated by DeepSurv.

Parameters:
  • X (array-like, shape = (n_samples, n_features)) – Data matrix.

  • return_array (bool, default=False) – Determines whether to return a 2D array of cumulative hazard values of shape (n_samples, n_unique_times) (if True) or a 1D array of sksurv.functions.StepFunction objects (if False).

Returns:

cum_hazard – If return_array is False, array of n_samples sksurv.functions.StepFunction objects. If return_array is True, numeric array of shape (n_samples, n_unique_times), where n_unique_times is the number of unique event times in the training data.

Return type:

ndarray, shape = (n_samples, n_unique_times) or (n_samples,)

predict_survival_function(X, return_array=False)

Predict survival function.

The survival function for an individual with feature vector \(x\) is defined as

\[S(t \mid x) = S_0(t)^{e^{\hat{h}(x)}},\]

where \(S_0(t)\) is the baseline survival function, estimated by Breslow’s estimator, and \(\hat{h}(x)\) is the time-independent proportional risk estimated by DeepSurv.

Parameters:
  • X (array-like, shape = (n_samples, n_features)) – Data matrix.

  • return_array (bool, default=False) – Determines whether to return a 2D array of survival probabilities of shape (n_samples, n_unique_times) (if True) or a 1D array of sksurv.functions.StepFunction objects (if False).

Returns:

survival_prob – If return_array is False, array of n_samples sksurv.functions.StepFunction objects. If return_array is True, numeric array of shape (n_samples, n_unique_times), where n_unique_times is the number of unique event times in the training data.

Return type:

ndarray, shape = (n_samples, n_unique_times) or (n_samples,)

score(X, y)

Returns the concordance index of the prediction.

Parameters:
  • X (array-like, shape = (n_samples, n_features)) – Test samples.

  • y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.

Returns:

cindex – Estimated concordance index.

Return type:

float

See also

sksurv.metrics.concordance_index_censored

Computes the concordance index.

save(path)

Save a trained model to disk.

This saves the model’s parameters (dictionary) and weights (tensors) such that the estimator can be fully restored using DeepSurv.load(). Internally relies on torch.save().

Parameters:

path (str or file-like object or os.PathLike) – Path at which to save the model.

classmethod load(path, device=None)

Reload an already trained model.

Parameters:
  • path (str or file-like object or os.PathLike) – Path from which to load the model. See also DeepSurv.save() for additional details.

  • device (str or torch.device, default=None) – Device on which tensors will be allocated. If None, uses CUDA if available, else CPU.

Returns:

model – The trained estimator.

Return type:

DeepSurv