GridSearchPipelineSelector.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Sep 30 14:15:17 2020
  5. @author: tanya
  6. @description:a class for selecting a machine learning
  7. pipeline from a deterministic space of parameter distributions
  8. over multiple pipelines.
  9. The selection is though in such a way that a Trials object is being
  10. maintained during the tuning process from which one can retrieve
  11. the best pipeline so far as well as the entire tuning history
  12. if needed.
  13. """
  14. import os
  15. import datetime
  16. import numpy as np
  17. from copy import deepcopy
  18. from itertools import product
  19. from collections import ChainMap
  20. from sklearn.pipeline import Pipeline
  21. from typing import Callable, Optional, Literal, Dict, Union, List
  22. from cdplib.log import Log
  23. from cdplib.pipeline_selector.PipelineSelector import PipelineSelector
  24. class GridSearchPipelineSelector(PipelineSelector):
  25. """
  26. A class for selecting a machine learning
  27. pipeline from a deterministic space of parameter distributions
  28. over multiple pipelines.
  29. The selection is though in such a way that a Trials object is being
  30. maintained during the tuning process from which one can retrieve
  31. the best pipeline so far as well as the entire tuning history
  32. if needed.
  33. """
  34. def __init__(self,
  35. cost_func: Union[Callable, str],
  36. greater_is_better: bool,
  37. trials_path: str,
  38. backup_trials_freq: Optional[int] = None,
  39. cross_validation_needs_scorer: bool = True,
  40. cross_val_averaging_func: Callable = np.mean,
  41. additional_metrics: Optional[Dict[str, Callable]] = None,
  42. strategy_name: Optional[str] = None,
  43. stdout_log_level: Literal["INFO", "WARNING", "ERROR"]
  44. = "INFO"
  45. ):
  46. """
  47. ::param Callable cost_func: function to minimize or maximize
  48. over the elements of a given (pipeline/hyperparameter) space
  49. :param bool greater_is_better: when True
  50. cost_func is maximized, else minimized.
  51. :param str trials_path: path at which the trials object is saved
  52. in binary format. From the trials object we can
  53. select information about the obtained scores, score variations,
  54. and pipelines, and parameters tried out so far. If a trials object
  55. already exists at the given path, it is loaded and the
  56. search is continued, else, the search is started from scratch.
  57. :param backup_trials_freq: frequecy in interations (trials)
  58. of saving the trials object at the trials_path.
  59. if None, the trials object is backed up avery time
  60. the score improves.
  61. :param Callable cross_val_averaging_func: Function to aggregate
  62. the cross-validation scores.
  63. Example different from the mean: mean - c*var.
  64. :param additional_metics: dict of additional metrics to save
  65. of the form {"metric_name": metric} where metric is a Callable.
  66. :param str strategy_name:
  67. a strategy is defined by the data set (columns/features and rows),
  68. cv object, cost function.
  69. When the strategy changes, one must start with new trials.
  70. :param str stdout_log_level: can be INFO, WARNING, ERROR
  71. """
  72. try:
  73. super().__init__(cost_func=cost_func,
  74. greater_is_better=greater_is_better,
  75. trials_path=trials_path,
  76. backup_trials_freq=backup_trials_freq,
  77. cross_validation_needs_scorer=
  78. cross_validation_needs_scorer,
  79. cross_val_averaging_func=cross_val_averaging_func,
  80. additional_metrics=additional_metrics,
  81. strategy_name=strategy_name,
  82. stdout_log_level=stdout_log_level)
  83. self._logger = Log("GridsearchPipelineSelector: ",
  84. stdout_log_level=stdout_log_level)
  85. self._trials = self._trials or []
  86. except Exception as e:
  87. err = "Failed initialization. Exit with error: {}".format(e)
  88. self._logger.log_and_raise_error(err)
  89. def run_trials(self) -> None:
  90. """
  91. """
  92. try:
  93. assert(self.attached_space),\
  94. "Parameter distribution space must be attached"
  95. # XXX Tanya: if the list of values is empty
  96. # in the space element, remove it
  97. done_trial_ids = [{"name": trial["name"],
  98. "params": trial["params"],
  99. "status": trial["status"]}
  100. for trial in self._trials]
  101. # list (generator) of (flattened) dictionaries
  102. # with all different combinations of
  103. # parameters for different pipelines
  104. # from the space definition.
  105. space_unfolded = ({"name": param_dist["name"],
  106. "pipeline": param_dist["pipeline"],
  107. "params": param_set}
  108. for param_dist in self._space
  109. for param_set in
  110. (dict(ChainMap(*tup)) for tup in
  111. product(*[[{k: v} for v in
  112. param_dist["params"][k]]
  113. for k in param_dist["params"]])))
  114. for space_element in space_unfolded:
  115. # uniquely identifies the current space element
  116. trial_id = {"name": space_element["name"],
  117. "params": space_element["params"],
  118. "status": 'ok'}
  119. # verify if the current pipline/parameters
  120. # were already tested before
  121. if trial_id in done_trial_ids:
  122. continue
  123. result = self._objective(space_element)
  124. pipeline = deepcopy(space_element["pipeline"])
  125. pipeline = pipeline.set_params(**space_element["params"])
  126. trial = {"name": space_element["name"],
  127. "params": space_element["params"],
  128. "pipeline": pipeline}
  129. trial.update(result)
  130. self._trials.append(trial)
  131. self.finished_tuning = True
  132. self.total_tuning_time = datetime.datetime.today()\
  133. - self.start_tuning_time
  134. self._backup_trials()
  135. except Exception as e:
  136. err = "Failed to run trials. Exit with error: {}".format(e)
  137. self._logger.log_and_raise_error(err)
  138. @property
  139. def number_of_trials(self) -> Union[int, None]:
  140. """
  141. Number of trials already run in the current trials object
  142. """
  143. try:
  144. return len(self._trials)
  145. except Exception as e:
  146. err = ("Failed to retrieve the number of trials. "
  147. "Exit with error: {}".format(e))
  148. self._logger.log_and_raise_error(err)
  149. @property
  150. def best_trial(self) -> Union[dict, None]:
  151. """
  152. """
  153. try:
  154. assert(len(self._trials) > 0),\
  155. ("Trials object is empty. "
  156. "Call run_trials method.")
  157. return max(self._trials, key=lambda x: x["score"])
  158. except Exception as e:
  159. err = ("Could not retrieve the best trial. "
  160. "Exit with error: {}".format(e))
  161. self._logger.log_and_raise_error(err)
  162. @property
  163. def best_trial_score(self) -> Union[float, None]:
  164. '''
  165. '''
  166. try:
  167. assert(len(self._trials) > 0),\
  168. ("Trials object is empty. "
  169. "Call run_trials method.")
  170. return self.best_trial["score"]
  171. except Exception as e:
  172. err = ("Could not retrieve the best trial. "
  173. "Exit with error: {}".format(e))
  174. self._logger.log_and_raise_error(err)
  175. @property
  176. def best_trial_score_variance(self) -> Union[float, None]:
  177. '''
  178. '''
  179. try:
  180. assert(len(self._trials) > 0),\
  181. ("Trials object is empty. "
  182. "Call run_trials method.")
  183. return self.best_trial["score_variance"]
  184. except Exception as e:
  185. err = ("Could not retrieve the best trial. "
  186. "Exit with error: {}".format(e))
  187. self._logger.log_and_raise_error(err)
  188. @property
  189. def best_trial_pipeline(self) -> Union[Pipeline, None]:
  190. '''
  191. '''
  192. try:
  193. assert(len(self._trials) > 0),\
  194. ("Trials object is empty. "
  195. "Call run_trials method.")
  196. return self.best_trial["pipeline"]
  197. except Exception as e:
  198. err = ("Could not retrieve the best trial. "
  199. "Exit with error: {}".format(e))
  200. self._logger.log_and_raise_error(err)
  201. def get_n_best_trial_pipelines(self, n: int)\
  202. -> Union[List[Pipeline], None]:
  203. """
  204. N best pipelines with corresponding
  205. best hyperparameters
  206. """
  207. try:
  208. assert(len(self._trials) > 0),\
  209. ("Trials object is empty. "
  210. "Call run_trials method.")
  211. return [trial["pipeline"] for trial in
  212. sorted(self._trials, key=lambda x: x["score"],
  213. reverse=True)[:n]]
  214. except Exception as e:
  215. err = ("Failed to retrieve n best trials. "
  216. "Exit with error: {}".format(e))
  217. self._logger.log_and_raise_error(err)
  218. def get_n_best_trial_pipelines_of_each_type(self, n: int)\
  219. -> Union[Dict[str, List[Pipeline]], None]:
  220. """
  221. If the hyperparameter search is done over multiple
  222. pipelines, then returns n different pipeline-types
  223. with corresponding hyperparameters
  224. """
  225. try:
  226. assert(len(self._trials) > 0),\
  227. ("Trials object is empty. "
  228. "Call run_trials method.")
  229. return pd.DataFrame(self._trials)\
  230. .sort_values(by=["name", "score"],
  231. ascending=False)\
  232. .groupby("name")\
  233. .head(n)\
  234. .groupby("name")["pipeline"]\
  235. .apply(lambda x: list(x))\
  236. .to_dict()
  237. except Exception as e:
  238. err = ("Failed to retrieve n best trials of each type."
  239. "Exit with error: {}".format(e))
  240. self._logger.log_and_raise_error(err)
  241. def trials_to_excel(self, path: str) -> None:
  242. """
  243. Trials object in the shape of table written to excel,
  244. should contain the run number, pipeline (as str),
  245. hyperparamters (as str), self.best_result (see self._objective method)
  246. as well as additional information configured
  247. through self.save_result method.
  248. """
  249. try:
  250. pd.DataFrame(self._trials).to_excel(path)
  251. except Exception as e:
  252. err = ("Failed to write trials to excel. "
  253. "Exit with error: {}".format(e))
  254. self._logger.log_and_raise_error(err)
  255. if __name__ == "__main__":
  256. # elementary example
  257. from sklearn.datasets import load_breast_cancer
  258. from sklearn.metrics import accuracy_score, precision_score
  259. from cdplib.gridsearch.space_sample import space
  260. from cdplib.db_handlers import MongodbHandler
  261. import pickle
  262. import pandas as pd
  263. trials_path = "gridsearch_trials_TEST.pkl"
  264. additional_metrics = {"precision": precision_score}
  265. strategy_name = "strategy_1"
  266. data_path = "data_TEST.h5"
  267. cv_path = "cv_TEST.pkl"
  268. collection_name = 'TEST_' + strategy_name
  269. logger = Log("GridSearchPipelineSelector__TEST:")
  270. logger.info("Start test")
  271. data_loader = load_breast_cancer()
  272. X = data_loader["data"]
  273. y = data_loader["target"]
  274. pd.DataFrame(X).to_hdf(data_path, key="X_train")
  275. pd.Series(y).to_hdf(data_path, key="y_train")
  276. cv = [(list(range(len(X)//3)), list(range(len(X)//3, len(X)))),
  277. (list(range(2*len(X)//3)), list(range(2*len(X)//3, len(X))))]
  278. pickle.dump(cv, open(cv_path, "wb"))
  279. gs = GridSearchPipelineSelector(cost_func=accuracy_score,
  280. greater_is_better=True,
  281. trials_path=trials_path,
  282. additional_metrics=additional_metrics,
  283. strategy_name=strategy_name,
  284. stdout_log_level="WARNING")
  285. gs.attach_space(space=space)
  286. gs.attach_data_from_hdf5(data_hdf5_store_path=data_path,
  287. cv_pickle_path=cv_path)
  288. save_method = MongodbHandler().insert_data_into_collection
  289. save_kwargs = {'collection_name': collection_name}
  290. gs.configer_summary_saving(save_method=save_method,
  291. kwargs=save_kwargs)
  292. gs.run_trials()
  293. logger.info("Best trial: {}".format(gs.best_trial))
  294. logger.info("Total tuning time: {}".format(gs.total_tuning_time))
  295. for file in [trials_path, data_path, cv_path]:
  296. os.remove(file)
  297. logger.info("End test")
  298. # XXX Tanya check warnings