Explorar o código

fixed bug in type declaration

tanja %!s(int64=3) %!d(string=hai) anos
pai
achega
e1181077da
Modificáronse 1 ficheiros con 9 adicións e 10 borrados
  1. 9 10
      cdplib/pipeline_selector/PipelineSelector.py

+ 9 - 10
cdplib/pipeline_selector/PipelineSelector.py

@@ -65,12 +65,11 @@ class PipelineSelector(ABC):
                  cost_func: Union[Callable, str],
                  greater_is_better: bool,
                  trials_path: str,
-                 backup_trials_freq: Optional[int] = None,
+                 backup_trials_freq: int = None,
                  cross_val_averaging_func: Callable = np.mean,
-                 additional_metrics: Optional[Dict[str, Callable]] = None,
-                 additional_averaging_funcs:
-                     Optional[Dict[str, Callable]] = None,
-                 strategy_name: Optional[str] = None,
+                 additional_metrics: Dict[str, Callable] = None,
+                 additional_averaging_funcs: Dict[str, Callable] = None,
+                 strategy_name: str = None,
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  = "INFO"):
         """
@@ -302,13 +301,13 @@ class PipelineSelector(ABC):
             self._logger.loger_and_raise_error(err)
 
     def attach_data(self, X_train: Union[pd.DataFrame, np.ndarray],
-                    y_train: Optional[pd.DataFrame, pd.Series, np.ndarray]
+                    y_train: Union[pd.DataFrame, pd.Series, np.ndarray]
                     = None,
-                    X_val: Optional[pd.DataFrame, np.ndarray]
+                    X_val: Union[pd.DataFrame, np.ndarray]
                     = None,
-                    y_val: Optional[pd.DataFrame, pd.Series, np.ndarray]
+                    y_val: Union[pd.DataFrame, pd.Series, np.ndarray]
                     = None,
-                    cv: Optional[Iterable[Tuple[List[int], List[int]]]]
+                    cv: Union[Iterable[Tuple[List[int], List[int]]]]
                     = None) -> None:
         '''
         :param array X_train: data on which
@@ -468,7 +467,7 @@ class PipelineSelector(ABC):
                                 = functools.partial(
                                         pd.DataFrame.to_excel,
                                         **{"path_or_buf": "result.csv"}),
-                                kwargs: Optional[dict] = None) -> None:
+                                kwargs: dict = None) -> None:
         """
         When the score calculated by _objective function improves,
         the default summary is updated with information about the