Survival Data Science Explanability performance issue

Hello, I am working in HPC support team. I summarize a challenge we have received from an user asking our help through the ticket system.

Context:

Life scientists are interested in Survival Data Science model. Unlike standard ML doing regression or classification, Survival ML returns survival function. The technology used in this ticket is Survival-Scikit: scikit-survival โ€” scikit-survival 0.21.0

After training and testing Survival ML models, if the model is accurate enough, life scientist wants to explain it. They use SurvivalShape(t) :

Challenge:

The SurvivalShap(t) algorithm is killed due to a lack of memory.

Diagnostic:

A critical line of code in SurvivalShap Python framework is this one:

simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=p)]

The memory/time complexity of the line itertools.product(range(2),repeat=n) is O(2 power n) with n the number of variables. The user uses 387 variables which is impossible to store/compute and crashes.

Proposed solutions:

  • Reducing the number of variables

  • Github Pull Request:

I recently proposed a new feature to the developers of Survival Shap. This feature involves sampling the computations to achieve results comparable to the true brute-force approach.

It is up to developers to make it official.

Conclusion:

It was possible for the life scientist to reduce the number of variables and she proceeded.

I propose some update for approximating the Survival Shap algorithm and report to Survival Scikit that SRForest predictions does not scale when n_jobs increases.

1 Like