|
@@ -162,7 +162,7 @@ import sys
|
|
|
import numpy as np
|
|
|
from itertools import zip_longest
|
|
|
|
|
|
-from numpy.typing import ArrayLike
|
|
|
+# from numpy.typing import ArrayLike
|
|
|
|
|
|
if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
|
|
|
from typing import Callable, Dict, Iterable, Union
|
|
@@ -183,9 +183,12 @@ from cdplib.ml_validation.fine_tuning import get_optimal_proba_threshold
|
|
|
def cross_validate_with_optimal_threshold(
|
|
|
score_func_threshold: Callable,
|
|
|
estimator: object,
|
|
|
- X: ArrayLike,
|
|
|
- y: ArrayLike = None,
|
|
|
- groups: ArrayLike = None,
|
|
|
+ # X: ArrayLike,
|
|
|
+ # y: ArrayLike = None,
|
|
|
+ # groups: ArrayLike = None,
|
|
|
+ X,
|
|
|
+ y = None,
|
|
|
+ groups = None,
|
|
|
scoring: Union[Callable, Dict] = None,
|
|
|
cv: Union[Iterable, int, None] = None,
|
|
|
n_jobs: int = None,
|
|
@@ -195,10 +198,14 @@ def cross_validate_with_optimal_threshold(
|
|
|
return_train_score: bool = False,
|
|
|
return_estimator: bool = False,
|
|
|
error_score: float = np.nan,
|
|
|
- X_val: ArrayLike = None,
|
|
|
- y_val: ArrayLike = None,
|
|
|
- X_val_threshold: ArrayLike = None,
|
|
|
- y_val_threshold: ArrayLike = None,
|
|
|
+ # X_val: ArrayLike = None,
|
|
|
+ # y_val: ArrayLike = None,
|
|
|
+ # X_val_threshold: ArrayLike = None,
|
|
|
+ # y_val_threshold: ArrayLike = None,
|
|
|
+ X_val = None,
|
|
|
+ y_val = None,
|
|
|
+ X_val_threshold = None,
|
|
|
+ y_val_threshold = None,
|
|
|
cv_threshold: Union[Iterable, int, None] = None,
|
|
|
threshold_set: Union[Iterable, None] = None,
|
|
|
scores: Dict = None)-> Dict:
|