transitionparser.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. # Natural Language Toolkit: Arc-Standard and Arc-eager Transition Based Parsers
  2. #
  3. # Author: Long Duong <longdt219@gmail.com>
  4. #
  5. # Copyright (C) 2001-2019 NLTK Project
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. from __future__ import absolute_import
  9. from __future__ import division
  10. from __future__ import print_function
  11. import tempfile
  12. import pickle
  13. from os import remove
  14. from copy import deepcopy
  15. from operator import itemgetter
  16. try:
  17. from numpy import array
  18. from scipy import sparse
  19. from sklearn.datasets import load_svmlight_file
  20. from sklearn import svm
  21. except ImportError:
  22. pass
  23. from nltk.parse import ParserI, DependencyGraph, DependencyEvaluator
  24. class Configuration(object):
  25. """
  26. Class for holding configuration which is the partial analysis of the input sentence.
  27. The transition based parser aims at finding set of operators that transfer the initial
  28. configuration to the terminal configuration.
  29. The configuration includes:
  30. - Stack: for storing partially proceeded words
  31. - Buffer: for storing remaining input words
  32. - Set of arcs: for storing partially built dependency tree
  33. This class also provides a method to represent a configuration as list of features.
  34. """
  35. def __init__(self, dep_graph):
  36. """
  37. :param dep_graph: the representation of an input in the form of dependency graph.
  38. :type dep_graph: DependencyGraph where the dependencies are not specified.
  39. """
  40. # dep_graph.nodes contain list of token for a sentence
  41. self.stack = [0] # The root element
  42. self.buffer = list(range(1, len(dep_graph.nodes))) # The rest is in the buffer
  43. self.arcs = [] # empty set of arc
  44. self._tokens = dep_graph.nodes
  45. self._max_address = len(self.buffer)
  46. def __str__(self):
  47. return (
  48. 'Stack : '
  49. + str(self.stack)
  50. + ' Buffer : '
  51. + str(self.buffer)
  52. + ' Arcs : '
  53. + str(self.arcs)
  54. )
  55. def _check_informative(self, feat, flag=False):
  56. """
  57. Check whether a feature is informative
  58. The flag control whether "_" is informative or not
  59. """
  60. if feat is None:
  61. return False
  62. if feat == '':
  63. return False
  64. if flag is False:
  65. if feat == '_':
  66. return False
  67. return True
  68. def extract_features(self):
  69. """
  70. Extract the set of features for the current configuration. Implement standard features as describe in
  71. Table 3.2 (page 31) in Dependency Parsing book by Sandra Kubler, Ryan McDonal, Joakim Nivre.
  72. Please note that these features are very basic.
  73. :return: list(str)
  74. """
  75. result = []
  76. # Todo : can come up with more complicated features set for better
  77. # performance.
  78. if len(self.stack) > 0:
  79. # Stack 0
  80. stack_idx0 = self.stack[len(self.stack) - 1]
  81. token = self._tokens[stack_idx0]
  82. if self._check_informative(token['word'], True):
  83. result.append('STK_0_FORM_' + token['word'])
  84. if 'lemma' in token and self._check_informative(token['lemma']):
  85. result.append('STK_0_LEMMA_' + token['lemma'])
  86. if self._check_informative(token['tag']):
  87. result.append('STK_0_POS_' + token['tag'])
  88. if 'feats' in token and self._check_informative(token['feats']):
  89. feats = token['feats'].split("|")
  90. for feat in feats:
  91. result.append('STK_0_FEATS_' + feat)
  92. # Stack 1
  93. if len(self.stack) > 1:
  94. stack_idx1 = self.stack[len(self.stack) - 2]
  95. token = self._tokens[stack_idx1]
  96. if self._check_informative(token['tag']):
  97. result.append('STK_1_POS_' + token['tag'])
  98. # Left most, right most dependency of stack[0]
  99. left_most = 1000000
  100. right_most = -1
  101. dep_left_most = ''
  102. dep_right_most = ''
  103. for (wi, r, wj) in self.arcs:
  104. if wi == stack_idx0:
  105. if (wj > wi) and (wj > right_most):
  106. right_most = wj
  107. dep_right_most = r
  108. if (wj < wi) and (wj < left_most):
  109. left_most = wj
  110. dep_left_most = r
  111. if self._check_informative(dep_left_most):
  112. result.append('STK_0_LDEP_' + dep_left_most)
  113. if self._check_informative(dep_right_most):
  114. result.append('STK_0_RDEP_' + dep_right_most)
  115. # Check Buffered 0
  116. if len(self.buffer) > 0:
  117. # Buffer 0
  118. buffer_idx0 = self.buffer[0]
  119. token = self._tokens[buffer_idx0]
  120. if self._check_informative(token['word'], True):
  121. result.append('BUF_0_FORM_' + token['word'])
  122. if 'lemma' in token and self._check_informative(token['lemma']):
  123. result.append('BUF_0_LEMMA_' + token['lemma'])
  124. if self._check_informative(token['tag']):
  125. result.append('BUF_0_POS_' + token['tag'])
  126. if 'feats' in token and self._check_informative(token['feats']):
  127. feats = token['feats'].split("|")
  128. for feat in feats:
  129. result.append('BUF_0_FEATS_' + feat)
  130. # Buffer 1
  131. if len(self.buffer) > 1:
  132. buffer_idx1 = self.buffer[1]
  133. token = self._tokens[buffer_idx1]
  134. if self._check_informative(token['word'], True):
  135. result.append('BUF_1_FORM_' + token['word'])
  136. if self._check_informative(token['tag']):
  137. result.append('BUF_1_POS_' + token['tag'])
  138. if len(self.buffer) > 2:
  139. buffer_idx2 = self.buffer[2]
  140. token = self._tokens[buffer_idx2]
  141. if self._check_informative(token['tag']):
  142. result.append('BUF_2_POS_' + token['tag'])
  143. if len(self.buffer) > 3:
  144. buffer_idx3 = self.buffer[3]
  145. token = self._tokens[buffer_idx3]
  146. if self._check_informative(token['tag']):
  147. result.append('BUF_3_POS_' + token['tag'])
  148. # Left most, right most dependency of stack[0]
  149. left_most = 1000000
  150. right_most = -1
  151. dep_left_most = ''
  152. dep_right_most = ''
  153. for (wi, r, wj) in self.arcs:
  154. if wi == buffer_idx0:
  155. if (wj > wi) and (wj > right_most):
  156. right_most = wj
  157. dep_right_most = r
  158. if (wj < wi) and (wj < left_most):
  159. left_most = wj
  160. dep_left_most = r
  161. if self._check_informative(dep_left_most):
  162. result.append('BUF_0_LDEP_' + dep_left_most)
  163. if self._check_informative(dep_right_most):
  164. result.append('BUF_0_RDEP_' + dep_right_most)
  165. return result
  166. class Transition(object):
  167. """
  168. This class defines a set of transition which is applied to a configuration to get another configuration
  169. Note that for different parsing algorithm, the transition is different.
  170. """
  171. # Define set of transitions
  172. LEFT_ARC = 'LEFTARC'
  173. RIGHT_ARC = 'RIGHTARC'
  174. SHIFT = 'SHIFT'
  175. REDUCE = 'REDUCE'
  176. def __init__(self, alg_option):
  177. """
  178. :param alg_option: the algorithm option of this parser. Currently support `arc-standard` and `arc-eager` algorithm
  179. :type alg_option: str
  180. """
  181. self._algo = alg_option
  182. if alg_option not in [
  183. TransitionParser.ARC_STANDARD,
  184. TransitionParser.ARC_EAGER,
  185. ]:
  186. raise ValueError(
  187. " Currently we only support %s and %s "
  188. % (TransitionParser.ARC_STANDARD, TransitionParser.ARC_EAGER)
  189. )
  190. def left_arc(self, conf, relation):
  191. """
  192. Note that the algorithm for left-arc is quite similar except for precondition for both arc-standard and arc-eager
  193. :param configuration: is the current configuration
  194. :return : A new configuration or -1 if the pre-condition is not satisfied
  195. """
  196. if (len(conf.buffer) <= 0) or (len(conf.stack) <= 0):
  197. return -1
  198. if conf.buffer[0] == 0:
  199. # here is the Root element
  200. return -1
  201. idx_wi = conf.stack[len(conf.stack) - 1]
  202. flag = True
  203. if self._algo == TransitionParser.ARC_EAGER:
  204. for (idx_parent, r, idx_child) in conf.arcs:
  205. if idx_child == idx_wi:
  206. flag = False
  207. if flag:
  208. conf.stack.pop()
  209. idx_wj = conf.buffer[0]
  210. conf.arcs.append((idx_wj, relation, idx_wi))
  211. else:
  212. return -1
  213. def right_arc(self, conf, relation):
  214. """
  215. Note that the algorithm for right-arc is DIFFERENT for arc-standard and arc-eager
  216. :param configuration: is the current configuration
  217. :return : A new configuration or -1 if the pre-condition is not satisfied
  218. """
  219. if (len(conf.buffer) <= 0) or (len(conf.stack) <= 0):
  220. return -1
  221. if self._algo == TransitionParser.ARC_STANDARD:
  222. idx_wi = conf.stack.pop()
  223. idx_wj = conf.buffer[0]
  224. conf.buffer[0] = idx_wi
  225. conf.arcs.append((idx_wi, relation, idx_wj))
  226. else: # arc-eager
  227. idx_wi = conf.stack[len(conf.stack) - 1]
  228. idx_wj = conf.buffer.pop(0)
  229. conf.stack.append(idx_wj)
  230. conf.arcs.append((idx_wi, relation, idx_wj))
  231. def reduce(self, conf):
  232. """
  233. Note that the algorithm for reduce is only available for arc-eager
  234. :param configuration: is the current configuration
  235. :return : A new configuration or -1 if the pre-condition is not satisfied
  236. """
  237. if self._algo != TransitionParser.ARC_EAGER:
  238. return -1
  239. if len(conf.stack) <= 0:
  240. return -1
  241. idx_wi = conf.stack[len(conf.stack) - 1]
  242. flag = False
  243. for (idx_parent, r, idx_child) in conf.arcs:
  244. if idx_child == idx_wi:
  245. flag = True
  246. if flag:
  247. conf.stack.pop() # reduce it
  248. else:
  249. return -1
  250. def shift(self, conf):
  251. """
  252. Note that the algorithm for shift is the SAME for arc-standard and arc-eager
  253. :param configuration: is the current configuration
  254. :return : A new configuration or -1 if the pre-condition is not satisfied
  255. """
  256. if len(conf.buffer) <= 0:
  257. return -1
  258. idx_wi = conf.buffer.pop(0)
  259. conf.stack.append(idx_wi)
  260. class TransitionParser(ParserI):
  261. """
  262. Class for transition based parser. Implement 2 algorithms which are "arc-standard" and "arc-eager"
  263. """
  264. ARC_STANDARD = 'arc-standard'
  265. ARC_EAGER = 'arc-eager'
  266. def __init__(self, algorithm):
  267. """
  268. :param algorithm: the algorithm option of this parser. Currently support `arc-standard` and `arc-eager` algorithm
  269. :type algorithm: str
  270. """
  271. if not (algorithm in [self.ARC_STANDARD, self.ARC_EAGER]):
  272. raise ValueError(
  273. " Currently we only support %s and %s "
  274. % (self.ARC_STANDARD, self.ARC_EAGER)
  275. )
  276. self._algorithm = algorithm
  277. self._dictionary = {}
  278. self._transition = {}
  279. self._match_transition = {}
  280. def _get_dep_relation(self, idx_parent, idx_child, depgraph):
  281. p_node = depgraph.nodes[idx_parent]
  282. c_node = depgraph.nodes[idx_child]
  283. if c_node['word'] is None:
  284. return None # Root word
  285. if c_node['head'] == p_node['address']:
  286. return c_node['rel']
  287. else:
  288. return None
  289. def _convert_to_binary_features(self, features):
  290. """
  291. :param features: list of feature string which is needed to convert to binary features
  292. :type features: list(str)
  293. :return : string of binary features in libsvm format which is 'featureID:value' pairs
  294. """
  295. unsorted_result = []
  296. for feature in features:
  297. self._dictionary.setdefault(feature, len(self._dictionary))
  298. unsorted_result.append(self._dictionary[feature])
  299. # Default value of each feature is 1.0
  300. return ' '.join(
  301. str(featureID) + ':1.0' for featureID in sorted(unsorted_result)
  302. )
  303. def _is_projective(self, depgraph):
  304. arc_list = []
  305. for key in depgraph.nodes:
  306. node = depgraph.nodes[key]
  307. if 'head' in node:
  308. childIdx = node['address']
  309. parentIdx = node['head']
  310. if parentIdx is not None:
  311. arc_list.append((parentIdx, childIdx))
  312. for (parentIdx, childIdx) in arc_list:
  313. # Ensure that childIdx < parentIdx
  314. if childIdx > parentIdx:
  315. temp = childIdx
  316. childIdx = parentIdx
  317. parentIdx = temp
  318. for k in range(childIdx + 1, parentIdx):
  319. for m in range(len(depgraph.nodes)):
  320. if (m < childIdx) or (m > parentIdx):
  321. if (k, m) in arc_list:
  322. return False
  323. if (m, k) in arc_list:
  324. return False
  325. return True
  326. def _write_to_file(self, key, binary_features, input_file):
  327. """
  328. write the binary features to input file and update the transition dictionary
  329. """
  330. self._transition.setdefault(key, len(self._transition) + 1)
  331. self._match_transition[self._transition[key]] = key
  332. input_str = str(self._transition[key]) + ' ' + binary_features + '\n'
  333. input_file.write(input_str.encode('utf-8'))
  334. def _create_training_examples_arc_std(self, depgraphs, input_file):
  335. """
  336. Create the training example in the libsvm format and write it to the input_file.
  337. Reference : Page 32, Chapter 3. Dependency Parsing by Sandra Kubler, Ryan McDonal and Joakim Nivre (2009)
  338. """
  339. operation = Transition(self.ARC_STANDARD)
  340. count_proj = 0
  341. training_seq = []
  342. for depgraph in depgraphs:
  343. if not self._is_projective(depgraph):
  344. continue
  345. count_proj += 1
  346. conf = Configuration(depgraph)
  347. while len(conf.buffer) > 0:
  348. b0 = conf.buffer[0]
  349. features = conf.extract_features()
  350. binary_features = self._convert_to_binary_features(features)
  351. if len(conf.stack) > 0:
  352. s0 = conf.stack[len(conf.stack) - 1]
  353. # Left-arc operation
  354. rel = self._get_dep_relation(b0, s0, depgraph)
  355. if rel is not None:
  356. key = Transition.LEFT_ARC + ':' + rel
  357. self._write_to_file(key, binary_features, input_file)
  358. operation.left_arc(conf, rel)
  359. training_seq.append(key)
  360. continue
  361. # Right-arc operation
  362. rel = self._get_dep_relation(s0, b0, depgraph)
  363. if rel is not None:
  364. precondition = True
  365. # Get the max-index of buffer
  366. maxID = conf._max_address
  367. for w in range(maxID + 1):
  368. if w != b0:
  369. relw = self._get_dep_relation(b0, w, depgraph)
  370. if relw is not None:
  371. if (b0, relw, w) not in conf.arcs:
  372. precondition = False
  373. if precondition:
  374. key = Transition.RIGHT_ARC + ':' + rel
  375. self._write_to_file(key, binary_features, input_file)
  376. operation.right_arc(conf, rel)
  377. training_seq.append(key)
  378. continue
  379. # Shift operation as the default
  380. key = Transition.SHIFT
  381. self._write_to_file(key, binary_features, input_file)
  382. operation.shift(conf)
  383. training_seq.append(key)
  384. print(" Number of training examples : " + str(len(depgraphs)))
  385. print(" Number of valid (projective) examples : " + str(count_proj))
  386. return training_seq
  387. def _create_training_examples_arc_eager(self, depgraphs, input_file):
  388. """
  389. Create the training example in the libsvm format and write it to the input_file.
  390. Reference : 'A Dynamic Oracle for Arc-Eager Dependency Parsing' by Joav Goldberg and Joakim Nivre
  391. """
  392. operation = Transition(self.ARC_EAGER)
  393. countProj = 0
  394. training_seq = []
  395. for depgraph in depgraphs:
  396. if not self._is_projective(depgraph):
  397. continue
  398. countProj += 1
  399. conf = Configuration(depgraph)
  400. while len(conf.buffer) > 0:
  401. b0 = conf.buffer[0]
  402. features = conf.extract_features()
  403. binary_features = self._convert_to_binary_features(features)
  404. if len(conf.stack) > 0:
  405. s0 = conf.stack[len(conf.stack) - 1]
  406. # Left-arc operation
  407. rel = self._get_dep_relation(b0, s0, depgraph)
  408. if rel is not None:
  409. key = Transition.LEFT_ARC + ':' + rel
  410. self._write_to_file(key, binary_features, input_file)
  411. operation.left_arc(conf, rel)
  412. training_seq.append(key)
  413. continue
  414. # Right-arc operation
  415. rel = self._get_dep_relation(s0, b0, depgraph)
  416. if rel is not None:
  417. key = Transition.RIGHT_ARC + ':' + rel
  418. self._write_to_file(key, binary_features, input_file)
  419. operation.right_arc(conf, rel)
  420. training_seq.append(key)
  421. continue
  422. # reduce operation
  423. flag = False
  424. for k in range(s0):
  425. if self._get_dep_relation(k, b0, depgraph) is not None:
  426. flag = True
  427. if self._get_dep_relation(b0, k, depgraph) is not None:
  428. flag = True
  429. if flag:
  430. key = Transition.REDUCE
  431. self._write_to_file(key, binary_features, input_file)
  432. operation.reduce(conf)
  433. training_seq.append(key)
  434. continue
  435. # Shift operation as the default
  436. key = Transition.SHIFT
  437. self._write_to_file(key, binary_features, input_file)
  438. operation.shift(conf)
  439. training_seq.append(key)
  440. print(" Number of training examples : " + str(len(depgraphs)))
  441. print(" Number of valid (projective) examples : " + str(countProj))
  442. return training_seq
  443. def train(self, depgraphs, modelfile, verbose=True):
  444. """
  445. :param depgraphs : list of DependencyGraph as the training data
  446. :type depgraphs : DependencyGraph
  447. :param modelfile : file name to save the trained model
  448. :type modelfile : str
  449. """
  450. try:
  451. input_file = tempfile.NamedTemporaryFile(
  452. prefix='transition_parse.train', dir=tempfile.gettempdir(), delete=False
  453. )
  454. if self._algorithm == self.ARC_STANDARD:
  455. self._create_training_examples_arc_std(depgraphs, input_file)
  456. else:
  457. self._create_training_examples_arc_eager(depgraphs, input_file)
  458. input_file.close()
  459. # Using the temporary file to train the libsvm classifier
  460. x_train, y_train = load_svmlight_file(input_file.name)
  461. # The parameter is set according to the paper:
  462. # Algorithms for Deterministic Incremental Dependency Parsing by Joakim Nivre
  463. # Todo : because of probability = True => very slow due to
  464. # cross-validation. Need to improve the speed here
  465. model = svm.SVC(
  466. kernel='poly',
  467. degree=2,
  468. coef0=0,
  469. gamma=0.2,
  470. C=0.5,
  471. verbose=verbose,
  472. probability=True,
  473. )
  474. model.fit(x_train, y_train)
  475. # Save the model to file name (as pickle)
  476. pickle.dump(model, open(modelfile, 'wb'))
  477. finally:
  478. remove(input_file.name)
  479. def parse(self, depgraphs, modelFile):
  480. """
  481. :param depgraphs: the list of test sentence, each sentence is represented as a dependency graph where the 'head' information is dummy
  482. :type depgraphs: list(DependencyGraph)
  483. :param modelfile: the model file
  484. :type modelfile: str
  485. :return: list (DependencyGraph) with the 'head' and 'rel' information
  486. """
  487. result = []
  488. # First load the model
  489. model = pickle.load(open(modelFile, 'rb'))
  490. operation = Transition(self._algorithm)
  491. for depgraph in depgraphs:
  492. conf = Configuration(depgraph)
  493. while len(conf.buffer) > 0:
  494. features = conf.extract_features()
  495. col = []
  496. row = []
  497. data = []
  498. for feature in features:
  499. if feature in self._dictionary:
  500. col.append(self._dictionary[feature])
  501. row.append(0)
  502. data.append(1.0)
  503. np_col = array(sorted(col)) # NB : index must be sorted
  504. np_row = array(row)
  505. np_data = array(data)
  506. x_test = sparse.csr_matrix(
  507. (np_data, (np_row, np_col)), shape=(1, len(self._dictionary))
  508. )
  509. # It's best to use decision function as follow BUT it's not supported yet for sparse SVM
  510. # Using decision funcion to build the votes array
  511. # dec_func = model.decision_function(x_test)[0]
  512. # votes = {}
  513. # k = 0
  514. # for i in range(len(model.classes_)):
  515. # for j in range(i+1, len(model.classes_)):
  516. # #if dec_func[k] > 0:
  517. # votes.setdefault(i,0)
  518. # votes[i] +=1
  519. # else:
  520. # votes.setdefault(j,0)
  521. # votes[j] +=1
  522. # k +=1
  523. # Sort votes according to the values
  524. # sorted_votes = sorted(votes.items(), key=itemgetter(1), reverse=True)
  525. # We will use predict_proba instead of decision_function
  526. prob_dict = {}
  527. pred_prob = model.predict_proba(x_test)[0]
  528. for i in range(len(pred_prob)):
  529. prob_dict[i] = pred_prob[i]
  530. sorted_Prob = sorted(prob_dict.items(), key=itemgetter(1), reverse=True)
  531. # Note that SHIFT is always a valid operation
  532. for (y_pred_idx, confidence) in sorted_Prob:
  533. # y_pred = model.predict(x_test)[0]
  534. # From the prediction match to the operation
  535. y_pred = model.classes_[y_pred_idx]
  536. if y_pred in self._match_transition:
  537. strTransition = self._match_transition[y_pred]
  538. baseTransition = strTransition.split(":")[0]
  539. if baseTransition == Transition.LEFT_ARC:
  540. if (
  541. operation.left_arc(conf, strTransition.split(":")[1])
  542. != -1
  543. ):
  544. break
  545. elif baseTransition == Transition.RIGHT_ARC:
  546. if (
  547. operation.right_arc(conf, strTransition.split(":")[1])
  548. != -1
  549. ):
  550. break
  551. elif baseTransition == Transition.REDUCE:
  552. if operation.reduce(conf) != -1:
  553. break
  554. elif baseTransition == Transition.SHIFT:
  555. if operation.shift(conf) != -1:
  556. break
  557. else:
  558. raise ValueError(
  559. "The predicted transition is not recognized, expected errors"
  560. )
  561. # Finish with operations build the dependency graph from Conf.arcs
  562. new_depgraph = deepcopy(depgraph)
  563. for key in new_depgraph.nodes:
  564. node = new_depgraph.nodes[key]
  565. node['rel'] = ''
  566. # With the default, all the token depend on the Root
  567. node['head'] = 0
  568. for (head, rel, child) in conf.arcs:
  569. c_node = new_depgraph.nodes[child]
  570. c_node['head'] = head
  571. c_node['rel'] = rel
  572. result.append(new_depgraph)
  573. return result
  574. def demo():
  575. """
  576. >>> from nltk.parse import DependencyGraph, DependencyEvaluator
  577. >>> from nltk.parse.transitionparser import TransitionParser, Configuration, Transition
  578. >>> gold_sent = DependencyGraph(\"""
  579. ... Economic JJ 2 ATT
  580. ... news NN 3 SBJ
  581. ... has VBD 0 ROOT
  582. ... little JJ 5 ATT
  583. ... effect NN 3 OBJ
  584. ... on IN 5 ATT
  585. ... financial JJ 8 ATT
  586. ... markets NNS 6 PC
  587. ... . . 3 PU
  588. ... \""")
  589. >>> conf = Configuration(gold_sent)
  590. ###################### Check the Initial Feature ########################
  591. >>> print(', '.join(conf.extract_features()))
  592. STK_0_POS_TOP, BUF_0_FORM_Economic, BUF_0_LEMMA_Economic, BUF_0_POS_JJ, BUF_1_FORM_news, BUF_1_POS_NN, BUF_2_POS_VBD, BUF_3_POS_JJ
  593. ###################### Check The Transition #######################
  594. Check the Initialized Configuration
  595. >>> print(conf)
  596. Stack : [0] Buffer : [1, 2, 3, 4, 5, 6, 7, 8, 9] Arcs : []
  597. A. Do some transition checks for ARC-STANDARD
  598. >>> operation = Transition('arc-standard')
  599. >>> operation.shift(conf)
  600. >>> operation.left_arc(conf, "ATT")
  601. >>> operation.shift(conf)
  602. >>> operation.left_arc(conf,"SBJ")
  603. >>> operation.shift(conf)
  604. >>> operation.shift(conf)
  605. >>> operation.left_arc(conf, "ATT")
  606. >>> operation.shift(conf)
  607. >>> operation.shift(conf)
  608. >>> operation.shift(conf)
  609. >>> operation.left_arc(conf, "ATT")
  610. Middle Configuration and Features Check
  611. >>> print(conf)
  612. Stack : [0, 3, 5, 6] Buffer : [8, 9] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (5, 'ATT', 4), (8, 'ATT', 7)]
  613. >>> print(', '.join(conf.extract_features()))
  614. STK_0_FORM_on, STK_0_LEMMA_on, STK_0_POS_IN, STK_1_POS_NN, BUF_0_FORM_markets, BUF_0_LEMMA_markets, BUF_0_POS_NNS, BUF_1_FORM_., BUF_1_POS_., BUF_0_LDEP_ATT
  615. >>> operation.right_arc(conf, "PC")
  616. >>> operation.right_arc(conf, "ATT")
  617. >>> operation.right_arc(conf, "OBJ")
  618. >>> operation.shift(conf)
  619. >>> operation.right_arc(conf, "PU")
  620. >>> operation.right_arc(conf, "ROOT")
  621. >>> operation.shift(conf)
  622. Terminated Configuration Check
  623. >>> print(conf)
  624. Stack : [0] Buffer : [] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (5, 'ATT', 4), (8, 'ATT', 7), (6, 'PC', 8), (5, 'ATT', 6), (3, 'OBJ', 5), (3, 'PU', 9), (0, 'ROOT', 3)]
  625. B. Do some transition checks for ARC-EAGER
  626. >>> conf = Configuration(gold_sent)
  627. >>> operation = Transition('arc-eager')
  628. >>> operation.shift(conf)
  629. >>> operation.left_arc(conf,'ATT')
  630. >>> operation.shift(conf)
  631. >>> operation.left_arc(conf,'SBJ')
  632. >>> operation.right_arc(conf,'ROOT')
  633. >>> operation.shift(conf)
  634. >>> operation.left_arc(conf,'ATT')
  635. >>> operation.right_arc(conf,'OBJ')
  636. >>> operation.right_arc(conf,'ATT')
  637. >>> operation.shift(conf)
  638. >>> operation.left_arc(conf,'ATT')
  639. >>> operation.right_arc(conf,'PC')
  640. >>> operation.reduce(conf)
  641. >>> operation.reduce(conf)
  642. >>> operation.reduce(conf)
  643. >>> operation.right_arc(conf,'PU')
  644. >>> print(conf)
  645. Stack : [0, 3, 9] Buffer : [] Arcs : [(2, 'ATT', 1), (3, 'SBJ', 2), (0, 'ROOT', 3), (5, 'ATT', 4), (3, 'OBJ', 5), (5, 'ATT', 6), (8, 'ATT', 7), (6, 'PC', 8), (3, 'PU', 9)]
  646. ###################### Check The Training Function #######################
  647. A. Check the ARC-STANDARD training
  648. >>> import tempfile
  649. >>> import os
  650. >>> input_file = tempfile.NamedTemporaryFile(prefix='transition_parse.train', dir=tempfile.gettempdir(), delete=False)
  651. >>> parser_std = TransitionParser('arc-standard')
  652. >>> print(', '.join(parser_std._create_training_examples_arc_std([gold_sent], input_file)))
  653. Number of training examples : 1
  654. Number of valid (projective) examples : 1
  655. SHIFT, LEFTARC:ATT, SHIFT, LEFTARC:SBJ, SHIFT, SHIFT, LEFTARC:ATT, SHIFT, SHIFT, SHIFT, LEFTARC:ATT, RIGHTARC:PC, RIGHTARC:ATT, RIGHTARC:OBJ, SHIFT, RIGHTARC:PU, RIGHTARC:ROOT, SHIFT
  656. >>> parser_std.train([gold_sent],'temp.arcstd.model', verbose=False)
  657. Number of training examples : 1
  658. Number of valid (projective) examples : 1
  659. >>> remove(input_file.name)
  660. B. Check the ARC-EAGER training
  661. >>> input_file = tempfile.NamedTemporaryFile(prefix='transition_parse.train', dir=tempfile.gettempdir(),delete=False)
  662. >>> parser_eager = TransitionParser('arc-eager')
  663. >>> print(', '.join(parser_eager._create_training_examples_arc_eager([gold_sent], input_file)))
  664. Number of training examples : 1
  665. Number of valid (projective) examples : 1
  666. SHIFT, LEFTARC:ATT, SHIFT, LEFTARC:SBJ, RIGHTARC:ROOT, SHIFT, LEFTARC:ATT, RIGHTARC:OBJ, RIGHTARC:ATT, SHIFT, LEFTARC:ATT, RIGHTARC:PC, REDUCE, REDUCE, REDUCE, RIGHTARC:PU
  667. >>> parser_eager.train([gold_sent],'temp.arceager.model', verbose=False)
  668. Number of training examples : 1
  669. Number of valid (projective) examples : 1
  670. >>> remove(input_file.name)
  671. ###################### Check The Parsing Function ########################
  672. A. Check the ARC-STANDARD parser
  673. >>> result = parser_std.parse([gold_sent], 'temp.arcstd.model')
  674. >>> de = DependencyEvaluator(result, [gold_sent])
  675. >>> de.eval() >= (0, 0)
  676. True
  677. B. Check the ARC-EAGER parser
  678. >>> result = parser_eager.parse([gold_sent], 'temp.arceager.model')
  679. >>> de = DependencyEvaluator(result, [gold_sent])
  680. >>> de.eval() >= (0, 0)
  681. True
  682. Remove test temporary files
  683. >>> remove('temp.arceager.model')
  684. >>> remove('temp.arcstd.model')
  685. Note that result is very poor because of only one training example.
  686. """