瀏覽代碼

fixed input arguments in custom cross validation in pipeline selection

tanja 3 年之前
父節點
當前提交
05a907571d
共有 1 個文件被更改,包括 12 次插入7 次删除
  1. 12 7
      cdplib/pipeline_selector/PipelineSelector.py

+ 12 - 7
cdplib/pipeline_selector/PipelineSelector.py

@@ -600,14 +600,19 @@ class PipelineSelector(ABC):
                 for metric_name, metric in scoring.itmes():
                     scoring[metric_name] = make_scorer(
                         metric, greater_is_better=self._greater_is_better)
+                    
+            cross_validation_input_args = {
+                 "estimator": pipeline,
+                 "X": self._X,
+                 "y": self._y,
+                 "cv": self._cv,
+                 "scoring": scoring
+                 }
+            
+            if "error_score" in self._cross_validation.__annotations__:
+                cross_validation_input_args["error_score"] = np.nan
 
-            scores = self._cross_validation(
-                    estimator=pipeline,
-                    X=self._X,
-                    y=self._y,
-                    cv=self._cv,
-                    scoring=scoring,
-                    error_score=np.nan)
+            scores = self._cross_validation(**cross_validation_input_args)
 
             averaging_funcs = {
                     metric_name: self._additional_averaging_funcs[metric_name]