viterbi.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # Natural Language Toolkit: Viterbi Probabilistic Parser
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Edward Loper <edloper@gmail.com>
  5. # Steven Bird <stevenbird1@gmail.com>
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. from __future__ import print_function, unicode_literals
  9. from functools import reduce
  10. from nltk.tree import Tree, ProbabilisticTree
  11. from nltk.compat import python_2_unicode_compatible
  12. from nltk.parse.api import ParserI
  13. ##//////////////////////////////////////////////////////
  14. ## Viterbi PCFG Parser
  15. ##//////////////////////////////////////////////////////
  16. @python_2_unicode_compatible
  17. class ViterbiParser(ParserI):
  18. """
  19. A bottom-up ``PCFG`` parser that uses dynamic programming to find
  20. the single most likely parse for a text. The ``ViterbiParser`` parser
  21. parses texts by filling in a "most likely constituent table".
  22. This table records the most probable tree representation for any
  23. given span and node value. In particular, it has an entry for
  24. every start index, end index, and node value, recording the most
  25. likely subtree that spans from the start index to the end index,
  26. and has the given node value.
  27. The ``ViterbiParser`` parser fills in this table incrementally. It starts
  28. by filling in all entries for constituents that span one element
  29. of text (i.e., entries where the end index is one greater than the
  30. start index). After it has filled in all table entries for
  31. constituents that span one element of text, it fills in the
  32. entries for constitutants that span two elements of text. It
  33. continues filling in the entries for constituents spanning larger
  34. and larger portions of the text, until the entire table has been
  35. filled. Finally, it returns the table entry for a constituent
  36. spanning the entire text, whose node value is the grammar's start
  37. symbol.
  38. In order to find the most likely constituent with a given span and
  39. node value, the ``ViterbiParser`` parser considers all productions that
  40. could produce that node value. For each production, it finds all
  41. children that collectively cover the span and have the node values
  42. specified by the production's right hand side. If the probability
  43. of the tree formed by applying the production to the children is
  44. greater than the probability of the current entry in the table,
  45. then the table is updated with this new tree.
  46. A pseudo-code description of the algorithm used by
  47. ``ViterbiParser`` is:
  48. | Create an empty most likely constituent table, *MLC*.
  49. | For width in 1...len(text):
  50. | For start in 1...len(text)-width:
  51. | For prod in grammar.productions:
  52. | For each sequence of subtrees [t[1], t[2], ..., t[n]] in MLC,
  53. | where t[i].label()==prod.rhs[i],
  54. | and the sequence covers [start:start+width]:
  55. | old_p = MLC[start, start+width, prod.lhs]
  56. | new_p = P(t[1])P(t[1])...P(t[n])P(prod)
  57. | if new_p > old_p:
  58. | new_tree = Tree(prod.lhs, t[1], t[2], ..., t[n])
  59. | MLC[start, start+width, prod.lhs] = new_tree
  60. | Return MLC[0, len(text), start_symbol]
  61. :type _grammar: PCFG
  62. :ivar _grammar: The grammar used to parse sentences.
  63. :type _trace: int
  64. :ivar _trace: The level of tracing output that should be generated
  65. when parsing a text.
  66. """
  67. def __init__(self, grammar, trace=0):
  68. """
  69. Create a new ``ViterbiParser`` parser, that uses ``grammar`` to
  70. parse texts.
  71. :type grammar: PCFG
  72. :param grammar: The grammar used to parse texts.
  73. :type trace: int
  74. :param trace: The level of tracing that should be used when
  75. parsing a text. ``0`` will generate no tracing output;
  76. and higher numbers will produce more verbose tracing
  77. output.
  78. """
  79. self._grammar = grammar
  80. self._trace = trace
  81. def grammar(self):
  82. return self._grammar
  83. def trace(self, trace=2):
  84. """
  85. Set the level of tracing output that should be generated when
  86. parsing a text.
  87. :type trace: int
  88. :param trace: The trace level. A trace level of ``0`` will
  89. generate no tracing output; and higher trace levels will
  90. produce more verbose tracing output.
  91. :rtype: None
  92. """
  93. self._trace = trace
  94. def parse(self, tokens):
  95. # Inherit docs from ParserI
  96. tokens = list(tokens)
  97. self._grammar.check_coverage(tokens)
  98. # The most likely constituent table. This table specifies the
  99. # most likely constituent for a given span and type.
  100. # Constituents can be either Trees or tokens. For Trees,
  101. # the "type" is the Nonterminal for the tree's root node
  102. # value. For Tokens, the "type" is the token's type.
  103. # The table is stored as a dictionary, since it is sparse.
  104. constituents = {}
  105. # Initialize the constituents dictionary with the words from
  106. # the text.
  107. if self._trace:
  108. print(('Inserting tokens into the most likely' + ' constituents table...'))
  109. for index in range(len(tokens)):
  110. token = tokens[index]
  111. constituents[index, index + 1, token] = token
  112. if self._trace > 1:
  113. self._trace_lexical_insertion(token, index, len(tokens))
  114. # Consider each span of length 1, 2, ..., n; and add any trees
  115. # that might cover that span to the constituents dictionary.
  116. for length in range(1, len(tokens) + 1):
  117. if self._trace:
  118. print(
  119. (
  120. 'Finding the most likely constituents'
  121. + ' spanning %d text elements...' % length
  122. )
  123. )
  124. for start in range(len(tokens) - length + 1):
  125. span = (start, start + length)
  126. self._add_constituents_spanning(span, constituents, tokens)
  127. # Return the tree that spans the entire text & have the right cat
  128. tree = constituents.get((0, len(tokens), self._grammar.start()))
  129. if tree is not None:
  130. yield tree
  131. def _add_constituents_spanning(self, span, constituents, tokens):
  132. """
  133. Find any constituents that might cover ``span``, and add them
  134. to the most likely constituents table.
  135. :rtype: None
  136. :type span: tuple(int, int)
  137. :param span: The section of the text for which we are
  138. trying to find possible constituents. The span is
  139. specified as a pair of integers, where the first integer
  140. is the index of the first token that should be included in
  141. the constituent; and the second integer is the index of
  142. the first token that should not be included in the
  143. constituent. I.e., the constituent should cover
  144. ``text[span[0]:span[1]]``, where ``text`` is the text
  145. that we are parsing.
  146. :type constituents: dict(tuple(int,int,Nonterminal) -> ProbabilisticToken or ProbabilisticTree)
  147. :param constituents: The most likely constituents table. This
  148. table records the most probable tree representation for
  149. any given span and node value. In particular,
  150. ``constituents(s,e,nv)`` is the most likely
  151. ``ProbabilisticTree`` that covers ``text[s:e]``
  152. and has a node value ``nv.symbol()``, where ``text``
  153. is the text that we are parsing. When
  154. ``_add_constituents_spanning`` is called, ``constituents``
  155. should contain all possible constituents that are shorter
  156. than ``span``.
  157. :type tokens: list of tokens
  158. :param tokens: The text we are parsing. This is only used for
  159. trace output.
  160. """
  161. # Since some of the grammar productions may be unary, we need to
  162. # repeatedly try all of the productions until none of them add any
  163. # new constituents.
  164. changed = True
  165. while changed:
  166. changed = False
  167. # Find all ways instantiations of the grammar productions that
  168. # cover the span.
  169. instantiations = self._find_instantiations(span, constituents)
  170. # For each production instantiation, add a new
  171. # ProbabilisticTree whose probability is the product
  172. # of the childrens' probabilities and the production's
  173. # probability.
  174. for (production, children) in instantiations:
  175. subtrees = [c for c in children if isinstance(c, Tree)]
  176. p = reduce(lambda pr, t: pr * t.prob(), subtrees, production.prob())
  177. node = production.lhs().symbol()
  178. tree = ProbabilisticTree(node, children, prob=p)
  179. # If it's new a constituent, then add it to the
  180. # constituents dictionary.
  181. c = constituents.get((span[0], span[1], production.lhs()))
  182. if self._trace > 1:
  183. if c is None or c != tree:
  184. if c is None or c.prob() < tree.prob():
  185. print(' Insert:', end=' ')
  186. else:
  187. print(' Discard:', end=' ')
  188. self._trace_production(production, p, span, len(tokens))
  189. if c is None or c.prob() < tree.prob():
  190. constituents[span[0], span[1], production.lhs()] = tree
  191. changed = True
  192. def _find_instantiations(self, span, constituents):
  193. """
  194. :return: a list of the production instantiations that cover a
  195. given span of the text. A "production instantiation" is
  196. a tuple containing a production and a list of children,
  197. where the production's right hand side matches the list of
  198. children; and the children cover ``span``. :rtype: list
  199. of ``pair`` of ``Production``, (list of
  200. (``ProbabilisticTree`` or token.
  201. :type span: tuple(int, int)
  202. :param span: The section of the text for which we are
  203. trying to find production instantiations. The span is
  204. specified as a pair of integers, where the first integer
  205. is the index of the first token that should be covered by
  206. the production instantiation; and the second integer is
  207. the index of the first token that should not be covered by
  208. the production instantiation.
  209. :type constituents: dict(tuple(int,int,Nonterminal) -> ProbabilisticToken or ProbabilisticTree)
  210. :param constituents: The most likely constituents table. This
  211. table records the most probable tree representation for
  212. any given span and node value. See the module
  213. documentation for more information.
  214. """
  215. rv = []
  216. for production in self._grammar.productions():
  217. childlists = self._match_rhs(production.rhs(), span, constituents)
  218. for childlist in childlists:
  219. rv.append((production, childlist))
  220. return rv
  221. def _match_rhs(self, rhs, span, constituents):
  222. """
  223. :return: a set of all the lists of children that cover ``span``
  224. and that match ``rhs``.
  225. :rtype: list(list(ProbabilisticTree or token)
  226. :type rhs: list(Nonterminal or any)
  227. :param rhs: The list specifying what kinds of children need to
  228. cover ``span``. Each nonterminal in ``rhs`` specifies
  229. that the corresponding child should be a tree whose node
  230. value is that nonterminal's symbol. Each terminal in ``rhs``
  231. specifies that the corresponding child should be a token
  232. whose type is that terminal.
  233. :type span: tuple(int, int)
  234. :param span: The section of the text for which we are
  235. trying to find child lists. The span is specified as a
  236. pair of integers, where the first integer is the index of
  237. the first token that should be covered by the child list;
  238. and the second integer is the index of the first token
  239. that should not be covered by the child list.
  240. :type constituents: dict(tuple(int,int,Nonterminal) -> ProbabilisticToken or ProbabilisticTree)
  241. :param constituents: The most likely constituents table. This
  242. table records the most probable tree representation for
  243. any given span and node value. See the module
  244. documentation for more information.
  245. """
  246. (start, end) = span
  247. # Base case
  248. if start >= end and rhs == ():
  249. return [[]]
  250. if start >= end or rhs == ():
  251. return []
  252. # Find everything that matches the 1st symbol of the RHS
  253. childlists = []
  254. for split in range(start, end + 1):
  255. l = constituents.get((start, split, rhs[0]))
  256. if l is not None:
  257. rights = self._match_rhs(rhs[1:], (split, end), constituents)
  258. childlists += [[l] + r for r in rights]
  259. return childlists
  260. def _trace_production(self, production, p, span, width):
  261. """
  262. Print trace output indicating that a given production has been
  263. applied at a given location.
  264. :param production: The production that has been applied
  265. :type production: Production
  266. :param p: The probability of the tree produced by the production.
  267. :type p: float
  268. :param span: The span of the production
  269. :type span: tuple
  270. :rtype: None
  271. """
  272. str = '|' + '.' * span[0]
  273. str += '=' * (span[1] - span[0])
  274. str += '.' * (width - span[1]) + '| '
  275. str += '%s' % production
  276. if self._trace > 2:
  277. str = '%-40s %12.10f ' % (str, p)
  278. print(str)
  279. def _trace_lexical_insertion(self, token, index, width):
  280. str = ' Insert: |' + '.' * index + '=' + '.' * (width - index - 1) + '| '
  281. str += '%s' % (token,)
  282. print(str)
  283. def __repr__(self):
  284. return '<ViterbiParser for %r>' % self._grammar
  285. ##//////////////////////////////////////////////////////
  286. ## Test Code
  287. ##//////////////////////////////////////////////////////
  288. def demo():
  289. """
  290. A demonstration of the probabilistic parsers. The user is
  291. prompted to select which demo to run, and how many parses should
  292. be found; and then each parser is run on the same demo, and a
  293. summary of the results are displayed.
  294. """
  295. import sys, time
  296. from nltk import tokenize
  297. from nltk.parse import ViterbiParser
  298. from nltk.grammar import toy_pcfg1, toy_pcfg2
  299. # Define two demos. Each demo has a sentence and a grammar.
  300. demos = [
  301. ('I saw the man with my telescope', toy_pcfg1),
  302. ('the boy saw Jack with Bob under the table with a telescope', toy_pcfg2),
  303. ]
  304. # Ask the user which demo they want to use.
  305. print()
  306. for i in range(len(demos)):
  307. print('%3s: %s' % (i + 1, demos[i][0]))
  308. print(' %r' % demos[i][1])
  309. print()
  310. print('Which demo (%d-%d)? ' % (1, len(demos)), end=' ')
  311. try:
  312. snum = int(sys.stdin.readline().strip()) - 1
  313. sent, grammar = demos[snum]
  314. except:
  315. print('Bad sentence number')
  316. return
  317. # Tokenize the sentence.
  318. tokens = sent.split()
  319. parser = ViterbiParser(grammar)
  320. all_parses = {}
  321. print('\nsent: %s\nparser: %s\ngrammar: %s' % (sent, parser, grammar))
  322. parser.trace(3)
  323. t = time.time()
  324. parses = parser.parse_all(tokens)
  325. time = time.time() - t
  326. average = (
  327. reduce(lambda a, b: a + b.prob(), parses, 0) / len(parses) if parses else 0
  328. )
  329. num_parses = len(parses)
  330. for p in parses:
  331. all_parses[p.freeze()] = 1
  332. # Print some summary statistics
  333. print()
  334. print('Time (secs) # Parses Average P(parse)')
  335. print('-----------------------------------------')
  336. print('%11.4f%11d%19.14f' % (time, num_parses, average))
  337. parses = all_parses.keys()
  338. if parses:
  339. p = reduce(lambda a, b: a + b.prob(), parses, 0) / len(parses)
  340. else:
  341. p = 0
  342. print('------------------------------------------')
  343. print('%11s%11d%19.14f' % ('n/a', len(parses), p))
  344. # Ask the user if we should draw the parses.
  345. print()
  346. print('Draw parses (y/n)? ', end=' ')
  347. if sys.stdin.readline().strip().lower().startswith('y'):
  348. from nltk.draw.tree import draw_trees
  349. print(' please wait...')
  350. draw_trees(*parses)
  351. # Ask the user if we should print the parses.
  352. print()
  353. print('Print parses (y/n)? ', end=' ')
  354. if sys.stdin.readline().strip().lower().startswith('y'):
  355. for parse in parses:
  356. print(parse)
  357. if __name__ == '__main__':
  358. demo()