api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Natural Language Toolkit: API for alignment and translation objects
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Will Zhang <wilzzha@gmail.com>
  5. # Guan Gui <ggui@student.unimelb.edu.au>
  6. # Steven Bird <stevenbird1@gmail.com>
  7. # Tah Wei Hoon <hoon.tw@gmail.com>
  8. # URL: <http://nltk.org/>
  9. # For license information, see LICENSE.TXT
  10. from __future__ import print_function, unicode_literals
  11. import subprocess
  12. from collections import namedtuple
  13. from nltk.compat import python_2_unicode_compatible
  14. @python_2_unicode_compatible
  15. class AlignedSent(object):
  16. """
  17. Return an aligned sentence object, which encapsulates two sentences
  18. along with an ``Alignment`` between them.
  19. Typically used in machine translation to represent a sentence and
  20. its translation.
  21. >>> from nltk.translate import AlignedSent, Alignment
  22. >>> algnsent = AlignedSent(['klein', 'ist', 'das', 'Haus'],
  23. ... ['the', 'house', 'is', 'small'], Alignment.fromstring('0-3 1-2 2-0 3-1'))
  24. >>> algnsent.words
  25. ['klein', 'ist', 'das', 'Haus']
  26. >>> algnsent.mots
  27. ['the', 'house', 'is', 'small']
  28. >>> algnsent.alignment
  29. Alignment([(0, 3), (1, 2), (2, 0), (3, 1)])
  30. >>> from nltk.corpus import comtrans
  31. >>> print(comtrans.aligned_sents()[54])
  32. <AlignedSent: 'Weshalb also sollten...' -> 'So why should EU arm...'>
  33. >>> print(comtrans.aligned_sents()[54].alignment)
  34. 0-0 0-1 1-0 2-2 3-4 3-5 4-7 5-8 6-3 7-9 8-9 9-10 9-11 10-12 11-6 12-6 13-13
  35. :param words: Words in the target language sentence
  36. :type words: list(str)
  37. :param mots: Words in the source language sentence
  38. :type mots: list(str)
  39. :param alignment: Word-level alignments between ``words`` and ``mots``.
  40. Each alignment is represented as a 2-tuple (words_index, mots_index).
  41. :type alignment: Alignment
  42. """
  43. def __init__(self, words, mots, alignment=None):
  44. self._words = words
  45. self._mots = mots
  46. if alignment is None:
  47. self.alignment = Alignment([])
  48. else:
  49. assert type(alignment) is Alignment
  50. self.alignment = alignment
  51. @property
  52. def words(self):
  53. return self._words
  54. @property
  55. def mots(self):
  56. return self._mots
  57. def _get_alignment(self):
  58. return self._alignment
  59. def _set_alignment(self, alignment):
  60. _check_alignment(len(self.words), len(self.mots), alignment)
  61. self._alignment = alignment
  62. alignment = property(_get_alignment, _set_alignment)
  63. def __repr__(self):
  64. """
  65. Return a string representation for this ``AlignedSent``.
  66. :rtype: str
  67. """
  68. words = "[%s]" % (", ".join("'%s'" % w for w in self._words))
  69. mots = "[%s]" % (", ".join("'%s'" % w for w in self._mots))
  70. return "AlignedSent(%s, %s, %r)" % (words, mots, self._alignment)
  71. def _to_dot(self):
  72. """
  73. Dot representation of the aligned sentence
  74. """
  75. s = 'graph align {\n'
  76. s += 'node[shape=plaintext]\n'
  77. # Declare node
  78. for w in self._words:
  79. s += '"%s_source" [label="%s"] \n' % (w, w)
  80. for w in self._mots:
  81. s += '"%s_target" [label="%s"] \n' % (w, w)
  82. # Alignment
  83. for u, v in self._alignment:
  84. s += '"%s_source" -- "%s_target" \n' % (self._words[u], self._mots[v])
  85. # Connect the source words
  86. for i in range(len(self._words) - 1):
  87. s += '"%s_source" -- "%s_source" [style=invis]\n' % (
  88. self._words[i],
  89. self._words[i + 1],
  90. )
  91. # Connect the target words
  92. for i in range(len(self._mots) - 1):
  93. s += '"%s_target" -- "%s_target" [style=invis]\n' % (
  94. self._mots[i],
  95. self._mots[i + 1],
  96. )
  97. # Put it in the same rank
  98. s += '{rank = same; %s}\n' % (' '.join('"%s_source"' % w for w in self._words))
  99. s += '{rank = same; %s}\n' % (' '.join('"%s_target"' % w for w in self._mots))
  100. s += '}'
  101. return s
  102. def _repr_svg_(self):
  103. """
  104. Ipython magic : show SVG representation of this ``AlignedSent``.
  105. """
  106. dot_string = self._to_dot().encode('utf8')
  107. output_format = 'svg'
  108. try:
  109. process = subprocess.Popen(
  110. ['dot', '-T%s' % output_format],
  111. stdin=subprocess.PIPE,
  112. stdout=subprocess.PIPE,
  113. stderr=subprocess.PIPE,
  114. )
  115. except OSError:
  116. raise Exception('Cannot find the dot binary from Graphviz package')
  117. out, err = process.communicate(dot_string)
  118. return out.decode('utf8')
  119. def __str__(self):
  120. """
  121. Return a human-readable string representation for this ``AlignedSent``.
  122. :rtype: str
  123. """
  124. source = " ".join(self._words)[:20] + "..."
  125. target = " ".join(self._mots)[:20] + "..."
  126. return "<AlignedSent: '%s' -> '%s'>" % (source, target)
  127. def invert(self):
  128. """
  129. Return the aligned sentence pair, reversing the directionality
  130. :rtype: AlignedSent
  131. """
  132. return AlignedSent(self._mots, self._words, self._alignment.invert())
  133. @python_2_unicode_compatible
  134. class Alignment(frozenset):
  135. """
  136. A storage class for representing alignment between two sequences, s1, s2.
  137. In general, an alignment is a set of tuples of the form (i, j, ...)
  138. representing an alignment between the i-th element of s1 and the
  139. j-th element of s2. Tuples are extensible (they might contain
  140. additional data, such as a boolean to indicate sure vs possible alignments).
  141. >>> from nltk.translate import Alignment
  142. >>> a = Alignment([(0, 0), (0, 1), (1, 2), (2, 2)])
  143. >>> a.invert()
  144. Alignment([(0, 0), (1, 0), (2, 1), (2, 2)])
  145. >>> print(a.invert())
  146. 0-0 1-0 2-1 2-2
  147. >>> a[0]
  148. [(0, 1), (0, 0)]
  149. >>> a.invert()[2]
  150. [(2, 1), (2, 2)]
  151. >>> b = Alignment([(0, 0), (0, 1)])
  152. >>> b.issubset(a)
  153. True
  154. >>> c = Alignment.fromstring('0-0 0-1')
  155. >>> b == c
  156. True
  157. """
  158. def __new__(cls, pairs):
  159. self = frozenset.__new__(cls, pairs)
  160. self._len = max(p[0] for p in self) if self != frozenset([]) else 0
  161. self._index = None
  162. return self
  163. @classmethod
  164. def fromstring(cls, s):
  165. """
  166. Read a giza-formatted string and return an Alignment object.
  167. >>> Alignment.fromstring('0-0 2-1 9-2 21-3 10-4 7-5')
  168. Alignment([(0, 0), (2, 1), (7, 5), (9, 2), (10, 4), (21, 3)])
  169. :type s: str
  170. :param s: the positional alignments in giza format
  171. :rtype: Alignment
  172. :return: An Alignment object corresponding to the string representation ``s``.
  173. """
  174. return Alignment([_giza2pair(a) for a in s.split()])
  175. def __getitem__(self, key):
  176. """
  177. Look up the alignments that map from a given index or slice.
  178. """
  179. if not self._index:
  180. self._build_index()
  181. return self._index.__getitem__(key)
  182. def invert(self):
  183. """
  184. Return an Alignment object, being the inverted mapping.
  185. """
  186. return Alignment(((p[1], p[0]) + p[2:]) for p in self)
  187. def range(self, positions=None):
  188. """
  189. Work out the range of the mapping from the given positions.
  190. If no positions are specified, compute the range of the entire mapping.
  191. """
  192. image = set()
  193. if not self._index:
  194. self._build_index()
  195. if not positions:
  196. positions = list(range(len(self._index)))
  197. for p in positions:
  198. image.update(f for _, f in self._index[p])
  199. return sorted(image)
  200. def __repr__(self):
  201. """
  202. Produce a Giza-formatted string representing the alignment.
  203. """
  204. return "Alignment(%r)" % sorted(self)
  205. def __str__(self):
  206. """
  207. Produce a Giza-formatted string representing the alignment.
  208. """
  209. return " ".join("%d-%d" % p[:2] for p in sorted(self))
  210. def _build_index(self):
  211. """
  212. Build a list self._index such that self._index[i] is a list
  213. of the alignments originating from word i.
  214. """
  215. self._index = [[] for _ in range(self._len + 1)]
  216. for p in self:
  217. self._index[p[0]].append(p)
  218. def _giza2pair(pair_string):
  219. i, j = pair_string.split("-")
  220. return int(i), int(j)
  221. def _naacl2pair(pair_string):
  222. i, j, p = pair_string.split("-")
  223. return int(i), int(j)
  224. def _check_alignment(num_words, num_mots, alignment):
  225. """
  226. Check whether the alignments are legal.
  227. :param num_words: the number of source language words
  228. :type num_words: int
  229. :param num_mots: the number of target language words
  230. :type num_mots: int
  231. :param alignment: alignment to be checked
  232. :type alignment: Alignment
  233. :raise IndexError: if alignment falls outside the sentence
  234. """
  235. assert type(alignment) is Alignment
  236. if not all(0 <= pair[0] < num_words for pair in alignment):
  237. raise IndexError("Alignment is outside boundary of words")
  238. if not all(pair[1] is None or 0 <= pair[1] < num_mots for pair in alignment):
  239. raise IndexError("Alignment is outside boundary of mots")
  240. PhraseTableEntry = namedtuple('PhraseTableEntry', ['trg_phrase', 'log_prob'])
  241. class PhraseTable(object):
  242. """
  243. In-memory store of translations for a given phrase, and the log
  244. probability of the those translations
  245. """
  246. def __init__(self):
  247. self.src_phrases = dict()
  248. def translations_for(self, src_phrase):
  249. """
  250. Get the translations for a source language phrase
  251. :param src_phrase: Source language phrase of interest
  252. :type src_phrase: tuple(str)
  253. :return: A list of target language phrases that are translations
  254. of ``src_phrase``, ordered in decreasing order of
  255. likelihood. Each list element is a tuple of the target
  256. phrase and its log probability.
  257. :rtype: list(PhraseTableEntry)
  258. """
  259. return self.src_phrases[src_phrase]
  260. def add(self, src_phrase, trg_phrase, log_prob):
  261. """
  262. :type src_phrase: tuple(str)
  263. :type trg_phrase: tuple(str)
  264. :param log_prob: Log probability that given ``src_phrase``,
  265. ``trg_phrase`` is its translation
  266. :type log_prob: float
  267. """
  268. entry = PhraseTableEntry(trg_phrase=trg_phrase, log_prob=log_prob)
  269. if src_phrase not in self.src_phrases:
  270. self.src_phrases[src_phrase] = []
  271. self.src_phrases[src_phrase].append(entry)
  272. self.src_phrases[src_phrase].sort(key=lambda e: e.log_prob, reverse=True)
  273. def __contains__(self, src_phrase):
  274. return src_phrase in self.src_phrases