fine_tuning.py 944 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Sat May 1 13:46:42 2021
  5. @author: tanya
  6. """
  7. import sys
  8. import numpy as np
  9. # from numpy.typing import ArrayLike
  10. if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
  11. from typing import Iterable, Callable
  12. else:
  13. from typing_extensions import * #Iterable, Callable
  14. def get_optimal_proba_threshold(score_func: Callable,
  15. # y_true: ArrayLike,
  16. # proba: ArrayLike,
  17. y_true,
  18. proba,
  19. threshold_set: Iterable = None):
  20. """
  21. """
  22. scores = {}
  23. if threshold_set is None:
  24. threshold_set = np.arange(0, 1, 0.1)
  25. for threshold in threshold_set:
  26. y_pred = (proba >= threshold).astype(int)
  27. scores[threshold] = score_func(y_true, y_pred)
  28. return max(scores, key=scores.get)