fine_tuning.py 932 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Sat May 1 13:46:42 2021
  5. @author: tanya
  6. """
  7. import sys
  8. import numpy as np
  9. # from numpy.typing import ArrayLike
  10. if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
  11. from typing import Iterable, Callable
  12. else:
  13. from
  14. import * #Iterable, Callable
  15. def get_optimal_proba_threshold(score_func: Callable,
  16. # y_true: ArrayLike,
  17. # proba: ArrayLike,
  18. y_true,
  19. proba,
  20. threshold_set: Iterable = None):
  21. """
  22. """
  23. scores = {}
  24. if threshold_set is None:
  25. threshold_set = np.arange(0, 1, 0.1)
  26. for threshold in threshold_set:
  27. y_pred = (proba >= threshold).astype(int)
  28. scores[threshold] = score_func(y_true, y_pred)
  29. return max(scores, key=scores.get)