FineTunedClassiferCV.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Apr 23 08:51:53 2020
  5. @author: tanya
  6. @description: class for fine-tuning a sklearn classifier
  7. (optimizing the probability threshold)
  8. """
  9. import pandas as pd
  10. import numpy as np
  11. from typing import Callable
  12. from sklearn.base import (BaseEstimator, ClassifierMixin,
  13. clone, MetaEstimatorMixin)
  14. from cdplib.log import Log
  15. from cdplib.utils.TyperConverter import TypeConverter
  16. class FineTunedClassifierCV(BaseEstimator, ClassifierMixin,
  17. MetaEstimatorMixin):
  18. """
  19. Probability threshold tuning for a given estimator.
  20. Overrides the method predict of the given sklearn classifer
  21. and returns predictions with the optimal value of
  22. the probability threshold.
  23. An object of this class can be passed to an sklearn Pipeline
  24. """
  25. def __init__(self, estimator, cost_func: Callable, greater_is_better: bool,
  26. cv=None, threshold_step: float = 0.1):
  27. """
  28. """
  29. self.estimator = estimator
  30. self.is_fitted = False
  31. self.greater_is_better = greater_is_better
  32. if cv is None:
  33. self.cv = ...
  34. else:
  35. self.cv = cv
  36. self.cost_func = cost_func
  37. self.threshold_step = threshold_step
  38. self.optimal_threshold = 0.5
  39. self._logger = Log("FineTunedClassifyCV")
  40. def _get_best_threshold(self, y_val: (pd.DataFrame, np.array),
  41. proba_pred: (pd.DataFrame, np.array)):
  42. '''
  43. '''
  44. costs = {}
  45. for t in np.arange(self.threshold_step, 1, self.threshold_step):
  46. costs[t] = self.cost_func(y_val, (proba_pred >= t).astype(int))
  47. if self.greater_is_better:
  48. return max(costs, key=costs.get)
  49. else:
  50. return min(costs, key=costs.get)
  51. def fit(self, X: (pd.DataFrame, np.array),
  52. y: (pd.DataFrame, np.array) = None,
  53. **fit_args):
  54. """
  55. """
  56. X = TypeConverter().convert_to_ndarray(X)
  57. if y is not None:
  58. y = TypeConverter().convert_to_ndarray(X)
  59. optimal_thrs_per_fold = []
  60. for train_inds, val_inds in self.cv:
  61. X_train, X_val = X[train_inds], X[val_inds]
  62. if y is not None:
  63. y_train, y_val = y[train_inds], y[val_inds]
  64. else:
  65. y_train, y_val = None, None
  66. estimator = clone(fine_tuned_clf.estimator)
  67. estimator.fit(X_train, y_train, **fit_args)
  68. proba_pred = estimator.predict_proba(X_val)
  69. optimal_thr = self._get_best_threshold(y_val, proba_pred)
  70. optimal_thrs_per_fold.append(optimal_thr)
  71. self.optimal_threshold = np.mean(optimal_thrs_per_fold)
  72. self.estimator.fit(X, **fit_args)
  73. def predict(self, X: (pd.DataFrame, np.array)) -> np.array:
  74. """
  75. """
  76. if self.is_fitted:
  77. proba_pred = self.estimator.predict_proba(X)
  78. return (proba_pred >= self.optimal_threshold).astype(int)
  79. else:
  80. self._logger.warn("You should fit first")
  81. def get_params(self):
  82. """
  83. """
  84. params = self.estimator.get_params()
  85. params.update({"cv": self.cv, "cost_func": self.cost_func})
  86. return params
  87. def set_params(self, **params: dict):
  88. """
  89. """
  90. for param in params:
  91. if param == "cv":
  92. self.cv = params[param]
  93. params.pop(param)
  94. elif param == "cost_func":
  95. self.cost_func = params[param]
  96. params.pop(param)
  97. self.estimator.set_params(**params)
  98. if __name__ == "__main__":
  99. # test
  100. from sklearn.datasets import load_iris
  101. from sklearn.metrics import accuracy_score
  102. import gc
  103. from xgboost import XGBRFClassifier
  104. data = load_iris()
  105. X, y = data["data"], data["target"]
  106. y = (y==1).astype(int)
  107. del data
  108. gc.collect()
  109. # make a custom cv object
  110. val_len = len(X)//10
  111. split_inds = range(len(X)//2, len(X), val_len)
  112. cv = []
  113. for i in split_inds:
  114. train_inds = list(range(i))
  115. val_inds = list(range(i, i + val_len))
  116. cv.append((train_inds, val_inds))
  117. clf = XGBRFClassifier()
  118. fine_tuned_clf = FineTunedClassifierCV(estimator=clf,
  119. cv=cv,
  120. greater_is_better=True,
  121. cost_func=accuracy_score)
  122. fine_tuned_clf.fit(X=X, y=y)