GridSearchPipelineSelector.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Sep 30 14:15:17 2020
  5. @author: tanya
  6. @description:a class for selecting a machine learning
  7. pipeline from a deterministic space of parameter distributions
  8. over multiple pipelines.
  9. The selection is though in such a way that a Trials object is being
  10. maintained during the tuning process from which one can retrieve
  11. the best pipeline so far as well as the entire tuning history
  12. if needed.
  13. """
  14. import os
  15. import datetime
  16. import numpy as np
  17. from itertools import product
  18. from collections import ChainMap
  19. from sklearn.pipeline import Pipeline
  20. from typing import Callable, Optional, Literal, Dict, Union, List
  21. from cdplib.pipeline_selector.PipelineSelector import PipelineSelector
  22. class GridSearchPipelineSelector(PipelineSelector):
  23. """
  24. A class for selecting a machine learning
  25. pipeline from a deterministic space of parameter distributions
  26. over multiple pipelines.
  27. The selection is though in such a way that a Trials object is being
  28. maintained during the tuning process from which one can retrieve
  29. the best pipeline so far as well as the entire tuning history
  30. if needed.
  31. """
  32. def __init__(self,
  33. cost_func: Union[Callable, str],
  34. greater_is_better: bool,
  35. trials_path: str,
  36. backup_trials_freq: Optional[int] = None,
  37. cross_val_averaging_func: Callable = np.mean,
  38. additional_metrics: Optional[Dict[str, Callable]] = None,
  39. strategy_name: Optional[str] = None,
  40. stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
  41. = "INFO"
  42. ):
  43. """
  44. ::param Callable cost_func: function to minimize or maximize
  45. over the elements of a given (pipeline/hyperparameter) space
  46. :param bool greater_is_better: when True
  47. cost_func is maximized, else minimized.
  48. :param str trials_path: path at which the trials object is saved
  49. in binary format. From the trials object we can
  50. select information about the obtained scores, score variations,
  51. and pipelines, and parameters tried out so far. If a trials object
  52. already exists at the given path, it is loaded and the
  53. search is continued, else, the search is started from scratch.
  54. :param backup_trials_freq: frequecy in interations (trials)
  55. of saving the trials object at the trials_path.
  56. if None, the trials object is backed up avery time
  57. the score improves.
  58. :param Callable cross_val_averaging_func: Function to aggregate
  59. the cross-validation scores.
  60. Example different from the mean: mean - c*var.
  61. :param additional_metics: dict of additional metrics to save
  62. of the form {"metric_name": metric} where metric is a Callable.
  63. :param str strategy_name:
  64. a strategy is defined by the data set (columns/features and rows),
  65. cv object, cost function.
  66. When the strategy changes, one must start with new trials.
  67. :param str stdout_log_level: can be INFO, WARNING, ERROR
  68. """
  69. try:
  70. super().__init__(cost_func=cost_func,
  71. greater_is_better=greater_is_better,
  72. trials_path=trials_path,
  73. backup_trials_freq=backup_trials_freq,
  74. cross_val_averaging_func=cross_val_averaging_func,
  75. additional_metrics=additional_metrics,
  76. strategy_name=strategy_name,
  77. stdout_log_level=stdout_log_level)
  78. self._logger = Log("GridsearchPipelineSelector: ",
  79. stdout_log_level=stdout_log_level)
  80. self._trials = self._trials or []
  81. except Exception as e:
  82. err = "Failed initialization. Exit with error: {}".format(e)
  83. self._logger.log_and_raise_error(err)
  84. def run_trials(self) -> None:
  85. """
  86. """
  87. try:
  88. assert(self.attached_space),\
  89. "Parameter distribution space must be attached"
  90. done_trial_ids = [{"name": trial["name"],
  91. "params": trial["params"],
  92. "status": trial["status"]}
  93. for trial in self._trials]
  94. # list (generator) of (flattened) dictionaries
  95. # with all different combinations of
  96. # parameters for different pipelines
  97. # from the space definition.
  98. space_unfolded = ({"name": param_dist["name"],
  99. "pipeline": param_dist["pipeline"],
  100. "params": param_set}
  101. for param_dist in self._space
  102. for param_set in
  103. (dict(ChainMap(*tup)) for tup in
  104. product(*[[{k: v} for v in
  105. param_dist["params"][k]]
  106. for k in param_dist["params"]])))
  107. for space_element in space_unfolded:
  108. # uniquely identifies the current space element
  109. trial_id = {"name": space_element["name"],
  110. "params": space_element["params"],
  111. "status": 'ok'}
  112. # verify if the current pipline/parameters
  113. # were already tested before
  114. if trial_id in done_trial_ids:
  115. continue
  116. result = self._objective(space_element)
  117. pipeline = space_element["pipeline"].set_params(
  118. **space_element["params"])
  119. trial = {"name": space_element["name"],
  120. "params": space_element["params"],
  121. "pipeline": pipeline}
  122. trial.update(result)
  123. self._trials.append(trial)
  124. self.finished_tuning = True
  125. self.total_tuning_time = datetime.datetime.today()\
  126. - self.start_tuning_time
  127. self._backup_trials()
  128. except Exception as e:
  129. err = "Failed to run trials. Exit with error: {}".format(e)
  130. self._logger.log_and_raise_error(err)
  131. @property
  132. def number_of_trials(self) -> Union[int, None]:
  133. """
  134. Number of trials already run in the current trials object
  135. """
  136. try:
  137. return len(self._trials)
  138. except Exception as e:
  139. err = ("Failed to retrieve the number of trials. "
  140. "Exit with error: {}".format(e))
  141. self._logger.log_and_raise_error(err)
  142. @property
  143. def best_trial(self) -> Union[dict, None]:
  144. """
  145. """
  146. try:
  147. assert(len(self._trials) > 0),\
  148. ("Trials object is empty. "
  149. "Call run_trials method.")
  150. return max(self._trials, key=lambda x: x["score"])
  151. except Exception as e:
  152. err = ("Could not retrieve the best trial. "
  153. "Exit with error: {}".format(e))
  154. self._logger.log_and_raise_error(err)
  155. @property
  156. def best_trial_score(self) -> Union[float, None]:
  157. '''
  158. '''
  159. try:
  160. assert(len(self._trials) > 0),\
  161. ("Trials object is empty. "
  162. "Call run_trials method.")
  163. return self.best_trial["score"]
  164. except Exception as e:
  165. err = ("Could not retrieve the best trial. "
  166. "Exit with error: {}".format(e))
  167. self._logger.log_and_raise_error(err)
  168. @property
  169. def best_trial_score_variance(self) -> Union[float, None]:
  170. '''
  171. '''
  172. try:
  173. assert(len(self._trials) > 0),\
  174. ("Trials object is empty. "
  175. "Call run_trials method.")
  176. return self.best_trial["score_variance"]
  177. except Exception as e:
  178. err = ("Could not retrieve the best trial. "
  179. "Exit with error: {}".format(e))
  180. self._logger.log_and_raise_error(err)
  181. @property
  182. def best_trial_pipeline(self) -> Union[Pipeline, None]:
  183. '''
  184. '''
  185. try:
  186. assert(len(self._trials) > 0),\
  187. ("Trials object is empty. "
  188. "Call run_trials method.")
  189. return self.best_trial["pipeline"]
  190. except Exception as e:
  191. err = ("Could not retrieve the best trial. "
  192. "Exit with error: {}".format(e))
  193. self._logger.log_and_raise_error(err)
  194. def get_n_best_trial_pipelines(self, n: int)\
  195. -> Union[List[Pipeline], None]:
  196. """
  197. N best pipelines with corresponding
  198. best hyperparameters
  199. """
  200. try:
  201. assert(len(self._trials) > 0),\
  202. ("Trials object is empty. "
  203. "Call run_trials method.")
  204. return [trial["pipeline"] for trial in
  205. sorted(self._trials, key=lambda x: x["score"],
  206. reverse=True)[:n]]
  207. except Exception as e:
  208. err = ("Failed to retrieve n best trials. "
  209. "Exit with error: {}".format(e))
  210. self._logger.log_and_raise_error(err)
  211. def get_n_best_trial_pipelines_of_each_type(self, n: int)\
  212. -> Union[Dict[str, List[Pipeline]], None]:
  213. """
  214. If the hyperparameter search is done over multiple
  215. pipelines, then returns n different pipeline-types
  216. with corresponding hyperparameters
  217. """
  218. try:
  219. assert(len(self._trials) > 0),\
  220. ("Trials object is empty. "
  221. "Call run_trials method.")
  222. return pd.DataFrame(self._trials)\
  223. .sort_values(by=["name", "score"],
  224. ascending=False)\
  225. .groupby("name")\
  226. .head(n)\
  227. .groupby("name")["pipeline"]\
  228. .apply(lambda x: list(x))\
  229. .to_dict()
  230. except Exception as e:
  231. err = ("Failed to retrieve n best trials of each type."
  232. "Exit with error: {}".format(e))
  233. self._logger.log_and_raise_error(err)
  234. def trials_to_excel(self, path: str) -> None:
  235. """
  236. Trials object in the shape of table written to excel,
  237. should contain the run number, pipeline (as str),
  238. hyperparamters (as str), self.best_result (see self._objective method)
  239. as well as additional information configured
  240. through self.save_result method.
  241. """
  242. try:
  243. pd.DataFrame(self._trials).to_excel(path)
  244. except Exception as e:
  245. err = ("Failed to write trials to excel. "
  246. "Exit with error: {}".format(e))
  247. self._logger.log_and_raise_error(err)
  248. if __name__ == "__main__":
  249. # elementary example
  250. from sklearn.datasets import load_breast_cancer
  251. from sklearn.metrics import accuracy_score, precision_score
  252. from cdplib.gridsearch.space_sample import space
  253. from cdplib.log import Log
  254. from cdplib.db_handlers import MongodbHandler
  255. import pickle
  256. import pandas as pd
  257. import os
  258. trials_path = "gridsearch_trials_TEST.pkl"
  259. additional_metrics = {"precision": precision_score}
  260. strategy_name = "strategy_1"
  261. data_path = "data_TEST.h5"
  262. cv_path = "cv_TEST.pkl"
  263. collection_name = 'TEST_' + strategy_name
  264. logger = Log("GridSearchPipelineSelector__TEST:")
  265. logger.info("Start test")
  266. data_loader = load_breast_cancer()
  267. X = data_loader["data"]
  268. y = data_loader["target"]
  269. pd.DataFrame(X).to_hdf(data_path, key="X_train")
  270. pd.Series(y).to_hdf(data_path, key="y_train")
  271. cv = [(list(range(len(X)//3)), list(range(len(X)//3, len(X)))),
  272. (list(range(2*len(X)//3)), list(range(2*len(X)//3, len(X))))]
  273. pickle.dump(cv, open(cv_path, "wb"))
  274. gs = GridSearchPipelineSelector(cost_func=accuracy_score,
  275. greater_is_better=True,
  276. trials_path=trials_path,
  277. additional_metrics=additional_metrics,
  278. strategy_name=strategy_name,
  279. stdout_log_level="WARNING")
  280. gs.attach_space(space=space)
  281. gs.attach_data_from_hdf5(data_hdf5_store_path=data_path,
  282. cv_pickle_path=cv_path)
  283. save_method = MongodbHandler().insert_data_into_collection
  284. save_kwargs = {'collection_name': collection_name}
  285. gs.configer_summary_saving(save_method=save_method,
  286. kwargs=save_kwargs)
  287. gs.run_trials()
  288. logger.info("Best trial: {}".format(gs.best_trial))
  289. logger.info("Total tuning time: {}".format(gs.total_tuning_time))
  290. for file in [trials_path, data_path, cv_path]:
  291. os.remove(file)
  292. logger.info("End test")