Explorar o código

adding backward compability

ogert %!s(int64=3) %!d(string=hai) anos
pai
achega
fe008a8b19
Modificáronse 1 ficheiros con 3 adicións e 2 borrados
  1. 3 2
      cdplib/pipeline_selector/PipelineSelector.py

+ 3 - 2
cdplib/pipeline_selector/PipelineSelector.py

@@ -82,7 +82,7 @@ class PipelineSelector(ABC):
                 #  cross_val_averaging_func: Callable = np.mean,
                 cross_val_averaging_func = np.mean,
                 #  additional_metrics: Dict[str, Callable] = None,
-                 additional_metrics = None,
+                 additional_metrics = {},
                 #  additional_averaging_funcs: Dict[str, Callable] = None,
                  additional_averaging_funcs = None,
                  strategy_name: str = None,
@@ -609,7 +609,8 @@ class PipelineSelector(ABC):
         """
         try:
             
-            scoring = {"score": self._cost_func} | self._additional_metrics
+            # scoring = {"score": self._cost_func} | self._additional_metrics
+            scoring = {"score": self._cost_func, **self._additional_metrics}
             
             if self._cross_validation_needs_scorer:
                 for metric_name, metric in scoring.items():