Hello, I am working in HPC support team. I summarize a challenge we have received from an user asking our help through the ticket system.
Life scientists are interested in Survival Data Science model. Unlike standard ML doing regression or classification, Survival ML returns survival function. The technology used in this ticket is Survival-Scikit: scikit-survival — scikit-survival 0.21.0
After training and testing Survival ML models, if the model is accurate enough, life scientist wants to explain it. They use SurvivalShape(t) :
The SurvivalShap(t) algorithm is killed due to a lack of memory.
A critical line of code in SurvivalShap Python framework is this one:
simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=p)]
The memory/time complexity of the line itertools.product(range(2),repeat=n) is O(2 power n) with n the number of variables. The user uses 387 variables which is impossible to store/compute and crashes.
Reducing the number of variables
Github Pull Request:
I recently proposed a new feature to the developers of Survival Shap. This feature involves sampling the computations to achieve results comparable to the true brute-force approach.
It is up to developers to make it official.
- Github Issue
I also reported a speed issue with Survival Random Forest predictions in the framework named survival-scikitlearn:
Survival Random Forest predict_survival_function does not scale with `n_jobs` · Issue #382 · sebp/scikit-survival · GitHub
It was possible for the life scientist to reduce the number of variables and she proceeded.
I propose some update for approximating the Survival Shap algorithm and report to Survival Scikit that SRForest predictions does not scale when n_jobs increases.