decisiontree.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # Natural Language Toolkit: Decision Tree Classifiers
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. """
  8. A classifier model that decides which label to assign to a token on
  9. the basis of a tree structure, where branches correspond to conditions
  10. on feature values, and leaves correspond to label assignments.
  11. """
  12. from __future__ import print_function, unicode_literals, division
  13. from collections import defaultdict
  14. from nltk.probability import FreqDist, MLEProbDist, entropy
  15. from nltk.classify.api import ClassifierI
  16. from nltk.compat import python_2_unicode_compatible
  17. @python_2_unicode_compatible
  18. class DecisionTreeClassifier(ClassifierI):
  19. def __init__(self, label, feature_name=None, decisions=None, default=None):
  20. """
  21. :param label: The most likely label for tokens that reach
  22. this node in the decision tree. If this decision tree
  23. has no children, then this label will be assigned to
  24. any token that reaches this decision tree.
  25. :param feature_name: The name of the feature that this
  26. decision tree selects for.
  27. :param decisions: A dictionary mapping from feature values
  28. for the feature identified by ``feature_name`` to
  29. child decision trees.
  30. :param default: The child that will be used if the value of
  31. feature ``feature_name`` does not match any of the keys in
  32. ``decisions``. This is used when constructing binary
  33. decision trees.
  34. """
  35. self._label = label
  36. self._fname = feature_name
  37. self._decisions = decisions
  38. self._default = default
  39. def labels(self):
  40. labels = [self._label]
  41. if self._decisions is not None:
  42. for dt in self._decisions.values():
  43. labels.extend(dt.labels())
  44. if self._default is not None:
  45. labels.extend(self._default.labels())
  46. return list(set(labels))
  47. def classify(self, featureset):
  48. # Decision leaf:
  49. if self._fname is None:
  50. return self._label
  51. # Decision tree:
  52. fval = featureset.get(self._fname)
  53. if fval in self._decisions:
  54. return self._decisions[fval].classify(featureset)
  55. elif self._default is not None:
  56. return self._default.classify(featureset)
  57. else:
  58. return self._label
  59. def error(self, labeled_featuresets):
  60. errors = 0
  61. for featureset, label in labeled_featuresets:
  62. if self.classify(featureset) != label:
  63. errors += 1
  64. return errors / len(labeled_featuresets)
  65. def pretty_format(self, width=70, prefix='', depth=4):
  66. """
  67. Return a string containing a pretty-printed version of this
  68. decision tree. Each line in this string corresponds to a
  69. single decision tree node or leaf, and indentation is used to
  70. display the structure of the decision tree.
  71. """
  72. # [xx] display default!!
  73. if self._fname is None:
  74. n = width - len(prefix) - 15
  75. return '{0}{1} {2}\n'.format(prefix, '.' * n, self._label)
  76. s = ''
  77. for i, (fval, result) in enumerate(sorted(self._decisions.items())):
  78. hdr = '{0}{1}={2}? '.format(prefix, self._fname, fval)
  79. n = width - 15 - len(hdr)
  80. s += '{0}{1} {2}\n'.format(hdr, '.' * (n), result._label)
  81. if result._fname is not None and depth > 1:
  82. s += result.pretty_format(width, prefix + ' ', depth - 1)
  83. if self._default is not None:
  84. n = width - len(prefix) - 21
  85. s += '{0}else: {1} {2}\n'.format(prefix, '.' * n, self._default._label)
  86. if self._default._fname is not None and depth > 1:
  87. s += self._default.pretty_format(width, prefix + ' ', depth - 1)
  88. return s
  89. def pseudocode(self, prefix='', depth=4):
  90. """
  91. Return a string representation of this decision tree that
  92. expresses the decisions it makes as a nested set of pseudocode
  93. if statements.
  94. """
  95. if self._fname is None:
  96. return "{0}return {1!r}\n".format(prefix, self._label)
  97. s = ''
  98. for (fval, result) in sorted(self._decisions.items()):
  99. s += '{0}if {1} == {2!r}: '.format(prefix, self._fname, fval)
  100. if result._fname is not None and depth > 1:
  101. s += '\n' + result.pseudocode(prefix + ' ', depth - 1)
  102. else:
  103. s += 'return {0!r}\n'.format(result._label)
  104. if self._default is not None:
  105. if len(self._decisions) == 1:
  106. s += '{0}if {1} != {2!r}: '.format(
  107. prefix, self._fname, list(self._decisions.keys())[0]
  108. )
  109. else:
  110. s += '{0}else: '.format(prefix)
  111. if self._default._fname is not None and depth > 1:
  112. s += '\n' + self._default.pseudocode(prefix + ' ', depth - 1)
  113. else:
  114. s += 'return {0!r}\n'.format(self._default._label)
  115. return s
  116. def __str__(self):
  117. return self.pretty_format()
  118. @staticmethod
  119. def train(
  120. labeled_featuresets,
  121. entropy_cutoff=0.05,
  122. depth_cutoff=100,
  123. support_cutoff=10,
  124. binary=False,
  125. feature_values=None,
  126. verbose=False,
  127. ):
  128. """
  129. :param binary: If true, then treat all feature/value pairs as
  130. individual binary features, rather than using a single n-way
  131. branch for each feature.
  132. """
  133. # Collect a list of all feature names.
  134. feature_names = set()
  135. for featureset, label in labeled_featuresets:
  136. for fname in featureset:
  137. feature_names.add(fname)
  138. # Collect a list of the values each feature can take.
  139. if feature_values is None and binary:
  140. feature_values = defaultdict(set)
  141. for featureset, label in labeled_featuresets:
  142. for fname, fval in featureset.items():
  143. feature_values[fname].add(fval)
  144. # Start with a stump.
  145. if not binary:
  146. tree = DecisionTreeClassifier.best_stump(
  147. feature_names, labeled_featuresets, verbose
  148. )
  149. else:
  150. tree = DecisionTreeClassifier.best_binary_stump(
  151. feature_names, labeled_featuresets, feature_values, verbose
  152. )
  153. # Refine the stump.
  154. tree.refine(
  155. labeled_featuresets,
  156. entropy_cutoff,
  157. depth_cutoff - 1,
  158. support_cutoff,
  159. binary,
  160. feature_values,
  161. verbose,
  162. )
  163. # Return it
  164. return tree
  165. @staticmethod
  166. def leaf(labeled_featuresets):
  167. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  168. return DecisionTreeClassifier(label)
  169. @staticmethod
  170. def stump(feature_name, labeled_featuresets):
  171. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  172. # Find the best label for each value.
  173. freqs = defaultdict(FreqDist) # freq(label|value)
  174. for featureset, label in labeled_featuresets:
  175. feature_value = featureset.get(feature_name)
  176. freqs[feature_value][label] += 1
  177. decisions = dict(
  178. (val, DecisionTreeClassifier(freqs[val].max())) for val in freqs
  179. )
  180. return DecisionTreeClassifier(label, feature_name, decisions)
  181. def refine(
  182. self,
  183. labeled_featuresets,
  184. entropy_cutoff,
  185. depth_cutoff,
  186. support_cutoff,
  187. binary=False,
  188. feature_values=None,
  189. verbose=False,
  190. ):
  191. if len(labeled_featuresets) <= support_cutoff:
  192. return
  193. if self._fname is None:
  194. return
  195. if depth_cutoff <= 0:
  196. return
  197. for fval in self._decisions:
  198. fval_featuresets = [
  199. (featureset, label)
  200. for (featureset, label) in labeled_featuresets
  201. if featureset.get(self._fname) == fval
  202. ]
  203. label_freqs = FreqDist(label for (featureset, label) in fval_featuresets)
  204. if entropy(MLEProbDist(label_freqs)) > entropy_cutoff:
  205. self._decisions[fval] = DecisionTreeClassifier.train(
  206. fval_featuresets,
  207. entropy_cutoff,
  208. depth_cutoff,
  209. support_cutoff,
  210. binary,
  211. feature_values,
  212. verbose,
  213. )
  214. if self._default is not None:
  215. default_featuresets = [
  216. (featureset, label)
  217. for (featureset, label) in labeled_featuresets
  218. if featureset.get(self._fname) not in self._decisions
  219. ]
  220. label_freqs = FreqDist(label for (featureset, label) in default_featuresets)
  221. if entropy(MLEProbDist(label_freqs)) > entropy_cutoff:
  222. self._default = DecisionTreeClassifier.train(
  223. default_featuresets,
  224. entropy_cutoff,
  225. depth_cutoff,
  226. support_cutoff,
  227. binary,
  228. feature_values,
  229. verbose,
  230. )
  231. @staticmethod
  232. def best_stump(feature_names, labeled_featuresets, verbose=False):
  233. best_stump = DecisionTreeClassifier.leaf(labeled_featuresets)
  234. best_error = best_stump.error(labeled_featuresets)
  235. for fname in feature_names:
  236. stump = DecisionTreeClassifier.stump(fname, labeled_featuresets)
  237. stump_error = stump.error(labeled_featuresets)
  238. if stump_error < best_error:
  239. best_error = stump_error
  240. best_stump = stump
  241. if verbose:
  242. print(
  243. (
  244. 'best stump for {:6d} toks uses {:20} err={:6.4f}'.format(
  245. len(labeled_featuresets), best_stump._fname, best_error
  246. )
  247. )
  248. )
  249. return best_stump
  250. @staticmethod
  251. def binary_stump(feature_name, feature_value, labeled_featuresets):
  252. label = FreqDist(label for (featureset, label) in labeled_featuresets).max()
  253. # Find the best label for each value.
  254. pos_fdist = FreqDist()
  255. neg_fdist = FreqDist()
  256. for featureset, label in labeled_featuresets:
  257. if featureset.get(feature_name) == feature_value:
  258. pos_fdist[label] += 1
  259. else:
  260. neg_fdist[label] += 1
  261. decisions = {}
  262. default = label
  263. # But hopefully we have observations!
  264. if pos_fdist.N() > 0:
  265. decisions = {feature_value: DecisionTreeClassifier(pos_fdist.max())}
  266. if neg_fdist.N() > 0:
  267. default = DecisionTreeClassifier(neg_fdist.max())
  268. return DecisionTreeClassifier(label, feature_name, decisions, default)
  269. @staticmethod
  270. def best_binary_stump(
  271. feature_names, labeled_featuresets, feature_values, verbose=False
  272. ):
  273. best_stump = DecisionTreeClassifier.leaf(labeled_featuresets)
  274. best_error = best_stump.error(labeled_featuresets)
  275. for fname in feature_names:
  276. for fval in feature_values[fname]:
  277. stump = DecisionTreeClassifier.binary_stump(
  278. fname, fval, labeled_featuresets
  279. )
  280. stump_error = stump.error(labeled_featuresets)
  281. if stump_error < best_error:
  282. best_error = stump_error
  283. best_stump = stump
  284. if verbose:
  285. if best_stump._decisions:
  286. descr = '{0}={1}'.format(
  287. best_stump._fname, list(best_stump._decisions.keys())[0]
  288. )
  289. else:
  290. descr = '(default)'
  291. print(
  292. (
  293. 'best stump for {:6d} toks uses {:20} err={:6.4f}'.format(
  294. len(labeled_featuresets), descr, best_error
  295. )
  296. )
  297. )
  298. return best_stump
  299. ##//////////////////////////////////////////////////////
  300. ## Demo
  301. ##//////////////////////////////////////////////////////
  302. def f(x):
  303. return DecisionTreeClassifier.train(x, binary=True, verbose=True)
  304. def demo():
  305. from nltk.classify.util import names_demo, binary_names_demo_features
  306. classifier = names_demo(
  307. f, binary_names_demo_features # DecisionTreeClassifier.train,
  308. )
  309. print(classifier.pp(depth=7))
  310. print(classifier.pseudocode(depth=7))
  311. if __name__ == '__main__':
  312. demo()