Hello, I am working in HPC support team. I summarize a challenge we have received from an user asking our help through the ticket system.
Context:
Life scientists are interested in Survival Data Science model. Unlike standard ML doing regression or classification, Survival ML returns survival function. The technology used in this ticket is Survival-Scikit: scikit-survival โ scikit-survival 0.21.0
After training and testing Survival ML models, if the model is accurate enough, life scientist wants to explain it. They use SurvivalShape(t) :
Challenge:
The SurvivalShap(t) algorithm is killed due to a lack of memory.
Diagnostic:
A critical line of code in SurvivalShap Python framework is this one:
simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=p)]
The memory/time complexity of the line itertools.product(range(2),repeat=n) is O(2 power n) with n the number of variables. The user uses 387 variables which is impossible to store/compute and crashes.
Proposed solutions:
Reducing the number of variables
Github Pull Request:
I recently proposed a new feature to the developers of Survival Shap. This feature involves sampling the computations to achieve results comparable to the true brute-force approach.
It was possible for the life scientist to reduce the number of variables and she proceeded.
I propose some update for approximating the Survival Shap algorithm and report to Survival Scikit that SRForest predictions does not scale when n_jobs increases.