Просмотр исходного кода

added the option not to make_scorer

tanja лет назад: 3
Родитель
Сommit
e728ca4388
1 измененных файлов с 11 добавлено и 5 удалено
  1. 11 5
      cdplib/pipeline_selector/PipelineSelector.py

+ 11 - 5
cdplib/pipeline_selector/PipelineSelector.py

@@ -75,6 +75,7 @@ class PipelineSelector(ABC):
                  greater_is_better: bool,
                  trials_path: str,
                  backup_trials_freq: int = None,
+                 cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
                  additional_metrics: Dict[str, Callable] = None,
                  additional_averaging_funcs: Dict[str, Callable] = None,
@@ -139,6 +140,7 @@ class PipelineSelector(ABC):
             self.configured_summary_saving = False
 
             self._cost_func = cost_func
+            self._greater_is_better = greater_is_better
             # score factor is 1 when cost_func is minimized,
             # -1 when cost func is maximized
             self._score_factor = (not greater_is_better) - greater_is_better
@@ -160,6 +162,8 @@ class PipelineSelector(ABC):
             # if cross-valition is not configured,
             # sklearn cross-validation method is taken by default
             self._cross_validation = sklearn_cross_validation
+            
+            self._cross_validation_needs_scorer = cross_validation_needs_scorer
 
             # if a trials object already exists at the given path,
             # it is loaded and the search is continued. Else,
@@ -589,11 +593,13 @@ class PipelineSelector(ABC):
             over the folds.
         """
         try:
-            scoring = {"score": make_scorer(self._cost_func)}
-
-            scoring.update({metric_name: make_scorer(metric)
-                            for metric_name, metric
-                            in self._additional_metrics.items()})
+            
+            scoring = {"score": self._cost_func} | self._additional_metrics
+            
+            if self._cross_validation_needs_scorer:
+                for metric_name, metric in scoring.itmes():
+                    scoring[metric_name] = make_scorer(
+                        metric, greater_is_better=self._greater_is_better)
 
             scores = self._cross_validation(
                     estimator=pipeline,