|
@@ -284,7 +284,10 @@ class GridSearchPipelineSelector(PipelineSelector):
|
|
|
.sort_values(by=["name", "score"],
|
|
|
ascending=False)\
|
|
|
.groupby("name")\
|
|
|
- .head(n)[["pipeline"]]
|
|
|
+ .head(n)\
|
|
|
+ .groupby("name")["pipeline"]\
|
|
|
+ .apply(lambda x: list(x))\
|
|
|
+ .to_dict()
|
|
|
|
|
|
except Exception as e:
|
|
|
err = ("Failed to retrieve n best trials of each type."
|