evaluate.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Natural Language Toolkit: evaluation of dependency parser
  2. #
  3. # Author: Long Duong <longdt219@gmail.com>
  4. #
  5. # Copyright (C) 2001-2019 NLTK Project
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. from __future__ import division
  9. import unicodedata
  10. class DependencyEvaluator(object):
  11. """
  12. Class for measuring labelled and unlabelled attachment score for
  13. dependency parsing. Note that the evaluation ignores punctuation.
  14. >>> from nltk.parse import DependencyGraph, DependencyEvaluator
  15. >>> gold_sent = DependencyGraph(\"""
  16. ... Pierre NNP 2 NMOD
  17. ... Vinken NNP 8 SUB
  18. ... , , 2 P
  19. ... 61 CD 5 NMOD
  20. ... years NNS 6 AMOD
  21. ... old JJ 2 NMOD
  22. ... , , 2 P
  23. ... will MD 0 ROOT
  24. ... join VB 8 VC
  25. ... the DT 11 NMOD
  26. ... board NN 9 OBJ
  27. ... as IN 9 VMOD
  28. ... a DT 15 NMOD
  29. ... nonexecutive JJ 15 NMOD
  30. ... director NN 12 PMOD
  31. ... Nov. NNP 9 VMOD
  32. ... 29 CD 16 NMOD
  33. ... . . 9 VMOD
  34. ... \""")
  35. >>> parsed_sent = DependencyGraph(\"""
  36. ... Pierre NNP 8 NMOD
  37. ... Vinken NNP 1 SUB
  38. ... , , 3 P
  39. ... 61 CD 6 NMOD
  40. ... years NNS 6 AMOD
  41. ... old JJ 2 NMOD
  42. ... , , 3 AMOD
  43. ... will MD 0 ROOT
  44. ... join VB 8 VC
  45. ... the DT 11 AMOD
  46. ... board NN 9 OBJECT
  47. ... as IN 9 NMOD
  48. ... a DT 15 NMOD
  49. ... nonexecutive JJ 15 NMOD
  50. ... director NN 12 PMOD
  51. ... Nov. NNP 9 VMOD
  52. ... 29 CD 16 NMOD
  53. ... . . 9 VMOD
  54. ... \""")
  55. >>> de = DependencyEvaluator([parsed_sent],[gold_sent])
  56. >>> las, uas = de.eval()
  57. >>> las
  58. 0.6...
  59. >>> uas
  60. 0.8...
  61. >>> abs(uas - 0.8) < 0.00001
  62. True
  63. """
  64. def __init__(self, parsed_sents, gold_sents):
  65. """
  66. :param parsed_sents: the list of parsed_sents as the output of parser
  67. :type parsed_sents: list(DependencyGraph)
  68. """
  69. self._parsed_sents = parsed_sents
  70. self._gold_sents = gold_sents
  71. def _remove_punct(self, inStr):
  72. """
  73. Function to remove punctuation from Unicode string.
  74. :param input: the input string
  75. :return: Unicode string after remove all punctuation
  76. """
  77. punc_cat = set(["Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"])
  78. return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
  79. def eval(self):
  80. """
  81. Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
  82. :return : tuple(float,float)
  83. """
  84. if len(self._parsed_sents) != len(self._gold_sents):
  85. raise ValueError(
  86. " Number of parsed sentence is different with number of gold sentence."
  87. )
  88. corr = 0
  89. corrL = 0
  90. total = 0
  91. for i in range(len(self._parsed_sents)):
  92. parsed_sent_nodes = self._parsed_sents[i].nodes
  93. gold_sent_nodes = self._gold_sents[i].nodes
  94. if len(parsed_sent_nodes) != len(gold_sent_nodes):
  95. raise ValueError("Sentences must have equal length.")
  96. for parsed_node_address, parsed_node in parsed_sent_nodes.items():
  97. gold_node = gold_sent_nodes[parsed_node_address]
  98. if parsed_node["word"] is None:
  99. continue
  100. if parsed_node["word"] != gold_node["word"]:
  101. raise ValueError("Sentence sequence is not matched.")
  102. # Ignore if word is punctuation by default
  103. # if (parsed_sent[j]["word"] in string.punctuation):
  104. if self._remove_punct(parsed_node["word"]) == "":
  105. continue
  106. total += 1
  107. if parsed_node["head"] == gold_node["head"]:
  108. corr += 1
  109. if parsed_node["rel"] == gold_node["rel"]:
  110. corrL += 1
  111. return corrL / total, corr / total