1234567891011121314151617181920212223242526272829303132333435363738 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Sat May 1 13:46:42 2021
- @author: tanya
- """
- import sys
- import numpy as np
- from numpy.typing import ArrayLike
- if sys.version_info >= (3, 8):
- from typing import Iterable, Callable
- else:
- from typing_extensions import Iterable, Callable
- def get_optimal_proba_threshold(score_func: Callable,
- y_true: ArrayLike,
- proba: ArrayLike,
- threshold_set: Iterable = None):
- """
- """
- scores = {}
- if threshold_set is None:
- threshold_set = np.arange(0, 1, 0.1)
- for threshold in threshold_set:
- y_pred = (proba >= threshold).astype(int)
- scores[threshold] = score_func(y_true, y_pred)
- return max(scores, key=scores.get)
|