Browse Source

fixed cross_validate with fine tuning

tanja 3 years ago
parent
commit
6a4eea2918
1 changed files with 5 additions and 13 deletions
  1. 5 13
      cdplib/ml_validation/cross_validate_with_fine_tuning.py

+ 5 - 13
cdplib/ml_validation/cross_validate_with_fine_tuning.py

@@ -171,8 +171,6 @@ def cross_validate_with_optimal_threshold(
 
         for train_inds, val_inds in cv_threshold:
 
-            print("----- In cv threshold fold")
-
             X_train_fold, X_val_fold, y_train_fold, y_val_fold =\
                 CVComposer().cv_slice_dataset(
                     X=X_train,
@@ -190,8 +188,6 @@ def cross_validate_with_optimal_threshold(
 
             thresholds.append(threshold)
 
-            print("----- Threshold:", threshold)
-
         scores["test_threshold"].append(np.mean(thresholds))
 
         if refit:
@@ -226,8 +222,6 @@ def cross_validate_with_optimal_threshold(
 
         for (train_inds, val_inds), cv_fold in zip_longest(cv, cv_threshold):
 
-            print("=== In cv fold")
-
             X_train_fold, X_val_fold, y_train_fold, y_val_fold =\
                 CVComposer().cv_slice_dataset(
                     X=X_train,
@@ -247,8 +241,6 @@ def cross_validate_with_optimal_threshold(
                     threshold_set=threshold_set,
                     scores=scores)
 
-            print("=== scores:", scores)
-
         return scores
 
 
@@ -266,7 +258,7 @@ if __name__ == "__main__":
 
     X_train, X_val, y_train, y_val = train_test_split(X, y)
 
-    estimator = XGBRFClassifier()
+    estimator = XGBRFClassifier(use_label_encoder=False)
 
     score_func = accuracy_score
 
@@ -351,10 +343,10 @@ if __name__ == "__main__":
             score_func=accuracy_score,
             X_train=X_train,
             y_train=y_train,
-            X_val=X_val,
-            y_val=y_val,
-            X_val_threshold=X_val_threshold,
-            y_val_threshold=y_val_threshold,
+            X_val=None,
+            y_val=None,
+            X_val_threshold=None,
+            y_val_threshold=None,
             cv=3,
             cv_threshold=None,
             additional_metrics=additional_metrics)