vocabulary.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. """Language Model Vocabulary"""
  9. from __future__ import unicode_literals
  10. import sys
  11. from collections import Counter, Iterable
  12. from itertools import chain
  13. from nltk import compat
  14. try:
  15. # Python >= 3.4
  16. from functools import singledispatch
  17. except ImportError:
  18. # Python < 3.4
  19. from singledispatch import singledispatch
  20. @singledispatch
  21. def _dispatched_lookup(words, vocab):
  22. raise TypeError(
  23. "Unsupported type for looking up in vocabulary: {0}".format(type(words))
  24. )
  25. @_dispatched_lookup.register(Iterable)
  26. def _(words, vocab):
  27. """Look up a sequence of words in the vocabulary.
  28. Returns an iterator over looked up words.
  29. """
  30. return tuple(_dispatched_lookup(w, vocab) for w in words)
  31. try:
  32. # Python 2 unicode + str type
  33. basestring
  34. except NameError:
  35. # Python 3 unicode + str type
  36. basestring = str
  37. @_dispatched_lookup.register(basestring)
  38. def _string_lookup(word, vocab):
  39. """Looks up one word in the vocabulary."""
  40. return word if word in vocab else vocab.unk_label
  41. @compat.python_2_unicode_compatible
  42. class Vocabulary(object):
  43. """Stores language model vocabulary.
  44. Satisfies two common language modeling requirements for a vocabulary:
  45. - When checking membership and calculating its size, filters items
  46. by comparing their counts to a cutoff value.
  47. - Adds a special "unknown" token which unseen words are mapped to.
  48. >>> words = ['a', 'c', '-', 'd', 'c', 'a', 'b', 'r', 'a', 'c', 'd']
  49. >>> from nltk.lm import Vocabulary
  50. >>> vocab = Vocabulary(words, unk_cutoff=2)
  51. Tokens with counts greater than or equal to the cutoff value will
  52. be considered part of the vocabulary.
  53. >>> vocab['c']
  54. 3
  55. >>> 'c' in vocab
  56. True
  57. >>> vocab['d']
  58. 2
  59. >>> 'd' in vocab
  60. True
  61. Tokens with frequency counts less than the cutoff value will be considered not
  62. part of the vocabulary even though their entries in the count dictionary are
  63. preserved.
  64. >>> vocab['b']
  65. 1
  66. >>> 'b' in vocab
  67. False
  68. >>> vocab['aliens']
  69. 0
  70. >>> 'aliens' in vocab
  71. False
  72. Keeping the count entries for seen words allows us to change the cutoff value
  73. without having to recalculate the counts.
  74. >>> vocab2 = Vocabulary(vocab.counts, unk_cutoff=1)
  75. >>> "b" in vocab2
  76. True
  77. The cutoff value influences not only membership checking but also the result of
  78. getting the size of the vocabulary using the built-in `len`.
  79. Note that while the number of keys in the vocabulary's counter stays the same,
  80. the items in the vocabulary differ depending on the cutoff.
  81. We use `sorted` to demonstrate because it keeps the order consistent.
  82. >>> sorted(vocab2.counts)
  83. ['-', 'a', 'b', 'c', 'd', 'r']
  84. >>> sorted(vocab2)
  85. ['-', '<UNK>', 'a', 'b', 'c', 'd', 'r']
  86. >>> sorted(vocab.counts)
  87. ['-', 'a', 'b', 'c', 'd', 'r']
  88. >>> sorted(vocab)
  89. ['<UNK>', 'a', 'c', 'd']
  90. In addition to items it gets populated with, the vocabulary stores a special
  91. token that stands in for so-called "unknown" items. By default it's "<UNK>".
  92. >>> "<UNK>" in vocab
  93. True
  94. We can look up words in a vocabulary using its `lookup` method.
  95. "Unseen" words (with counts less than cutoff) are looked up as the unknown label.
  96. If given one word (a string) as an input, this method will return a string.
  97. >>> vocab.lookup("a")
  98. 'a'
  99. >>> vocab.lookup("aliens")
  100. '<UNK>'
  101. If given a sequence, it will return an tuple of the looked up words.
  102. >>> vocab.lookup(["p", 'a', 'r', 'd', 'b', 'c'])
  103. ('<UNK>', 'a', '<UNK>', 'd', '<UNK>', 'c')
  104. It's possible to update the counts after the vocabulary has been created.
  105. The interface follows that of `collections.Counter`.
  106. >>> vocab['b']
  107. 1
  108. >>> vocab.update(["b", "b", "c"])
  109. >>> vocab['b']
  110. 3
  111. """
  112. def __init__(self, counts=None, unk_cutoff=1, unk_label="<UNK>"):
  113. """Create a new Vocabulary.
  114. :param counts: Optional iterable or `collections.Counter` instance to
  115. pre-seed the Vocabulary. In case it is iterable, counts
  116. are calculated.
  117. :param int unk_cutoff: Words that occur less frequently than this value
  118. are not considered part of the vocabulary.
  119. :param unk_label: Label for marking words not part of vocabulary.
  120. """
  121. if isinstance(counts, Counter):
  122. self.counts = counts
  123. else:
  124. self.counts = Counter()
  125. if isinstance(counts, Iterable):
  126. self.counts.update(counts)
  127. self.unk_label = unk_label
  128. if unk_cutoff < 1:
  129. raise ValueError(
  130. "Cutoff value cannot be less than 1. Got: {0}".format(unk_cutoff)
  131. )
  132. self._cutoff = unk_cutoff
  133. @property
  134. def cutoff(self):
  135. """Cutoff value.
  136. Items with count below this value are not considered part of vocabulary.
  137. """
  138. return self._cutoff
  139. def update(self, *counter_args, **counter_kwargs):
  140. """Update vocabulary counts.
  141. Wraps `collections.Counter.update` method.
  142. """
  143. self.counts.update(*counter_args, **counter_kwargs)
  144. def lookup(self, words):
  145. """Look up one or more words in the vocabulary.
  146. If passed one word as a string will return that word or `self.unk_label`.
  147. Otherwise will assume it was passed a sequence of words, will try to look
  148. each of them up and return an iterator over the looked up words.
  149. :param words: Word(s) to look up.
  150. :type words: Iterable(str) or str
  151. :rtype: generator(str) or str
  152. :raises: TypeError for types other than strings or iterables
  153. >>> from nltk.lm import Vocabulary
  154. >>> vocab = Vocabulary(["a", "b", "c", "a", "b"], unk_cutoff=2)
  155. >>> vocab.lookup("a")
  156. 'a'
  157. >>> vocab.lookup("aliens")
  158. '<UNK>'
  159. >>> vocab.lookup(["a", "b", "c", ["x", "b"]])
  160. ('a', 'b', '<UNK>', ('<UNK>', 'b'))
  161. """
  162. return _dispatched_lookup(words, self)
  163. def __getitem__(self, item):
  164. return self._cutoff if item == self.unk_label else self.counts[item]
  165. def __contains__(self, item):
  166. """Only consider items with counts GE to cutoff as being in the
  167. vocabulary."""
  168. return self[item] >= self.cutoff
  169. def __iter__(self):
  170. """Building on membership check define how to iterate over
  171. vocabulary."""
  172. return chain(
  173. (item for item in self.counts if item in self),
  174. [self.unk_label] if self.counts else [],
  175. )
  176. def __len__(self):
  177. """Computing size of vocabulary reflects the cutoff."""
  178. return sum(1 for _ in self)
  179. def __eq__(self, other):
  180. return (
  181. self.unk_label == other.unk_label
  182. and self.cutoff == other.cutoff
  183. and self.counts == other.counts
  184. )
  185. if sys.version_info[0] == 2:
  186. # see https://stackoverflow.com/a/35781654/4501212
  187. def __ne__(self, other):
  188. equal = self.__eq__(other)
  189. return equal if equal is NotImplemented else not equal
  190. def __str__(self):
  191. return "<{0} with cutoff={1} unk_label='{2}' and {3} items>".format(
  192. self.__class__.__name__, self.cutoff, self.unk_label, len(self)
  193. )