Quellcode durchsuchen

remove union typing

ogert vor 3 Jahren
Ursprung
Commit
359c4e0853

+ 11 - 10
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -38,7 +38,7 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
 else:
     # from typing_extensions import *
     from typing_extensions import Callable, Optional,\
-    Literal, Union
+    Literal
 
 from cdplib.log import Log
 
@@ -61,7 +61,8 @@ class HyperoptPipelineSelector(PipelineSelector):
     a better pipeline was found.
     """
     def __init__(self,
-                 cost_func: Union[Callable, str],
+                #  cost_func: Union[Callable, str],
+                 cost_func,
                  greater_is_better: bool,
                  trials_path: str,
                  backup_trials_freq: Optional[int] = None,
@@ -180,7 +181,7 @@ class HyperoptPipelineSelector(PipelineSelector):
             self._logger.log_and_raise_error(err)
 
     @property
-    def number_of_trials(self) -> Union[int, None]:
+    def number_of_trials(self):# -> Union[int, None]:
         """
         :return: number of trials run so far
             with the given Trials object
@@ -251,7 +252,7 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self._logger.log_and_raise_error(err)
 
-    def _get_pipeline_from_index(self, i: int) -> Union[Pipeline, None]:
+    def _get_pipeline_from_index(self, i: int):# -> Union[Pipeline, None]:
         """
         Gets a pipeline with set parameters from the trial number i
         """
@@ -267,7 +268,7 @@ class HyperoptPipelineSelector(PipelineSelector):
             self._logger.log_and_raise_error(err)
 
     @property
-    def best_trial(self) -> Union[dict, None]:
+    def best_trial(self):# -> Union[dict, None]:
         """
         :return: dictionary with the summary of the best trial
             and space element (name, pipeline, params)
@@ -308,7 +309,7 @@ class HyperoptPipelineSelector(PipelineSelector):
                 self._logger.log_and_raise_error(err)
 
     @property
-    def best_trial_score(self) -> Union[float, None]:
+    def best_trial_score(self):# -> Union[float, None]:
         """
         """
         try:
@@ -324,7 +325,7 @@ class HyperoptPipelineSelector(PipelineSelector):
             self._logger.log_and_raise_error(err)
 
     @property
-    def best_trial_score_variance(self) -> Union[float, None]:
+    def best_trial_score_variance(self):# -> Union[float, None]:
         """
         """
         try:
@@ -340,7 +341,7 @@ class HyperoptPipelineSelector(PipelineSelector):
             self._logger.log_and_raise_error(err)
 
     @property
-    def best_trial_pipeline(self) -> Union[Pipeline, None]:
+    def best_trial_pipeline(self):# -> Union[Pipeline, None]:
         """
         """
         try:
@@ -376,8 +377,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self._logger.log_and_raise_error(err)
 
-    def get_n_best_trial_pipelines(self, n: int)\
-            -> Union[List[Pipeline], None]:
+    def get_n_best_trial_pipelines(self, n: int):#\
+            # -> Union[List[Pipeline], None]:
         """
         :return: the list of n best pipelines
         documented in trials

+ 17 - 11
cdplib/pipeline_selector/PipelineSelector.py

@@ -34,7 +34,7 @@ else:
     # from typing_extensions import *
     print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
     from typing_extensions import Callable, TypedDict,\
-    Literal, Tuple, Union
+    Literal
 
 import functools
 from sklearn.pipeline import Pipeline
@@ -73,7 +73,8 @@ class PipelineSelector(ABC):
     Children classes: hyperopt and custom gridsearch.
     """
     def __init__(self,
-                 cost_func: Union[Callable, str],
+                #  cost_func: Union[Callable, str],
+                 cost_func,
                  greater_is_better: bool,
                  trials_path: str,
                  backup_trials_freq: int = None,
@@ -329,16 +330,21 @@ class PipelineSelector(ABC):
 
             self._logger.loger_and_raise_error(err)
 
-    def attach_data(self, X_train: Union[pd.DataFrame, np.ndarray],
-                    y_train: Union[pd.DataFrame, pd.Series, np.ndarray]
-                    = None,
-                    X_val: Union[pd.DataFrame, np.ndarray]
-                    = None,
-                    y_val: Union[pd.DataFrame, pd.Series, np.ndarray]
-                    = None,
+    def attach_data(self, 
+                    # X_train: Union[pd.DataFrame, np.ndarray],
+                    # y_train: Union[pd.DataFrame, pd.Series, np.ndarray]
+                    # = None,
+                    # X_val: Union[pd.DataFrame, np.ndarray]
+                    # = None,
+                    # y_val: Union[pd.DataFrame, pd.Series, np.ndarray]
+                    # = None,
                     # cv: Union[Iterable[Tuple[List[int], List[int]]]]
-                    cv
-                    = None) -> None:
+                    # = None
+                    X_train,
+                    y_train = None,
+                    X_val = None,
+                    y_val = None,
+                    cv = None) -> None:
         '''
         :param array X_train: data on which
             machine learning pipelines are trained