tgoop.com/AspiringDataScience/754
Last Update:
#optimization #ml #metrics #python #numba #codegems
В общем, sklearn-овские метрики оказались слишком медленными, пришлось их переписать на numba. Вот пример classification_report, который работает в тысячу раз быстрее и поддерживает почти всю функциональность (кроме весов и микровзвешивания). Также оптимизировал метрики auc (алгоритм взят из fastauc) и calibration (считаю бины предсказанные vs реальные, потом mae/std от их разностей). На 8M сэмплов всё работает за ~30 миллисекунд кроме auc, та ~300 мс. Для сравнения, scikit-learn-овские работают от нескольких секунд до нескольких десятков секунд.@njit()
def fast_classification_report(
y_true: np.ndarray, y_pred: np.ndarray, nclasses: int = 2, zero_division: int = 0
):
"""Custom classification report, proof of concept."""
N_AVG_ARRAYS = 3 # precisions, recalls, f1s
# storage inits
weighted_averages = np.empty(N_AVG_ARRAYS, dtype=np.float64)
macro_averages = np.empty(N_AVG_ARRAYS, dtype=np.float64)
supports = np.zeros(nclasses, dtype=np.int64)
allpreds = np.zeros(nclasses, dtype=np.int64)
misses = np.zeros(nclasses, dtype=np.int64)
hits = np.zeros(nclasses, dtype=np.int64)
# count stats
for true_class, predicted_class in zip(y_true, y_pred):
supports[true_class] += 1
allpreds[predicted_class] += 1
if predicted_class == true_class:
hits[predicted_class] += 1
else:
misses[predicted_class] += 1
# main calcs
accuracy = hits.sum() / len(y_true)
balanced_accuracy = np.nan_to_num(hits / supports, copy=True, nan=zero_division).mean()
recalls = hits / supports
precisions = hits / allpreds
f1s = 2 * (precisions * recalls) / (precisions + recalls)
# fix nans & compute averages
i=0
for arr in (precisions, recalls, f1s):
np.nan_to_num(arr, copy=False, nan=zero_division)
weighted_averages[i] = (arr * supports).sum() / len(y_true)
macro_averages[i] = arr.mean()
i+=1
return hits, misses, accuracy, balanced_accuracy, supports, precisions, recalls, f1s, macro_averages, weighted_averages
BY Aspiring Data Science
Share with your friend now:
tgoop.com/AspiringDataScience/754