|
@@ -66,7 +66,7 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
self._trials = self._trials or []
|
|
self._trials = self._trials or []
|
|
|
|
|
|
done_trial_ids = [{"name": trial["name"],
|
|
done_trial_ids = [{"name": trial["name"],
|
|
- "param_set": trial["param_set"]}
|
|
|
|
|
|
+ "params": trial["params"]}
|
|
for trial in self._trials]
|
|
for trial in self._trials]
|
|
|
|
|
|
# list (generator) of (flattened) dictionaries
|
|
# list (generator) of (flattened) dictionaries
|
|
@@ -75,7 +75,7 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
# from the space definition.
|
|
# from the space definition.
|
|
space_unfolded = ({"name": pipeline_dist["name"],
|
|
space_unfolded = ({"name": pipeline_dist["name"],
|
|
"pipeline": pipeline_dist["pipeline"],
|
|
"pipeline": pipeline_dist["pipeline"],
|
|
- "param_set": param_set}
|
|
|
|
|
|
+ "params": param_set}
|
|
for pipeline_dist in self._space
|
|
for pipeline_dist in self._space
|
|
for param_set in
|
|
for param_set in
|
|
(dict(ChainMap(*tup)) for tup in
|
|
(dict(ChainMap(*tup)) for tup in
|
|
@@ -86,7 +86,7 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
for space_element in space_unfolded:
|
|
for space_element in space_unfolded:
|
|
|
|
|
|
trial_id = {"name": space_element["name"],
|
|
trial_id = {"name": space_element["name"],
|
|
- "param_set": space_element["param_set"]}
|
|
|
|
|
|
+ "params": space_element["params"]}
|
|
|
|
|
|
if trial_id in done_trial_ids:
|
|
if trial_id in done_trial_ids:
|
|
continue
|
|
continue
|
|
@@ -94,10 +94,10 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
result = self._objective(space_element)
|
|
result = self._objective(space_element)
|
|
|
|
|
|
pipeline = space_element["pipeline"].set_params(
|
|
pipeline = space_element["pipeline"].set_params(
|
|
- **space_element["param_set"])
|
|
|
|
|
|
+ **space_element["params"])
|
|
|
|
|
|
self._trials.append({"name": space_element["name"],
|
|
self._trials.append({"name": space_element["name"],
|
|
- "param_set": space_element["param_set"],
|
|
|
|
|
|
+ "params": space_element["params"],
|
|
"pipeline": pipeline,
|
|
"pipeline": pipeline,
|
|
"result": result})
|
|
"result": result})
|
|
|
|
|
|
@@ -164,7 +164,7 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
|
|
|
- # Small test
|
|
|
|
|
|
+ # Runn through test/example
|
|
|
|
|
|
from sklearn.datasets import load_breast_cancer
|
|
from sklearn.datasets import load_breast_cancer
|
|
from sklearn.metrics import accuracy_score
|
|
from sklearn.metrics import accuracy_score
|
|
@@ -190,6 +190,6 @@ if __name__ == "__main__":
|
|
gs.run_trials()
|
|
gs.run_trials()
|
|
|
|
|
|
logger.info("Best trial: {}".format(gs.best_trial))
|
|
logger.info("Best trial: {}".format(gs.best_trial))
|
|
- logger.info("Best trial: {}".format(gs.best_trial_pipeline))
|
|
|
|
|
|
+ logger.info("Best trial pipeline: {}".format(gs.best_trial_pipeline))
|
|
|
|
|
|
logger.info("End test")
|
|
logger.info("End test")
|