Browse Source

adding backward compability

ogert 2 years ago
parent
commit
fe008a8b19
1 changed files with 3 additions and 2 deletions
  1. 3 2
      cdplib/pipeline_selector/PipelineSelector.py

+ 3 - 2
cdplib/pipeline_selector/PipelineSelector.py

@@ -82,7 +82,7 @@ class PipelineSelector(ABC):
                 #  cross_val_averaging_func: Callable = np.mean,
                 cross_val_averaging_func = np.mean,
                 #  additional_metrics: Dict[str, Callable] = None,
-                 additional_metrics = None,
+                 additional_metrics = {},
                 #  additional_averaging_funcs: Dict[str, Callable] = None,
                  additional_averaging_funcs = None,
                  strategy_name: str = None,
@@ -609,7 +609,8 @@ class PipelineSelector(ABC):
         """
         try:
             
-            scoring = {"score": self._cost_func} | self._additional_metrics
+            # scoring = {"score": self._cost_func} | self._additional_metrics
+            scoring = {"score": self._cost_func, **self._additional_metrics}
             
             if self._cross_validation_needs_scorer:
                 for metric_name, metric in scoring.items():