gdfa.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: GDFA word alignment symmetrization
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Authors: Liling Tan
  6. # URL: <http://nltk.org/>
  7. # For license information, see LICENSE.TXT
  8. from collections import defaultdict
  9. def grow_diag_final_and(srclen, trglen, e2f, f2e):
  10. """
  11. This module symmetrisatizes the source-to-target and target-to-source
  12. word alignment output and produces, aka. GDFA algorithm (Koehn, 2005).
  13. Step 1: Find the intersection of the bidirectional alignment.
  14. Step 2: Search for additional neighbor alignment points to be added, given
  15. these criteria: (i) neighbor alignments points are not in the
  16. intersection and (ii) neighbor alignments are in the union.
  17. Step 3: Add all other alignment points thats not in the intersection, not in
  18. the neighboring alignments that met the criteria but in the original
  19. foward/backward alignment outputs.
  20. >>> forw = ('0-0 2-1 9-2 21-3 10-4 7-5 11-6 9-7 12-8 1-9 3-10 '
  21. ... '4-11 17-12 17-13 25-14 13-15 24-16 11-17 28-18')
  22. >>> back = ('0-0 1-9 2-9 3-10 4-11 5-12 6-6 7-5 8-6 9-7 10-4 '
  23. ... '11-6 12-8 13-12 15-12 17-13 18-13 19-12 20-13 '
  24. ... '21-3 22-12 23-14 24-17 25-15 26-17 27-18 28-18')
  25. >>> srctext = ("この よう な ハロー 白色 わい 星 の L 関数 "
  26. ... "は L と 共 に 不連続 に 増加 する こと が "
  27. ... "期待 さ れる こと を 示し た 。")
  28. >>> trgtext = ("Therefore , we expect that the luminosity function "
  29. ... "of such halo white dwarfs increases discontinuously "
  30. ... "with the luminosity .")
  31. >>> srclen = len(srctext.split())
  32. >>> trglen = len(trgtext.split())
  33. >>>
  34. >>> gdfa = grow_diag_final_and(srclen, trglen, forw, back)
  35. >>> gdfa == sorted(set([(28, 18), (6, 6), (24, 17), (2, 1), (15, 12), (13, 12),
  36. ... (2, 9), (3, 10), (26, 17), (25, 15), (8, 6), (9, 7), (20,
  37. ... 13), (18, 13), (0, 0), (10, 4), (13, 15), (23, 14), (7, 5),
  38. ... (25, 14), (1, 9), (17, 13), (4, 11), (11, 17), (9, 2), (22,
  39. ... 12), (27, 18), (24, 16), (21, 3), (19, 12), (17, 12), (5,
  40. ... 12), (11, 6), (12, 8)]))
  41. True
  42. References:
  43. Koehn, P., A. Axelrod, A. Birch, C. Callison, M. Osborne, and D. Talbot.
  44. 2005. Edinburgh System Description for the 2005 IWSLT Speech
  45. Translation Evaluation. In MT Eval Workshop.
  46. :type srclen: int
  47. :param srclen: the number of tokens in the source language
  48. :type trglen: int
  49. :param trglen: the number of tokens in the target language
  50. :type e2f: str
  51. :param e2f: the forward word alignment outputs from source-to-target
  52. language (in pharaoh output format)
  53. :type f2e: str
  54. :param f2e: the backward word alignment outputs from target-to-source
  55. language (in pharaoh output format)
  56. :rtype: set(tuple(int))
  57. :return: the symmetrized alignment points from the GDFA algorithm
  58. """
  59. # Converts pharaoh text format into list of tuples.
  60. e2f = [tuple(map(int, a.split('-'))) for a in e2f.split()]
  61. f2e = [tuple(map(int, a.split('-'))) for a in f2e.split()]
  62. neighbors = [(-1, 0), (0, -1), (1, 0), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
  63. alignment = set(e2f).intersection(set(f2e)) # Find the intersection.
  64. union = set(e2f).union(set(f2e))
  65. # *aligned* is used to check if neighbors are aligned in grow_diag()
  66. aligned = defaultdict(set)
  67. for i, j in alignment:
  68. aligned['e'].add(i)
  69. aligned['f'].add(j)
  70. def grow_diag():
  71. """
  72. Search for the neighbor points and them to the intersected alignment
  73. points if criteria are met.
  74. """
  75. prev_len = len(alignment) - 1
  76. # iterate until no new points added
  77. while prev_len < len(alignment):
  78. no_new_points = True
  79. # for english word e = 0 ... en
  80. for e in range(srclen):
  81. # for foreign word f = 0 ... fn
  82. for f in range(trglen):
  83. # if ( e aligned with f)
  84. if (e, f) in alignment:
  85. # for each neighboring point (e-new, f-new)
  86. for neighbor in neighbors:
  87. neighbor = tuple(i + j for i, j in zip((e, f), neighbor))
  88. e_new, f_new = neighbor
  89. # if ( ( e-new not aligned and f-new not aligned)
  90. # and (e-new, f-new in union(e2f, f2e) )
  91. if (
  92. e_new not in aligned and f_new not in aligned
  93. ) and neighbor in union:
  94. alignment.add(neighbor)
  95. aligned['e'].add(e_new)
  96. aligned['f'].add(f_new)
  97. prev_len += 1
  98. no_new_points = False
  99. # iterate until no new points added
  100. if no_new_points:
  101. break
  102. def final_and(a):
  103. """
  104. Adds remaining points that are not in the intersection, not in the
  105. neighboring alignments but in the original *e2f* and *f2e* alignments
  106. """
  107. # for english word e = 0 ... en
  108. for e_new in range(srclen):
  109. # for foreign word f = 0 ... fn
  110. for f_new in range(trglen):
  111. # if ( ( e-new not aligned and f-new not aligned)
  112. # and (e-new, f-new in union(e2f, f2e) )
  113. if (
  114. e_new not in aligned
  115. and f_new not in aligned
  116. and (e_new, f_new) in union
  117. ):
  118. alignment.add((e_new, f_new))
  119. aligned['e'].add(e_new)
  120. aligned['f'].add(f_new)
  121. grow_diag()
  122. final_and(e2f)
  123. final_and(f2e)
  124. return sorted(alignment)