Browse Source

removing Dict typing

ogert 2 years ago
parent
commit
515b4fbc66

+ 11 - 14
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -12,8 +12,6 @@ Created on Tue Oct  6 15:04:25 2020
  the best pipeline so far as well as the entire tuning history
  the best pipeline so far as well as the entire tuning history
  if needed.
  if needed.
 """
 """
-from __future__ import annotations
-
 import os
 import os
 import sys
 import sys
 
 
@@ -35,14 +33,12 @@ from cdplib.pipeline_selector.PipelineSelector import PipelineSelector,\
 
 
 # from typing import Callable, Optional, Literal, Dict, Union, List
 # from typing import Callable, Optional, Literal, Dict, Union, List
 if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
 if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
-    from typing import Callable, \
+    from typing import Callable, Optional,\
     Literal, Dict, List, Union
     Literal, Dict, List, Union
 else:
 else:
     # from typing_extensions import *
     # from typing_extensions import *
-    Dict = dict
-    List = list
-    from typing_extensions import Callable, \
-    Literal, Dict, List, Union
+    from typing_extensions import Callable, Optional,\
+    Literal, Union
 
 
 from cdplib.log import Log
 from cdplib.log import Log
 
 
@@ -71,7 +67,8 @@ class HyperoptPipelineSelector(PipelineSelector):
                  backup_trials_freq: Optional[int] = None,
                  backup_trials_freq: Optional[int] = None,
                  cross_validation_needs_scorer: bool = True,
                  cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
                  cross_val_averaging_func: Callable = np.mean,
-                 additional_metrics: Optional[Dict[str, Callable]] = None,
+                #  additional_metrics: Optional[Dict[str, Callable]] = None,
+                additional_metrics = None,
                  strategy_name: Optional[str] = None,
                  strategy_name: Optional[str] = None,
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  = "INFO"):
                  = "INFO"):
@@ -198,8 +195,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
 
             self._logger.log_and_raise_error(err)
             self._logger.log_and_raise_error(err)
 
 
-    def _get_space_element_from_trial(self, trial: dict)\
-            -> Union[Dict[str, SpaceElementType], None]:
+    def _get_space_element_from_trial(self, trial: dict):#\
+            # -> Union[Dict[str, SpaceElementType], None]:
         """
         """
         Hyperopt trials object does not contain the space
         Hyperopt trials object does not contain the space
              elements that result in the corresponding trials.
              elements that result in the corresponding trials.
@@ -234,8 +231,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
 
             self._logger.log_and_raise_error(err)
             self._logger.log_and_raise_error(err)
 
 
-    def _get_space_element_from_index(self, i: int)\
-            -> Union[Dict[str, SpaceElementType], None]:
+    def _get_space_element_from_index(self, i: int): #\
+            # -> Union[Dict[str, SpaceElementType], None]:
         """
         """
         Gets the space element of shape
         Gets the space element of shape
         {"name": NAME, "params": PARAMS, "pipeline": PIPELINE}
         {"name": NAME, "params": PARAMS, "pipeline": PIPELINE}
@@ -402,8 +399,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
 
             self._logger.log_and_raise_error(err)
             self._logger.log_and_raise_error(err)
 
 
-    def get_n_best_trial_pipelines_of_each_type(self, n: int)\
-            -> Union[Dict[str, List[Pipeline]], None]:
+    def get_n_best_trial_pipelines_of_each_type(self, n: int): #\
+            # -> Union[Dict[str, List[Pipeline]], None]:
         """
         """
         :return: a dictiionry where keys are pipeline names,
         :return: a dictiionry where keys are pipeline names,
         and values are lists of best pipelines with this name
         and values are lists of best pipelines with this name

+ 6 - 9
cdplib/pipeline_selector/PipelineSelector.py

@@ -16,8 +16,6 @@ Created on Wed Sep 30 14:23:23 2020
  save the current best result in a file or database during training.
  save the current best result in a file or database during training.
  Children classes: hyperopt and custom gridsearch.
  Children classes: hyperopt and custom gridsearch.
 """
 """
-from __future__ import annotations
-
 import pickle
 import pickle
 import os
 import os
 import sys
 import sys
@@ -35,11 +33,8 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
 else:
 else:
     # from typing_extensions import *
     # from typing_extensions import *
     print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
     print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
-    
-    Dict = dict
-    List = list
     from typing_extensions import Callable, TypedDict,\
     from typing_extensions import Callable, TypedDict,\
-    Literal, Dict, Iterable, List, Tuple, Union
+    Literal, Iterable, Tuple, Union
 
 
 import functools
 import functools
 from sklearn.pipeline import Pipeline
 from sklearn.pipeline import Pipeline
@@ -84,8 +79,10 @@ class PipelineSelector(ABC):
                  backup_trials_freq: int = None,
                  backup_trials_freq: int = None,
                  cross_validation_needs_scorer: bool = True,
                  cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
                  cross_val_averaging_func: Callable = np.mean,
-                 additional_metrics: Dict[str, Callable] = None,
-                 additional_averaging_funcs: Dict[str, Callable] = None,
+                #  additional_metrics: Dict[str, Callable] = None,
+                 additional_metrics = None,
+                #  additional_averaging_funcs: Dict[str, Callable] = None,
+                 additional_averaging_funcs = None,
                  strategy_name: str = None,
                  strategy_name: str = None,
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  = "INFO"):
                  = "INFO"):
@@ -567,7 +564,7 @@ class PipelineSelector(ABC):
 
 
             self._logger.log_and_raise_error(err)
             self._logger.log_and_raise_error(err)
 
 
-    def _evaluate(self, pipeline: Pipeline) -> Union[Dict[str, float], None]:
+    def _evaluate(self, pipeline: Pipeline) :#-> Union[Dict[str, float], None]:
         """
         """
         Calculates the averaged cross-validated score and score variance,
         Calculates the averaged cross-validated score and score variance,
         as well as the averaged values and variances of the additional metrics.
         as well as the averaged values and variances of the additional metrics.