fine_tuning.py 821 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Sat May 1 13:46:42 2021
  5. @author: tanya
  6. """
  7. import sys
  8. import numpy as np
  9. from numpy.typing import ArrayLike
  10. if sys.version_info >= (3, 8):
  11. from typing import Iterable, Callable
  12. else:
  13. from typing_extensions import Iterable, Callable
  14. def get_optimal_proba_threshold(score_func: Callable,
  15. y_true: ArrayLike,
  16. proba: ArrayLike,
  17. threshold_set: Iterable = None):
  18. """
  19. """
  20. scores = {}
  21. if threshold_set is None:
  22. threshold_set = np.arange(0, 1, 0.1)
  23. for threshold in threshold_set:
  24. y_pred = (proba >= threshold).astype(int)
  25. scores[threshold] = score_func(y_true, y_pred)
  26. return max(scores, key=scores.get)