ソースを参照

adding backward compability

ogert 2 年 前
コミット
fe008a8b19
共有1 個のファイルを変更した3 個の追加2 個の削除を含む
  1. 3 2
      cdplib/pipeline_selector/PipelineSelector.py

+ 3 - 2
cdplib/pipeline_selector/PipelineSelector.py

@@ -82,7 +82,7 @@ class PipelineSelector(ABC):
                 #  cross_val_averaging_func: Callable = np.mean,
                 cross_val_averaging_func = np.mean,
                 #  additional_metrics: Dict[str, Callable] = None,
-                 additional_metrics = None,
+                 additional_metrics = {},
                 #  additional_averaging_funcs: Dict[str, Callable] = None,
                  additional_averaging_funcs = None,
                  strategy_name: str = None,
@@ -609,7 +609,8 @@ class PipelineSelector(ABC):
         """
         try:
             
-            scoring = {"score": self._cost_func} | self._additional_metrics
+            # scoring = {"score": self._cost_func} | self._additional_metrics
+            scoring = {"score": self._cost_func, **self._additional_metrics}
             
             if self._cross_validation_needs_scorer:
                 for metric_name, metric in scoring.items():