|
@@ -82,7 +82,7 @@ class PipelineSelector(ABC):
|
|
|
# cross_val_averaging_func: Callable = np.mean,
|
|
|
cross_val_averaging_func = np.mean,
|
|
|
# additional_metrics: Dict[str, Callable] = None,
|
|
|
- additional_metrics = None,
|
|
|
+ additional_metrics = {},
|
|
|
# additional_averaging_funcs: Dict[str, Callable] = None,
|
|
|
additional_averaging_funcs = None,
|
|
|
strategy_name: str = None,
|
|
@@ -609,7 +609,8 @@ class PipelineSelector(ABC):
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
- scoring = {"score": self._cost_func} | self._additional_metrics
|
|
|
+ # scoring = {"score": self._cost_func} | self._additional_metrics
|
|
|
+ scoring = {"score": self._cost_func, **self._additional_metrics}
|
|
|
|
|
|
if self._cross_validation_needs_scorer:
|
|
|
for metric_name, metric in scoring.items():
|