api.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: Language Models
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Authors: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """Language Model Interface."""
  9. from __future__ import division, unicode_literals
  10. import random
  11. from abc import ABCMeta, abstractmethod
  12. from bisect import bisect
  13. from six import add_metaclass
  14. from nltk.lm.counter import NgramCounter
  15. from nltk.lm.util import log_base2
  16. from nltk.lm.vocabulary import Vocabulary
  17. try:
  18. from itertools import accumulate
  19. except ImportError:
  20. import operator
  21. def accumulate(iterable, func=operator.add):
  22. """Return running totals"""
  23. # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
  24. # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
  25. it = iter(iterable)
  26. try:
  27. total = next(it)
  28. except StopIteration:
  29. return
  30. yield total
  31. for element in it:
  32. total = func(total, element)
  33. yield total
  34. @add_metaclass(ABCMeta)
  35. class Smoothing(object):
  36. """Ngram Smoothing Interface
  37. Implements Chen & Goodman 1995's idea that all smoothing algorithms have
  38. certain features in common. This should ideally allow smoothing algoritms to
  39. work both with Backoff and Interpolation.
  40. """
  41. def __init__(self, vocabulary, counter):
  42. """
  43. :param vocabulary: The Ngram vocabulary object.
  44. :type vocabulary: nltk.lm.vocab.Vocabulary
  45. :param counter: The counts of the vocabulary items.
  46. :type counter: nltk.lm.counter.NgramCounter
  47. """
  48. self.vocab = vocabulary
  49. self.counts = counter
  50. @abstractmethod
  51. def unigram_score(self, word):
  52. raise NotImplementedError()
  53. @abstractmethod
  54. def alpha_gamma(self, word, context):
  55. raise NotImplementedError()
  56. def _mean(items):
  57. """Return average (aka mean) for sequence of items."""
  58. return sum(items) / len(items)
  59. def _random_generator(seed_or_generator):
  60. if isinstance(seed_or_generator, random.Random):
  61. return seed_or_generator
  62. return random.Random(seed_or_generator)
  63. def _weighted_choice(population, weights, random_generator=None):
  64. """Like random.choice, but with weights.
  65. Heavily inspired by python 3.6 `random.choices`.
  66. """
  67. if not population:
  68. raise ValueError("Can't choose from empty population")
  69. if len(population) != len(weights):
  70. raise ValueError("The number of weights does not match the population")
  71. cum_weights = list(accumulate(weights))
  72. total = cum_weights[-1]
  73. threshold = random_generator.random()
  74. return population[bisect(cum_weights, total * threshold)]
  75. @add_metaclass(ABCMeta)
  76. class LanguageModel(object):
  77. """ABC for Language Models.
  78. Cannot be directly instantiated itself.
  79. """
  80. def __init__(self, order, vocabulary=None, counter=None):
  81. """Creates new LanguageModel.
  82. :param vocabulary: If provided, this vocabulary will be used instead
  83. of creating a new one when training.
  84. :type vocabulary: `nltk.lm.Vocabulary` or None
  85. :param counter: If provided, use this object to count ngrams.
  86. :type vocabulary: `nltk.lm.NgramCounter` or None
  87. :param ngrams_fn: If given, defines how sentences in training text are turned to ngram
  88. sequences.
  89. :type ngrams_fn: function or None
  90. :param pad_fn: If given, defines how senteces in training text are padded.
  91. :type pad_fn: function or None
  92. """
  93. self.order = order
  94. self.vocab = Vocabulary() if vocabulary is None else vocabulary
  95. self.counts = NgramCounter() if counter is None else counter
  96. def fit(self, text, vocabulary_text=None):
  97. """Trains the model on a text.
  98. :param text: Training text as a sequence of sentences.
  99. """
  100. if not self.vocab:
  101. if vocabulary_text is None:
  102. raise ValueError(
  103. "Cannot fit without a vocabulary or text to " "create it from."
  104. )
  105. self.vocab.update(vocabulary_text)
  106. self.counts.update(self.vocab.lookup(sent) for sent in text)
  107. def score(self, word, context=None):
  108. """Masks out of vocab (OOV) words and computes their model score.
  109. For model-specific logic of calculating scores, see the `unmasked_score`
  110. method.
  111. """
  112. return self.unmasked_score(
  113. self.vocab.lookup(word), self.vocab.lookup(context) if context else None
  114. )
  115. @abstractmethod
  116. def unmasked_score(self, word, context=None):
  117. """Score a word given some optional context.
  118. Concrete models are expected to provide an implementation.
  119. Note that this method does not mask its arguments with the OOV label.
  120. Use the `score` method for that.
  121. :param str word: Word for which we want the score
  122. :param tuple(str) context: Context the word is in.
  123. If `None`, compute unigram score.
  124. :param context: tuple(str) or None
  125. :rtype: float
  126. """
  127. raise NotImplementedError()
  128. def logscore(self, word, context=None):
  129. """Evaluate the log score of this word in this context.
  130. The arguments are the same as for `score` and `unmasked_score`.
  131. """
  132. return log_base2(self.score(word, context))
  133. def context_counts(self, context):
  134. """Helper method for retrieving counts for a given context.
  135. Assumes context has been checked and oov words in it masked.
  136. :type context: tuple(str) or None
  137. """
  138. return (
  139. self.counts[len(context) + 1][context] if context else self.counts.unigrams
  140. )
  141. def entropy(self, text_ngrams):
  142. """Calculate cross-entropy of model for given evaluation text.
  143. :param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
  144. :rtype: float
  145. """
  146. return -1 * _mean(
  147. [self.logscore(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
  148. )
  149. def perplexity(self, text_ngrams):
  150. """Calculates the perplexity of the given text.
  151. This is simply 2 ** cross-entropy for the text, so the arguments are the same.
  152. """
  153. return pow(2.0, self.entropy(text_ngrams))
  154. def generate(self, num_words=1, text_seed=None, random_seed=None):
  155. """Generate words from the model.
  156. :param int num_words: How many words to generate. By default 1.
  157. :param text_seed: Generation can be conditioned on preceding context.
  158. :param random_seed: A random seed or an instance of `random.Random`. If provided,
  159. makes the random sampling part of generation reproducible.
  160. :return: One (str) word or a list of words generated from model.
  161. Examples:
  162. >>> from nltk.lm import MLE
  163. >>> lm = MLE(2)
  164. >>> lm.fit([[("a", "b"), ("b", "c")]], vocabulary_text=['a', 'b', 'c'])
  165. >>> lm.fit([[("a",), ("b",), ("c",)]])
  166. >>> lm.generate(random_seed=3)
  167. 'a'
  168. >>> lm.generate(text_seed=['a'])
  169. 'b'
  170. """
  171. text_seed = [] if text_seed is None else list(text_seed)
  172. random_generator = _random_generator(random_seed)
  173. # base recursion case
  174. if num_words == 1:
  175. context = (
  176. text_seed[-self.order + 1 :]
  177. if len(text_seed) >= self.order
  178. else text_seed
  179. )
  180. samples = self.context_counts(self.vocab.lookup(context))
  181. while context and not samples:
  182. context = context[1:] if len(context) > 1 else []
  183. samples = self.context_counts(self.vocab.lookup(context))
  184. # sorting achieves two things:
  185. # - reproducible randomness when sampling
  186. # - turning Mapping into Sequence which _weighted_choice expects
  187. samples = sorted(samples)
  188. return _weighted_choice(
  189. samples, tuple(self.score(w, context) for w in samples), random_generator
  190. )
  191. # build up text one word at a time
  192. generated = []
  193. for _ in range(num_words):
  194. generated.append(
  195. self.generate(
  196. num_words=1,
  197. text_seed=text_seed + generated,
  198. random_seed=random_generator,
  199. )
  200. )
  201. return generated