nonprojectivedependencyparser.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. # Natural Language Toolkit: Dependency Grammars
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Jason Narad <jason.narad@gmail.com>
  5. #
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. #
  9. from __future__ import print_function
  10. import math
  11. import logging
  12. from six.moves import range
  13. from nltk.parse.dependencygraph import DependencyGraph
  14. logger = logging.getLogger(__name__)
  15. #################################################################
  16. # DependencyScorerI - Interface for Graph-Edge Weight Calculation
  17. #################################################################
  18. class DependencyScorerI(object):
  19. """
  20. A scorer for calculated the weights on the edges of a weighted
  21. dependency graph. This is used by a
  22. ``ProbabilisticNonprojectiveParser`` to initialize the edge
  23. weights of a ``DependencyGraph``. While typically this would be done
  24. by training a binary classifier, any class that can return a
  25. multidimensional list representation of the edge weights can
  26. implement this interface. As such, it has no necessary
  27. fields.
  28. """
  29. def __init__(self):
  30. if self.__class__ == DependencyScorerI:
  31. raise TypeError('DependencyScorerI is an abstract interface')
  32. def train(self, graphs):
  33. """
  34. :type graphs: list(DependencyGraph)
  35. :param graphs: A list of dependency graphs to train the scorer.
  36. Typically the edges present in the graphs can be used as
  37. positive training examples, and the edges not present as negative
  38. examples.
  39. """
  40. raise NotImplementedError()
  41. def score(self, graph):
  42. """
  43. :type graph: DependencyGraph
  44. :param graph: A dependency graph whose set of edges need to be
  45. scored.
  46. :rtype: A three-dimensional list of numbers.
  47. :return: The score is returned in a multidimensional(3) list, such
  48. that the outer-dimension refers to the head, and the
  49. inner-dimension refers to the dependencies. For instance,
  50. scores[0][1] would reference the list of scores corresponding to
  51. arcs from node 0 to node 1. The node's 'address' field can be used
  52. to determine its number identification.
  53. For further illustration, a score list corresponding to Fig.2 of
  54. Keith Hall's 'K-best Spanning Tree Parsing' paper:
  55. scores = [[[], [5], [1], [1]],
  56. [[], [], [11], [4]],
  57. [[], [10], [], [5]],
  58. [[], [8], [8], []]]
  59. When used in conjunction with a MaxEntClassifier, each score would
  60. correspond to the confidence of a particular edge being classified
  61. with the positive training examples.
  62. """
  63. raise NotImplementedError()
  64. #################################################################
  65. # NaiveBayesDependencyScorer
  66. #################################################################
  67. class NaiveBayesDependencyScorer(DependencyScorerI):
  68. """
  69. A dependency scorer built around a MaxEnt classifier. In this
  70. particular class that classifier is a ``NaiveBayesClassifier``.
  71. It uses head-word, head-tag, child-word, and child-tag features
  72. for classification.
  73. >>> from nltk.parse.dependencygraph import DependencyGraph, conll_data2
  74. >>> graphs = [DependencyGraph(entry) for entry in conll_data2.split('\\n\\n') if entry]
  75. >>> npp = ProbabilisticNonprojectiveParser()
  76. >>> npp.train(graphs, NaiveBayesDependencyScorer())
  77. >>> parses = npp.parse(['Cathy', 'zag', 'hen', 'zwaaien', '.'], ['N', 'V', 'Pron', 'Adj', 'N', 'Punc'])
  78. >>> len(list(parses))
  79. 1
  80. """
  81. def __init__(self):
  82. pass # Do nothing without throwing error
  83. def train(self, graphs):
  84. """
  85. Trains a ``NaiveBayesClassifier`` using the edges present in
  86. graphs list as positive examples, the edges not present as
  87. negative examples. Uses a feature vector of head-word,
  88. head-tag, child-word, and child-tag.
  89. :type graphs: list(DependencyGraph)
  90. :param graphs: A list of dependency graphs to train the scorer.
  91. """
  92. from nltk.classify import NaiveBayesClassifier
  93. # Create training labeled training examples
  94. labeled_examples = []
  95. for graph in graphs:
  96. for head_node in graph.nodes.values():
  97. for child_index, child_node in graph.nodes.items():
  98. if child_index in head_node['deps']:
  99. label = "T"
  100. else:
  101. label = "F"
  102. labeled_examples.append(
  103. (
  104. dict(
  105. a=head_node['word'],
  106. b=head_node['tag'],
  107. c=child_node['word'],
  108. d=child_node['tag'],
  109. ),
  110. label,
  111. )
  112. )
  113. self.classifier = NaiveBayesClassifier.train(labeled_examples)
  114. def score(self, graph):
  115. """
  116. Converts the graph into a feature-based representation of
  117. each edge, and then assigns a score to each based on the
  118. confidence of the classifier in assigning it to the
  119. positive label. Scores are returned in a multidimensional list.
  120. :type graph: DependencyGraph
  121. :param graph: A dependency graph to score.
  122. :rtype: 3 dimensional list
  123. :return: Edge scores for the graph parameter.
  124. """
  125. # Convert graph to feature representation
  126. edges = []
  127. for head_node in graph.nodes.values():
  128. for child_node in graph.nodes.values():
  129. edges.append(
  130. (
  131. dict(
  132. a=head_node['word'],
  133. b=head_node['tag'],
  134. c=child_node['word'],
  135. d=child_node['tag'],
  136. )
  137. )
  138. )
  139. # Score edges
  140. edge_scores = []
  141. row = []
  142. count = 0
  143. for pdist in self.classifier.prob_classify_many(edges):
  144. logger.debug('%.4f %.4f', pdist.prob('T'), pdist.prob('F'))
  145. # smoothing in case the probability = 0
  146. row.append([math.log(pdist.prob("T") + 0.00000000001)])
  147. count += 1
  148. if count == len(graph.nodes):
  149. edge_scores.append(row)
  150. row = []
  151. count = 0
  152. return edge_scores
  153. #################################################################
  154. # A Scorer for Demo Purposes
  155. #################################################################
  156. # A short class necessary to show parsing example from paper
  157. class DemoScorer(DependencyScorerI):
  158. def train(self, graphs):
  159. print('Training...')
  160. def score(self, graph):
  161. # scores for Keith Hall 'K-best Spanning Tree Parsing' paper
  162. return [
  163. [[], [5], [1], [1]],
  164. [[], [], [11], [4]],
  165. [[], [10], [], [5]],
  166. [[], [8], [8], []],
  167. ]
  168. #################################################################
  169. # Non-Projective Probabilistic Parsing
  170. #################################################################
  171. class ProbabilisticNonprojectiveParser(object):
  172. """A probabilistic non-projective dependency parser.
  173. Nonprojective dependencies allows for "crossing branches" in the parse tree
  174. which is necessary for representing particular linguistic phenomena, or even
  175. typical parses in some languages. This parser follows the MST parsing
  176. algorithm, outlined in McDonald(2005), which likens the search for the best
  177. non-projective parse to finding the maximum spanning tree in a weighted
  178. directed graph.
  179. >>> class Scorer(DependencyScorerI):
  180. ... def train(self, graphs):
  181. ... pass
  182. ...
  183. ... def score(self, graph):
  184. ... return [
  185. ... [[], [5], [1], [1]],
  186. ... [[], [], [11], [4]],
  187. ... [[], [10], [], [5]],
  188. ... [[], [8], [8], []],
  189. ... ]
  190. >>> npp = ProbabilisticNonprojectiveParser()
  191. >>> npp.train([], Scorer())
  192. >>> parses = npp.parse(['v1', 'v2', 'v3'], [None, None, None])
  193. >>> len(list(parses))
  194. 1
  195. Rule based example
  196. ------------------
  197. >>> from nltk.grammar import DependencyGrammar
  198. >>> grammar = DependencyGrammar.fromstring('''
  199. ... 'taught' -> 'play' | 'man'
  200. ... 'man' -> 'the' | 'in'
  201. ... 'in' -> 'corner'
  202. ... 'corner' -> 'the'
  203. ... 'play' -> 'golf' | 'dachshund' | 'to'
  204. ... 'dachshund' -> 'his'
  205. ... ''')
  206. >>> ndp = NonprojectiveDependencyParser(grammar)
  207. >>> parses = ndp.parse(['the', 'man', 'in', 'the', 'corner', 'taught', 'his', 'dachshund', 'to', 'play', 'golf'])
  208. >>> len(list(parses))
  209. 4
  210. """
  211. def __init__(self):
  212. """
  213. Creates a new non-projective parser.
  214. """
  215. logging.debug('initializing prob. nonprojective...')
  216. def train(self, graphs, dependency_scorer):
  217. """
  218. Trains a ``DependencyScorerI`` from a set of ``DependencyGraph`` objects,
  219. and establishes this as the parser's scorer. This is used to
  220. initialize the scores on a ``DependencyGraph`` during the parsing
  221. procedure.
  222. :type graphs: list(DependencyGraph)
  223. :param graphs: A list of dependency graphs to train the scorer.
  224. :type dependency_scorer: DependencyScorerI
  225. :param dependency_scorer: A scorer which implements the
  226. ``DependencyScorerI`` interface.
  227. """
  228. self._scorer = dependency_scorer
  229. self._scorer.train(graphs)
  230. def initialize_edge_scores(self, graph):
  231. """
  232. Assigns a score to every edge in the ``DependencyGraph`` graph.
  233. These scores are generated via the parser's scorer which
  234. was assigned during the training process.
  235. :type graph: DependencyGraph
  236. :param graph: A dependency graph to assign scores to.
  237. """
  238. self.scores = self._scorer.score(graph)
  239. def collapse_nodes(self, new_node, cycle_path, g_graph, b_graph, c_graph):
  240. """
  241. Takes a list of nodes that have been identified to belong to a cycle,
  242. and collapses them into on larger node. The arcs of all nodes in
  243. the graph must be updated to account for this.
  244. :type new_node: Node.
  245. :param new_node: A Node (Dictionary) to collapse the cycle nodes into.
  246. :type cycle_path: A list of integers.
  247. :param cycle_path: A list of node addresses, each of which is in the cycle.
  248. :type g_graph, b_graph, c_graph: DependencyGraph
  249. :param g_graph, b_graph, c_graph: Graphs which need to be updated.
  250. """
  251. logger.debug('Collapsing nodes...')
  252. # Collapse all cycle nodes into v_n+1 in G_Graph
  253. for cycle_node_index in cycle_path:
  254. g_graph.remove_by_address(cycle_node_index)
  255. g_graph.add_node(new_node)
  256. g_graph.redirect_arcs(cycle_path, new_node['address'])
  257. def update_edge_scores(self, new_node, cycle_path):
  258. """
  259. Updates the edge scores to reflect a collapse operation into
  260. new_node.
  261. :type new_node: A Node.
  262. :param new_node: The node which cycle nodes are collapsed into.
  263. :type cycle_path: A list of integers.
  264. :param cycle_path: A list of node addresses that belong to the cycle.
  265. """
  266. logger.debug('cycle %s', cycle_path)
  267. cycle_path = self.compute_original_indexes(cycle_path)
  268. logger.debug('old cycle %s', cycle_path)
  269. logger.debug('Prior to update: %s', self.scores)
  270. for i, row in enumerate(self.scores):
  271. for j, column in enumerate(self.scores[i]):
  272. logger.debug(self.scores[i][j])
  273. if j in cycle_path and i not in cycle_path and self.scores[i][j]:
  274. subtract_val = self.compute_max_subtract_score(j, cycle_path)
  275. logger.debug('%s - %s', self.scores[i][j], subtract_val)
  276. new_vals = []
  277. for cur_val in self.scores[i][j]:
  278. new_vals.append(cur_val - subtract_val)
  279. self.scores[i][j] = new_vals
  280. for i, row in enumerate(self.scores):
  281. for j, cell in enumerate(self.scores[i]):
  282. if i in cycle_path and j in cycle_path:
  283. self.scores[i][j] = []
  284. logger.debug('After update: %s', self.scores)
  285. def compute_original_indexes(self, new_indexes):
  286. """
  287. As nodes are collapsed into others, they are replaced
  288. by the new node in the graph, but it's still necessary
  289. to keep track of what these original nodes were. This
  290. takes a list of node addresses and replaces any collapsed
  291. node addresses with their original addresses.
  292. :type new_indexes: A list of integers.
  293. :param new_indexes: A list of node addresses to check for
  294. subsumed nodes.
  295. """
  296. swapped = True
  297. while swapped:
  298. originals = []
  299. swapped = False
  300. for new_index in new_indexes:
  301. if new_index in self.inner_nodes:
  302. for old_val in self.inner_nodes[new_index]:
  303. if old_val not in originals:
  304. originals.append(old_val)
  305. swapped = True
  306. else:
  307. originals.append(new_index)
  308. new_indexes = originals
  309. return new_indexes
  310. def compute_max_subtract_score(self, column_index, cycle_indexes):
  311. """
  312. When updating scores the score of the highest-weighted incoming
  313. arc is subtracted upon collapse. This returns the correct
  314. amount to subtract from that edge.
  315. :type column_index: integer.
  316. :param column_index: A index representing the column of incoming arcs
  317. to a particular node being updated
  318. :type cycle_indexes: A list of integers.
  319. :param cycle_indexes: Only arcs from cycle nodes are considered. This
  320. is a list of such nodes addresses.
  321. """
  322. max_score = -100000
  323. for row_index in cycle_indexes:
  324. for subtract_val in self.scores[row_index][column_index]:
  325. if subtract_val > max_score:
  326. max_score = subtract_val
  327. return max_score
  328. def best_incoming_arc(self, node_index):
  329. """
  330. Returns the source of the best incoming arc to the
  331. node with address: node_index
  332. :type node_index: integer.
  333. :param node_index: The address of the 'destination' node,
  334. the node that is arced to.
  335. """
  336. originals = self.compute_original_indexes([node_index])
  337. logger.debug('originals: %s', originals)
  338. max_arc = None
  339. max_score = None
  340. for row_index in range(len(self.scores)):
  341. for col_index in range(len(self.scores[row_index])):
  342. # print self.scores[row_index][col_index]
  343. if col_index in originals and (
  344. max_score is None or self.scores[row_index][col_index] > max_score
  345. ):
  346. max_score = self.scores[row_index][col_index]
  347. max_arc = row_index
  348. logger.debug('%s, %s', row_index, col_index)
  349. logger.debug(max_score)
  350. for key in self.inner_nodes:
  351. replaced_nodes = self.inner_nodes[key]
  352. if max_arc in replaced_nodes:
  353. return key
  354. return max_arc
  355. def original_best_arc(self, node_index):
  356. originals = self.compute_original_indexes([node_index])
  357. max_arc = None
  358. max_score = None
  359. max_orig = None
  360. for row_index in range(len(self.scores)):
  361. for col_index in range(len(self.scores[row_index])):
  362. if col_index in originals and (
  363. max_score is None or self.scores[row_index][col_index] > max_score
  364. ):
  365. max_score = self.scores[row_index][col_index]
  366. max_arc = row_index
  367. max_orig = col_index
  368. return [max_arc, max_orig]
  369. def parse(self, tokens, tags):
  370. """
  371. Parses a list of tokens in accordance to the MST parsing algorithm
  372. for non-projective dependency parses. Assumes that the tokens to
  373. be parsed have already been tagged and those tags are provided. Various
  374. scoring methods can be used by implementing the ``DependencyScorerI``
  375. interface and passing it to the training algorithm.
  376. :type tokens: list(str)
  377. :param tokens: A list of words or punctuation to be parsed.
  378. :type tags: list(str)
  379. :param tags: A list of tags corresponding by index to the words in the tokens list.
  380. :return: An iterator of non-projective parses.
  381. :rtype: iter(DependencyGraph)
  382. """
  383. self.inner_nodes = {}
  384. # Initialize g_graph
  385. g_graph = DependencyGraph()
  386. for index, token in enumerate(tokens):
  387. g_graph.nodes[index + 1].update(
  388. {'word': token, 'tag': tags[index], 'rel': 'NTOP', 'address': index + 1}
  389. )
  390. # print (g_graph.nodes)
  391. # Fully connect non-root nodes in g_graph
  392. g_graph.connect_graph()
  393. original_graph = DependencyGraph()
  394. for index, token in enumerate(tokens):
  395. original_graph.nodes[index + 1].update(
  396. {'word': token, 'tag': tags[index], 'rel': 'NTOP', 'address': index + 1}
  397. )
  398. b_graph = DependencyGraph()
  399. c_graph = DependencyGraph()
  400. for index, token in enumerate(tokens):
  401. c_graph.nodes[index + 1].update(
  402. {'word': token, 'tag': tags[index], 'rel': 'NTOP', 'address': index + 1}
  403. )
  404. # Assign initial scores to g_graph edges
  405. self.initialize_edge_scores(g_graph)
  406. logger.debug(self.scores)
  407. # Initialize a list of unvisited vertices (by node address)
  408. unvisited_vertices = [vertex['address'] for vertex in c_graph.nodes.values()]
  409. # Iterate over unvisited vertices
  410. nr_vertices = len(tokens)
  411. betas = {}
  412. while unvisited_vertices:
  413. # Mark current node as visited
  414. current_vertex = unvisited_vertices.pop(0)
  415. logger.debug('current_vertex: %s', current_vertex)
  416. # Get corresponding node n_i to vertex v_i
  417. current_node = g_graph.get_by_address(current_vertex)
  418. logger.debug('current_node: %s', current_node)
  419. # Get best in-edge node b for current node
  420. best_in_edge = self.best_incoming_arc(current_vertex)
  421. betas[current_vertex] = self.original_best_arc(current_vertex)
  422. logger.debug('best in arc: %s --> %s', best_in_edge, current_vertex)
  423. # b_graph = Union(b_graph, b)
  424. for new_vertex in [current_vertex, best_in_edge]:
  425. b_graph.nodes[new_vertex].update(
  426. {'word': 'TEMP', 'rel': 'NTOP', 'address': new_vertex}
  427. )
  428. b_graph.add_arc(best_in_edge, current_vertex)
  429. # Beta(current node) = b - stored for parse recovery
  430. # If b_graph contains a cycle, collapse it
  431. cycle_path = b_graph.contains_cycle()
  432. if cycle_path:
  433. # Create a new node v_n+1 with address = len(nodes) + 1
  434. new_node = {'word': 'NONE', 'rel': 'NTOP', 'address': nr_vertices + 1}
  435. # c_graph = Union(c_graph, v_n+1)
  436. c_graph.add_node(new_node)
  437. # Collapse all nodes in cycle C into v_n+1
  438. self.update_edge_scores(new_node, cycle_path)
  439. self.collapse_nodes(new_node, cycle_path, g_graph, b_graph, c_graph)
  440. for cycle_index in cycle_path:
  441. c_graph.add_arc(new_node['address'], cycle_index)
  442. # self.replaced_by[cycle_index] = new_node['address']
  443. self.inner_nodes[new_node['address']] = cycle_path
  444. # Add v_n+1 to list of unvisited vertices
  445. unvisited_vertices.insert(0, nr_vertices + 1)
  446. # increment # of nodes counter
  447. nr_vertices += 1
  448. # Remove cycle nodes from b_graph; B = B - cycle c
  449. for cycle_node_address in cycle_path:
  450. b_graph.remove_by_address(cycle_node_address)
  451. logger.debug('g_graph: %s', g_graph)
  452. logger.debug('b_graph: %s', b_graph)
  453. logger.debug('c_graph: %s', c_graph)
  454. logger.debug('Betas: %s', betas)
  455. logger.debug('replaced nodes %s', self.inner_nodes)
  456. # Recover parse tree
  457. logger.debug('Final scores: %s', self.scores)
  458. logger.debug('Recovering parse...')
  459. for i in range(len(tokens) + 1, nr_vertices + 1):
  460. betas[betas[i][1]] = betas[i]
  461. logger.debug('Betas: %s', betas)
  462. for node in original_graph.nodes.values():
  463. # TODO: It's dangerous to assume that deps it a dictionary
  464. # because it's a default dictionary. Ideally, here we should not
  465. # be concerned how dependencies are stored inside of a dependency
  466. # graph.
  467. node['deps'] = {}
  468. for i in range(1, len(tokens) + 1):
  469. original_graph.add_arc(betas[i][0], betas[i][1])
  470. logger.debug('Done.')
  471. yield original_graph
  472. #################################################################
  473. # Rule-based Non-Projective Parser
  474. #################################################################
  475. class NonprojectiveDependencyParser(object):
  476. """
  477. A non-projective, rule-based, dependency parser. This parser
  478. will return the set of all possible non-projective parses based on
  479. the word-to-word relations defined in the parser's dependency
  480. grammar, and will allow the branches of the parse tree to cross
  481. in order to capture a variety of linguistic phenomena that a
  482. projective parser will not.
  483. """
  484. def __init__(self, dependency_grammar):
  485. """
  486. Creates a new ``NonprojectiveDependencyParser``.
  487. :param dependency_grammar: a grammar of word-to-word relations.
  488. :type dependency_grammar: DependencyGrammar
  489. """
  490. self._grammar = dependency_grammar
  491. def parse(self, tokens):
  492. """
  493. Parses the input tokens with respect to the parser's grammar. Parsing
  494. is accomplished by representing the search-space of possible parses as
  495. a fully-connected directed graph. Arcs that would lead to ungrammatical
  496. parses are removed and a lattice is constructed of length n, where n is
  497. the number of input tokens, to represent all possible grammatical
  498. traversals. All possible paths through the lattice are then enumerated
  499. to produce the set of non-projective parses.
  500. param tokens: A list of tokens to parse.
  501. type tokens: list(str)
  502. return: An iterator of non-projective parses.
  503. rtype: iter(DependencyGraph)
  504. """
  505. # Create graph representation of tokens
  506. self._graph = DependencyGraph()
  507. for index, token in enumerate(tokens):
  508. self._graph.nodes[index] = {
  509. 'word': token,
  510. 'deps': [],
  511. 'rel': 'NTOP',
  512. 'address': index,
  513. }
  514. for head_node in self._graph.nodes.values():
  515. deps = []
  516. for dep_node in self._graph.nodes.values():
  517. if (
  518. self._grammar.contains(head_node['word'], dep_node['word'])
  519. and head_node['word'] != dep_node['word']
  520. ):
  521. deps.append(dep_node['address'])
  522. head_node['deps'] = deps
  523. # Create lattice of possible heads
  524. roots = []
  525. possible_heads = []
  526. for i, word in enumerate(tokens):
  527. heads = []
  528. for j, head in enumerate(tokens):
  529. if (i != j) and self._grammar.contains(head, word):
  530. heads.append(j)
  531. if len(heads) == 0:
  532. roots.append(i)
  533. possible_heads.append(heads)
  534. # Set roots to attempt
  535. if len(roots) < 2:
  536. if len(roots) == 0:
  537. for i in range(len(tokens)):
  538. roots.append(i)
  539. # Traverse lattice
  540. analyses = []
  541. for root in roots:
  542. stack = []
  543. analysis = [[] for i in range(len(possible_heads))]
  544. i = 0
  545. forward = True
  546. while i >= 0:
  547. if forward:
  548. if len(possible_heads[i]) == 1:
  549. analysis[i] = possible_heads[i][0]
  550. elif len(possible_heads[i]) == 0:
  551. analysis[i] = -1
  552. else:
  553. head = possible_heads[i].pop()
  554. analysis[i] = head
  555. stack.append([i, head])
  556. if not forward:
  557. index_on_stack = False
  558. for stack_item in stack:
  559. if stack_item[0] == i:
  560. index_on_stack = True
  561. orig_length = len(possible_heads[i])
  562. if index_on_stack and orig_length == 0:
  563. for j in range(len(stack) - 1, -1, -1):
  564. stack_item = stack[j]
  565. if stack_item[0] == i:
  566. possible_heads[i].append(stack.pop(j)[1])
  567. elif index_on_stack and orig_length > 0:
  568. head = possible_heads[i].pop()
  569. analysis[i] = head
  570. stack.append([i, head])
  571. forward = True
  572. if i + 1 == len(possible_heads):
  573. analyses.append(analysis[:])
  574. forward = False
  575. if forward:
  576. i += 1
  577. else:
  578. i -= 1
  579. # Filter parses
  580. # ensure 1 root, every thing has 1 head
  581. for analysis in analyses:
  582. if analysis.count(-1) > 1:
  583. # there are several root elements!
  584. continue
  585. graph = DependencyGraph()
  586. graph.root = graph.nodes[analysis.index(-1) + 1]
  587. for address, (token, head_index) in enumerate(
  588. zip(tokens, analysis), start=1
  589. ):
  590. head_address = head_index + 1
  591. node = graph.nodes[address]
  592. node.update({'word': token, 'address': address})
  593. if head_address == 0:
  594. rel = 'ROOT'
  595. else:
  596. rel = ''
  597. graph.nodes[head_index + 1]['deps'][rel].append(address)
  598. # TODO: check for cycles
  599. yield graph
  600. #################################################################
  601. # Demos
  602. #################################################################
  603. def demo():
  604. # hall_demo()
  605. nonprojective_conll_parse_demo()
  606. rule_based_demo()
  607. def hall_demo():
  608. npp = ProbabilisticNonprojectiveParser()
  609. npp.train([], DemoScorer())
  610. for parse_graph in npp.parse(['v1', 'v2', 'v3'], [None, None, None]):
  611. print(parse_graph)
  612. def nonprojective_conll_parse_demo():
  613. from nltk.parse.dependencygraph import conll_data2
  614. graphs = [DependencyGraph(entry) for entry in conll_data2.split('\n\n') if entry]
  615. npp = ProbabilisticNonprojectiveParser()
  616. npp.train(graphs, NaiveBayesDependencyScorer())
  617. for parse_graph in npp.parse(
  618. ['Cathy', 'zag', 'hen', 'zwaaien', '.'], ['N', 'V', 'Pron', 'Adj', 'N', 'Punc']
  619. ):
  620. print(parse_graph)
  621. def rule_based_demo():
  622. from nltk.grammar import DependencyGrammar
  623. grammar = DependencyGrammar.fromstring(
  624. """
  625. 'taught' -> 'play' | 'man'
  626. 'man' -> 'the' | 'in'
  627. 'in' -> 'corner'
  628. 'corner' -> 'the'
  629. 'play' -> 'golf' | 'dachshund' | 'to'
  630. 'dachshund' -> 'his'
  631. """
  632. )
  633. print(grammar)
  634. ndp = NonprojectiveDependencyParser(grammar)
  635. graphs = ndp.parse(
  636. [
  637. 'the',
  638. 'man',
  639. 'in',
  640. 'the',
  641. 'corner',
  642. 'taught',
  643. 'his',
  644. 'dachshund',
  645. 'to',
  646. 'play',
  647. 'golf',
  648. ]
  649. )
  650. print('Graphs:')
  651. for graph in graphs:
  652. print(graph)
  653. if __name__ == '__main__':
  654. demo()