瀏覽代碼

removing Dict typing

ogert 3 年之前
父節點
當前提交
515b4fbc66
共有 2 個文件被更改,包括 17 次插入23 次删除
  1. 11 14
      cdplib/hyperopt/HyperoptPipelineSelector.py
  2. 6 9
      cdplib/pipeline_selector/PipelineSelector.py

+ 11 - 14
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -12,8 +12,6 @@ Created on Tue Oct  6 15:04:25 2020
  the best pipeline so far as well as the entire tuning history
  if needed.
 """
-from __future__ import annotations
-
 import os
 import sys
 
@@ -35,14 +33,12 @@ from cdplib.pipeline_selector.PipelineSelector import PipelineSelector,\
 
 # from typing import Callable, Optional, Literal, Dict, Union, List
 if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
-    from typing import Callable, \
+    from typing import Callable, Optional,\
     Literal, Dict, List, Union
 else:
     # from typing_extensions import *
-    Dict = dict
-    List = list
-    from typing_extensions import Callable, \
-    Literal, Dict, List, Union
+    from typing_extensions import Callable, Optional,\
+    Literal, Union
 
 from cdplib.log import Log
 
@@ -71,7 +67,8 @@ class HyperoptPipelineSelector(PipelineSelector):
                  backup_trials_freq: Optional[int] = None,
                  cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
-                 additional_metrics: Optional[Dict[str, Callable]] = None,
+                #  additional_metrics: Optional[Dict[str, Callable]] = None,
+                additional_metrics = None,
                  strategy_name: Optional[str] = None,
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  = "INFO"):
@@ -198,8 +195,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self._logger.log_and_raise_error(err)
 
-    def _get_space_element_from_trial(self, trial: dict)\
-            -> Union[Dict[str, SpaceElementType], None]:
+    def _get_space_element_from_trial(self, trial: dict):#\
+            # -> Union[Dict[str, SpaceElementType], None]:
         """
         Hyperopt trials object does not contain the space
              elements that result in the corresponding trials.
@@ -234,8 +231,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self._logger.log_and_raise_error(err)
 
-    def _get_space_element_from_index(self, i: int)\
-            -> Union[Dict[str, SpaceElementType], None]:
+    def _get_space_element_from_index(self, i: int): #\
+            # -> Union[Dict[str, SpaceElementType], None]:
         """
         Gets the space element of shape
         {"name": NAME, "params": PARAMS, "pipeline": PIPELINE}
@@ -402,8 +399,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self._logger.log_and_raise_error(err)
 
-    def get_n_best_trial_pipelines_of_each_type(self, n: int)\
-            -> Union[Dict[str, List[Pipeline]], None]:
+    def get_n_best_trial_pipelines_of_each_type(self, n: int): #\
+            # -> Union[Dict[str, List[Pipeline]], None]:
         """
         :return: a dictiionry where keys are pipeline names,
         and values are lists of best pipelines with this name

+ 6 - 9
cdplib/pipeline_selector/PipelineSelector.py

@@ -16,8 +16,6 @@ Created on Wed Sep 30 14:23:23 2020
  save the current best result in a file or database during training.
  Children classes: hyperopt and custom gridsearch.
 """
-from __future__ import annotations
-
 import pickle
 import os
 import sys
@@ -35,11 +33,8 @@ if (sys.version_info.major == 3) & (sys.version_info.minor >= 8):
 else:
     # from typing_extensions import *
     print("I have python version {}.{} and will import typing_extensions".format(sys.version_info.major, sys.version_info.minor))
-    
-    Dict = dict
-    List = list
     from typing_extensions import Callable, TypedDict,\
-    Literal, Dict, Iterable, List, Tuple, Union
+    Literal, Iterable, Tuple, Union
 
 import functools
 from sklearn.pipeline import Pipeline
@@ -84,8 +79,10 @@ class PipelineSelector(ABC):
                  backup_trials_freq: int = None,
                  cross_validation_needs_scorer: bool = True,
                  cross_val_averaging_func: Callable = np.mean,
-                 additional_metrics: Dict[str, Callable] = None,
-                 additional_averaging_funcs: Dict[str, Callable] = None,
+                #  additional_metrics: Dict[str, Callable] = None,
+                 additional_metrics = None,
+                #  additional_averaging_funcs: Dict[str, Callable] = None,
+                 additional_averaging_funcs = None,
                  strategy_name: str = None,
                  stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
                  = "INFO"):
@@ -567,7 +564,7 @@ class PipelineSelector(ABC):
 
             self._logger.log_and_raise_error(err)
 
-    def _evaluate(self, pipeline: Pipeline) -> Union[Dict[str, float], None]:
+    def _evaluate(self, pipeline: Pipeline) :#-> Union[Dict[str, float], None]:
         """
         Calculates the averaged cross-validated score and score variance,
         as well as the averaged values and variances of the additional metrics.