1234567891011121314151617181920212223242526272829303132333435363738394041 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Sat May 1 13:46:42 2021
- @author: tanya
- """
- import sys
- import numpy as np
- # from numpy.typing import ArrayLike
- if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
- from typing import Iterable, Callable
- else:
- from
- import * #Iterable, Callable
- def get_optimal_proba_threshold(score_func: Callable,
- # y_true: ArrayLike,
- # proba: ArrayLike,
- y_true,
- proba,
- threshold_set: Iterable = None):
- """
- """
- scores = {}
- if threshold_set is None:
- threshold_set = np.arange(0, 1, 0.1)
- for threshold in threshold_set:
- y_pred = (proba >= threshold).astype(int)
- scores[threshold] = score_func(y_true, y_pred)
- return max(scores, key=scores.get)
|