浏览代码

added total_tuning_time property

tanja 3 年之前
父节点
当前提交
ff659f0836
共有 2 个文件被更改,包括 10 次插入4 次删除
  1. 4 2
      cdplib/gridsearch/GridSearchPipelineSelector.py
  2. 6 2
      cdplib/hyperopt/HyperoptPipelineSelector.py

+ 4 - 2
cdplib/gridsearch/GridSearchPipelineSelector.py

@@ -15,7 +15,7 @@ Created on Wed Sep 30 14:15:17 2020
 
 import os
 import sys
-import time
+import datetime
 from itertools import product
 from collections import ChainMap
 from sklearn.pipeline import Pipeline
@@ -141,7 +141,8 @@ class GridSearchPipelineSelector(PipelineSelector):
 
         self.finished_tuning = True
 
-        self.total_tuning_time = time.time() - self.start_tuning_time
+        self.total_tuning_time = datetime.datetime.today()\
+            - self.start_tuning_time
 
         self._backup_trials()
 
@@ -344,6 +345,7 @@ if __name__ == "__main__":
     gs.run_trials()
 
     logger.info("Best trial: {}".format(gs.best_trial))
+    logger.info("Total tuning time: {}".format(gs.total_tuning_time))
 
     for file in [trials_path, data_path, cv_path]:
         os.remove(file)

+ 6 - 2
cdplib/hyperopt/HyperoptPipelineSelector.py

@@ -19,7 +19,7 @@ import pickle
 
 from copy import deepcopy
 
-import time
+import datetime
 
 from typing import Callable
 
@@ -165,7 +165,8 @@ class HyperoptPipelineSelector(PipelineSelector):
 
             self.finished_tuning = True
 
-            self.total_tuning_time = time.time() - self.start_tuning_time
+            self.total_tuning_time = datetime.datetime.today()\
+                - self.start_tuning_time
 
             self._backup_trials()
 
@@ -446,6 +447,9 @@ if __name__ == '__main__':
 
     hs.run_trials(niter=10)
 
+    logger.info("Best Trial: {}".format(hs.best_trial))
+    logger.info("Total tuning time: {}".format(hs.total_tuning_time))
+
     for file in [trials_path, data_path, cv_path]:
         os.remove(file)