Bladeren bron

fixed space element definition

tanja 3 jaren geleden
bovenliggende
commit
98769c9e78
1 gewijzigde bestanden met toevoegingen van 7 en 7 verwijderingen
  1. 7 7
      cdplib/gridsearch/gridsearch.py

+ 7 - 7
cdplib/gridsearch/gridsearch.py

@@ -66,7 +66,7 @@ class GridSearchPipelineSelector(PipelineSelector):
         self._trials = self._trials or []
 
         done_trial_ids = [{"name": trial["name"],
-                           "param_set": trial["param_set"]}
+                           "params": trial["params"]}
                           for trial in self._trials]
 
         # list (generator) of (flattened) dictionaries
@@ -75,7 +75,7 @@ class GridSearchPipelineSelector(PipelineSelector):
         # from the space definition.
         space_unfolded = ({"name": pipeline_dist["name"],
                            "pipeline": pipeline_dist["pipeline"],
-                           "param_set": param_set}
+                           "params": param_set}
                           for pipeline_dist in self._space
                           for param_set in
                           (dict(ChainMap(*tup)) for tup in
@@ -86,7 +86,7 @@ class GridSearchPipelineSelector(PipelineSelector):
         for space_element in space_unfolded:
 
             trial_id = {"name": space_element["name"],
-                        "param_set": space_element["param_set"]}
+                        "params": space_element["params"]}
 
             if trial_id in done_trial_ids:
                 continue
@@ -94,10 +94,10 @@ class GridSearchPipelineSelector(PipelineSelector):
             result = self._objective(space_element)
 
             pipeline = space_element["pipeline"].set_params(
-                    **space_element["param_set"])
+                    **space_element["params"])
 
             self._trials.append({"name": space_element["name"],
-                                 "param_set": space_element["param_set"],
+                                 "params": space_element["params"],
                                  "pipeline": pipeline,
                                  "result": result})
 
@@ -164,7 +164,7 @@ class GridSearchPipelineSelector(PipelineSelector):
 
 if __name__ == "__main__":
 
-    # Small test
+    # Runn through test/example
 
     from sklearn.datasets import load_breast_cancer
     from sklearn.metrics import accuracy_score
@@ -190,6 +190,6 @@ if __name__ == "__main__":
     gs.run_trials()
 
     logger.info("Best trial: {}".format(gs.best_trial))
-    logger.info("Best trial: {}".format(gs.best_trial_pipeline))
+    logger.info("Best trial pipeline: {}".format(gs.best_trial_pipeline))
 
     logger.info("End test")