demo.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Transformation-based learning
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Author: Marcus Uneson <marcus.uneson@gmail.com>
  6. # based on previous (nltk2) version by
  7. # Christopher Maloof, Edward Loper, Steven Bird
  8. # URL: <http://nltk.org/>
  9. # For license information, see LICENSE.TXT
  10. from __future__ import print_function, absolute_import, division
  11. import os
  12. import pickle
  13. import random
  14. import time
  15. from nltk.corpus import treebank
  16. from nltk.tbl import error_list, Template
  17. from nltk.tag.brill import Word, Pos
  18. from nltk.tag import BrillTaggerTrainer, RegexpTagger, UnigramTagger
  19. def demo():
  20. """
  21. Run a demo with defaults. See source comments for details,
  22. or docstrings of any of the more specific demo_* functions.
  23. """
  24. postag()
  25. def demo_repr_rule_format():
  26. """
  27. Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
  28. """
  29. postag(ruleformat="repr")
  30. def demo_str_rule_format():
  31. """
  32. Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
  33. """
  34. postag(ruleformat="str")
  35. def demo_verbose_rule_format():
  36. """
  37. Exemplify Rule.format("verbose")
  38. """
  39. postag(ruleformat="verbose")
  40. def demo_multiposition_feature():
  41. """
  42. The feature/s of a template takes a list of positions
  43. relative to the current word where the feature should be
  44. looked for, conceptually joined by logical OR. For instance,
  45. Pos([-1, 1]), given a value V, will hold whenever V is found
  46. one step to the left and/or one step to the right.
  47. For contiguous ranges, a 2-arg form giving inclusive end
  48. points can also be used: Pos(-3, -1) is the same as the arg
  49. below.
  50. """
  51. postag(templates=[Template(Pos([-3, -2, -1]))])
  52. def demo_multifeature_template():
  53. """
  54. Templates can have more than a single feature.
  55. """
  56. postag(templates=[Template(Word([0]), Pos([-2, -1]))])
  57. def demo_template_statistics():
  58. """
  59. Show aggregate statistics per template. Little used templates are
  60. candidates for deletion, much used templates may possibly be refined.
  61. Deleting unused templates is mostly about saving time and/or space:
  62. training is basically O(T) in the number of templates T
  63. (also in terms of memory usage, which often will be the limiting factor).
  64. """
  65. postag(incremental_stats=True, template_stats=True)
  66. def demo_generated_templates():
  67. """
  68. Template.expand and Feature.expand are class methods facilitating
  69. generating large amounts of templates. See their documentation for
  70. details.
  71. Note: training with 500 templates can easily fill all available
  72. even on relatively small corpora
  73. """
  74. wordtpls = Word.expand([-1, 0, 1], [1, 2], excludezero=False)
  75. tagtpls = Pos.expand([-2, -1, 0, 1], [1, 2], excludezero=True)
  76. templates = list(Template.expand([wordtpls, tagtpls], combinations=(1, 3)))
  77. print(
  78. "Generated {0} templates for transformation-based learning".format(
  79. len(templates)
  80. )
  81. )
  82. postag(templates=templates, incremental_stats=True, template_stats=True)
  83. def demo_learning_curve():
  84. """
  85. Plot a learning curve -- the contribution on tagging accuracy of
  86. the individual rules.
  87. Note: requires matplotlib
  88. """
  89. postag(
  90. incremental_stats=True,
  91. separate_baseline_data=True,
  92. learning_curve_output="learningcurve.png",
  93. )
  94. def demo_error_analysis():
  95. """
  96. Writes a file with context for each erroneous word after tagging testing data
  97. """
  98. postag(error_output="errors.txt")
  99. def demo_serialize_tagger():
  100. """
  101. Serializes the learned tagger to a file in pickle format; reloads it
  102. and validates the process.
  103. """
  104. postag(serialize_output="tagger.pcl")
  105. def demo_high_accuracy_rules():
  106. """
  107. Discard rules with low accuracy. This may hurt performance a bit,
  108. but will often produce rules which are more interesting read to a human.
  109. """
  110. postag(num_sents=3000, min_acc=0.96, min_score=10)
  111. def postag(
  112. templates=None,
  113. tagged_data=None,
  114. num_sents=1000,
  115. max_rules=300,
  116. min_score=3,
  117. min_acc=None,
  118. train=0.8,
  119. trace=3,
  120. randomize=False,
  121. ruleformat="str",
  122. incremental_stats=False,
  123. template_stats=False,
  124. error_output=None,
  125. serialize_output=None,
  126. learning_curve_output=None,
  127. learning_curve_take=300,
  128. baseline_backoff_tagger=None,
  129. separate_baseline_data=False,
  130. cache_baseline_tagger=None,
  131. ):
  132. """
  133. Brill Tagger Demonstration
  134. :param templates: how many sentences of training and testing data to use
  135. :type templates: list of Template
  136. :param tagged_data: maximum number of rule instances to create
  137. :type tagged_data: C{int}
  138. :param num_sents: how many sentences of training and testing data to use
  139. :type num_sents: C{int}
  140. :param max_rules: maximum number of rule instances to create
  141. :type max_rules: C{int}
  142. :param min_score: the minimum score for a rule in order for it to be considered
  143. :type min_score: C{int}
  144. :param min_acc: the minimum score for a rule in order for it to be considered
  145. :type min_acc: C{float}
  146. :param train: the fraction of the the corpus to be used for training (1=all)
  147. :type train: C{float}
  148. :param trace: the level of diagnostic tracing output to produce (0-4)
  149. :type trace: C{int}
  150. :param randomize: whether the training data should be a random subset of the corpus
  151. :type randomize: C{bool}
  152. :param ruleformat: rule output format, one of "str", "repr", "verbose"
  153. :type ruleformat: C{str}
  154. :param incremental_stats: if true, will tag incrementally and collect stats for each rule (rather slow)
  155. :type incremental_stats: C{bool}
  156. :param template_stats: if true, will print per-template statistics collected in training and (optionally) testing
  157. :type template_stats: C{bool}
  158. :param error_output: the file where errors will be saved
  159. :type error_output: C{string}
  160. :param serialize_output: the file where the learned tbl tagger will be saved
  161. :type serialize_output: C{string}
  162. :param learning_curve_output: filename of plot of learning curve(s) (train and also test, if available)
  163. :type learning_curve_output: C{string}
  164. :param learning_curve_take: how many rules plotted
  165. :type learning_curve_take: C{int}
  166. :param baseline_backoff_tagger: the file where rules will be saved
  167. :type baseline_backoff_tagger: tagger
  168. :param separate_baseline_data: use a fraction of the training data exclusively for training baseline
  169. :type separate_baseline_data: C{bool}
  170. :param cache_baseline_tagger: cache baseline tagger to this file (only interesting as a temporary workaround to get
  171. deterministic output from the baseline unigram tagger between python versions)
  172. :type cache_baseline_tagger: C{string}
  173. Note on separate_baseline_data: if True, reuse training data both for baseline and rule learner. This
  174. is fast and fine for a demo, but is likely to generalize worse on unseen data.
  175. Also cannot be sensibly used for learning curves on training data (the baseline will be artificially high).
  176. """
  177. # defaults
  178. baseline_backoff_tagger = baseline_backoff_tagger or REGEXP_TAGGER
  179. if templates is None:
  180. from nltk.tag.brill import describe_template_sets, brill24
  181. # some pre-built template sets taken from typical systems or publications are
  182. # available. Print a list with describe_template_sets()
  183. # for instance:
  184. templates = brill24()
  185. (training_data, baseline_data, gold_data, testing_data) = _demo_prepare_data(
  186. tagged_data, train, num_sents, randomize, separate_baseline_data
  187. )
  188. # creating (or reloading from cache) a baseline tagger (unigram tagger)
  189. # this is just a mechanism for getting deterministic output from the baseline between
  190. # python versions
  191. if cache_baseline_tagger:
  192. if not os.path.exists(cache_baseline_tagger):
  193. baseline_tagger = UnigramTagger(
  194. baseline_data, backoff=baseline_backoff_tagger
  195. )
  196. with open(cache_baseline_tagger, 'w') as print_rules:
  197. pickle.dump(baseline_tagger, print_rules)
  198. print(
  199. "Trained baseline tagger, pickled it to {0}".format(
  200. cache_baseline_tagger
  201. )
  202. )
  203. with open(cache_baseline_tagger, "r") as print_rules:
  204. baseline_tagger = pickle.load(print_rules)
  205. print("Reloaded pickled tagger from {0}".format(cache_baseline_tagger))
  206. else:
  207. baseline_tagger = UnigramTagger(baseline_data, backoff=baseline_backoff_tagger)
  208. print("Trained baseline tagger")
  209. if gold_data:
  210. print(
  211. " Accuracy on test set: {0:0.4f}".format(
  212. baseline_tagger.evaluate(gold_data)
  213. )
  214. )
  215. # creating a Brill tagger
  216. tbrill = time.time()
  217. trainer = BrillTaggerTrainer(
  218. baseline_tagger, templates, trace, ruleformat=ruleformat
  219. )
  220. print("Training tbl tagger...")
  221. brill_tagger = trainer.train(training_data, max_rules, min_score, min_acc)
  222. print("Trained tbl tagger in {0:0.2f} seconds".format(time.time() - tbrill))
  223. if gold_data:
  224. print(" Accuracy on test set: %.4f" % brill_tagger.evaluate(gold_data))
  225. # printing the learned rules, if learned silently
  226. if trace == 1:
  227. print("\nLearned rules: ")
  228. for (ruleno, rule) in enumerate(brill_tagger.rules(), 1):
  229. print("{0:4d} {1:s}".format(ruleno, rule.format(ruleformat)))
  230. # printing template statistics (optionally including comparison with the training data)
  231. # note: if not separate_baseline_data, then baseline accuracy will be artificially high
  232. if incremental_stats:
  233. print(
  234. "Incrementally tagging the test data, collecting individual rule statistics"
  235. )
  236. (taggedtest, teststats) = brill_tagger.batch_tag_incremental(
  237. testing_data, gold_data
  238. )
  239. print(" Rule statistics collected")
  240. if not separate_baseline_data:
  241. print(
  242. "WARNING: train_stats asked for separate_baseline_data=True; the baseline "
  243. "will be artificially high"
  244. )
  245. trainstats = brill_tagger.train_stats()
  246. if template_stats:
  247. brill_tagger.print_template_statistics(teststats)
  248. if learning_curve_output:
  249. _demo_plot(
  250. learning_curve_output, teststats, trainstats, take=learning_curve_take
  251. )
  252. print("Wrote plot of learning curve to {0}".format(learning_curve_output))
  253. else:
  254. print("Tagging the test data")
  255. taggedtest = brill_tagger.tag_sents(testing_data)
  256. if template_stats:
  257. brill_tagger.print_template_statistics()
  258. # writing error analysis to file
  259. if error_output is not None:
  260. with open(error_output, 'w') as f:
  261. f.write('Errors for Brill Tagger %r\n\n' % serialize_output)
  262. f.write(
  263. u'\n'.join(error_list(gold_data, taggedtest)).encode('utf-8') + '\n'
  264. )
  265. print("Wrote tagger errors including context to {0}".format(error_output))
  266. # serializing the tagger to a pickle file and reloading (just to see it works)
  267. if serialize_output is not None:
  268. taggedtest = brill_tagger.tag_sents(testing_data)
  269. with open(serialize_output, 'w') as print_rules:
  270. pickle.dump(brill_tagger, print_rules)
  271. print("Wrote pickled tagger to {0}".format(serialize_output))
  272. with open(serialize_output, "r") as print_rules:
  273. brill_tagger_reloaded = pickle.load(print_rules)
  274. print("Reloaded pickled tagger from {0}".format(serialize_output))
  275. taggedtest_reloaded = brill_tagger.tag_sents(testing_data)
  276. if taggedtest == taggedtest_reloaded:
  277. print("Reloaded tagger tried on test set, results identical")
  278. else:
  279. print("PROBLEM: Reloaded tagger gave different results on test set")
  280. def _demo_prepare_data(
  281. tagged_data, train, num_sents, randomize, separate_baseline_data
  282. ):
  283. # train is the proportion of data used in training; the rest is reserved
  284. # for testing.
  285. if tagged_data is None:
  286. print("Loading tagged data from treebank... ")
  287. tagged_data = treebank.tagged_sents()
  288. if num_sents is None or len(tagged_data) <= num_sents:
  289. num_sents = len(tagged_data)
  290. if randomize:
  291. random.seed(len(tagged_data))
  292. random.shuffle(tagged_data)
  293. cutoff = int(num_sents * train)
  294. training_data = tagged_data[:cutoff]
  295. gold_data = tagged_data[cutoff:num_sents]
  296. testing_data = [[t[0] for t in sent] for sent in gold_data]
  297. if not separate_baseline_data:
  298. baseline_data = training_data
  299. else:
  300. bl_cutoff = len(training_data) // 3
  301. (baseline_data, training_data) = (
  302. training_data[:bl_cutoff],
  303. training_data[bl_cutoff:],
  304. )
  305. (trainseqs, traintokens) = corpus_size(training_data)
  306. (testseqs, testtokens) = corpus_size(testing_data)
  307. (bltrainseqs, bltraintokens) = corpus_size(baseline_data)
  308. print("Read testing data ({0:d} sents/{1:d} wds)".format(testseqs, testtokens))
  309. print("Read training data ({0:d} sents/{1:d} wds)".format(trainseqs, traintokens))
  310. print(
  311. "Read baseline data ({0:d} sents/{1:d} wds) {2:s}".format(
  312. bltrainseqs,
  313. bltraintokens,
  314. "" if separate_baseline_data else "[reused the training set]",
  315. )
  316. )
  317. return (training_data, baseline_data, gold_data, testing_data)
  318. def _demo_plot(learning_curve_output, teststats, trainstats=None, take=None):
  319. testcurve = [teststats['initialerrors']]
  320. for rulescore in teststats['rulescores']:
  321. testcurve.append(testcurve[-1] - rulescore)
  322. testcurve = [1 - x / teststats['tokencount'] for x in testcurve[:take]]
  323. traincurve = [trainstats['initialerrors']]
  324. for rulescore in trainstats['rulescores']:
  325. traincurve.append(traincurve[-1] - rulescore)
  326. traincurve = [1 - x / trainstats['tokencount'] for x in traincurve[:take]]
  327. import matplotlib.pyplot as plt
  328. r = list(range(len(testcurve)))
  329. plt.plot(r, testcurve, r, traincurve)
  330. plt.axis([None, None, None, 1.0])
  331. plt.savefig(learning_curve_output)
  332. NN_CD_TAGGER = RegexpTagger([(r'^-?[0-9]+(.[0-9]+)?$', 'CD'), (r'.*', 'NN')])
  333. REGEXP_TAGGER = RegexpTagger(
  334. [
  335. (r'^-?[0-9]+(.[0-9]+)?$', 'CD'), # cardinal numbers
  336. (r'(The|the|A|a|An|an)$', 'AT'), # articles
  337. (r'.*able$', 'JJ'), # adjectives
  338. (r'.*ness$', 'NN'), # nouns formed from adjectives
  339. (r'.*ly$', 'RB'), # adverbs
  340. (r'.*s$', 'NNS'), # plural nouns
  341. (r'.*ing$', 'VBG'), # gerunds
  342. (r'.*ed$', 'VBD'), # past tense verbs
  343. (r'.*', 'NN'), # nouns (default)
  344. ]
  345. )
  346. def corpus_size(seqs):
  347. return (len(seqs), sum(len(x) for x in seqs))
  348. if __name__ == '__main__':
  349. demo_learning_curve()