gaac.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Natural Language Toolkit: Group Average Agglomerative Clusterer
  2. #
  3. # Copyright (C) 2001-2019 NLTK Project
  4. # Author: Trevor Cohn <tacohn@cs.mu.oz.au>
  5. # URL: <http://nltk.org/>
  6. # For license information, see LICENSE.TXT
  7. from __future__ import print_function, unicode_literals, division
  8. try:
  9. import numpy
  10. except ImportError:
  11. pass
  12. from nltk.cluster.util import VectorSpaceClusterer, Dendrogram, cosine_distance
  13. from nltk.compat import python_2_unicode_compatible
  14. @python_2_unicode_compatible
  15. class GAAClusterer(VectorSpaceClusterer):
  16. """
  17. The Group Average Agglomerative starts with each of the N vectors as singleton
  18. clusters. It then iteratively merges pairs of clusters which have the
  19. closest centroids. This continues until there is only one cluster. The
  20. order of merges gives rise to a dendrogram: a tree with the earlier merges
  21. lower than later merges. The membership of a given number of clusters c, 1
  22. <= c <= N, can be found by cutting the dendrogram at depth c.
  23. This clusterer uses the cosine similarity metric only, which allows for
  24. efficient speed-up in the clustering process.
  25. """
  26. def __init__(self, num_clusters=1, normalise=True, svd_dimensions=None):
  27. VectorSpaceClusterer.__init__(self, normalise, svd_dimensions)
  28. self._num_clusters = num_clusters
  29. self._dendrogram = None
  30. self._groups_values = None
  31. def cluster(self, vectors, assign_clusters=False, trace=False):
  32. # stores the merge order
  33. self._dendrogram = Dendrogram(
  34. [numpy.array(vector, numpy.float64) for vector in vectors]
  35. )
  36. return VectorSpaceClusterer.cluster(self, vectors, assign_clusters, trace)
  37. def cluster_vectorspace(self, vectors, trace=False):
  38. # variables describing the initial situation
  39. N = len(vectors)
  40. cluster_len = [1] * N
  41. cluster_count = N
  42. index_map = numpy.arange(N)
  43. # construct the similarity matrix
  44. dims = (N, N)
  45. dist = numpy.ones(dims, dtype=numpy.float) * numpy.inf
  46. for i in range(N):
  47. for j in range(i + 1, N):
  48. dist[i, j] = cosine_distance(vectors[i], vectors[j])
  49. while cluster_count > max(self._num_clusters, 1):
  50. i, j = numpy.unravel_index(dist.argmin(), dims)
  51. if trace:
  52. print("merging %d and %d" % (i, j))
  53. # update similarities for merging i and j
  54. self._merge_similarities(dist, cluster_len, i, j)
  55. # remove j
  56. dist[:, j] = numpy.inf
  57. dist[j, :] = numpy.inf
  58. # merge the clusters
  59. cluster_len[i] = cluster_len[i] + cluster_len[j]
  60. self._dendrogram.merge(index_map[i], index_map[j])
  61. cluster_count -= 1
  62. # update the index map to reflect the indexes if we
  63. # had removed j
  64. index_map[j + 1 :] -= 1
  65. index_map[j] = N
  66. self.update_clusters(self._num_clusters)
  67. def _merge_similarities(self, dist, cluster_len, i, j):
  68. # the new cluster i merged from i and j adopts the average of
  69. # i and j's similarity to each other cluster, weighted by the
  70. # number of points in the clusters i and j
  71. i_weight = cluster_len[i]
  72. j_weight = cluster_len[j]
  73. weight_sum = i_weight + j_weight
  74. # update for x<i
  75. dist[:i, i] = dist[:i, i] * i_weight + dist[:i, j] * j_weight
  76. dist[:i, i] /= weight_sum
  77. # update for i<x<j
  78. dist[i, i + 1 : j] = (
  79. dist[i, i + 1 : j] * i_weight + dist[i + 1 : j, j] * j_weight
  80. )
  81. # update for i<j<x
  82. dist[i, j + 1 :] = dist[i, j + 1 :] * i_weight + dist[j, j + 1 :] * j_weight
  83. dist[i, i + 1 :] /= weight_sum
  84. def update_clusters(self, num_clusters):
  85. clusters = self._dendrogram.groups(num_clusters)
  86. self._centroids = []
  87. for cluster in clusters:
  88. assert len(cluster) > 0
  89. if self._should_normalise:
  90. centroid = self._normalise(cluster[0])
  91. else:
  92. centroid = numpy.array(cluster[0])
  93. for vector in cluster[1:]:
  94. if self._should_normalise:
  95. centroid += self._normalise(vector)
  96. else:
  97. centroid += vector
  98. centroid /= len(cluster)
  99. self._centroids.append(centroid)
  100. self._num_clusters = len(self._centroids)
  101. def classify_vectorspace(self, vector):
  102. best = None
  103. for i in range(self._num_clusters):
  104. centroid = self._centroids[i]
  105. dist = cosine_distance(vector, centroid)
  106. if not best or dist < best[0]:
  107. best = (dist, i)
  108. return best[1]
  109. def dendrogram(self):
  110. """
  111. :return: The dendrogram representing the current clustering
  112. :rtype: Dendrogram
  113. """
  114. return self._dendrogram
  115. def num_clusters(self):
  116. return self._num_clusters
  117. def __repr__(self):
  118. return '<GroupAverageAgglomerative Clusterer n=%d>' % self._num_clusters
  119. def demo():
  120. """
  121. Non-interactive demonstration of the clusterers with simple 2-D data.
  122. """
  123. from nltk.cluster import GAAClusterer
  124. # use a set of tokens with 2D indices
  125. vectors = [numpy.array(f) for f in [[3, 3], [1, 2], [4, 2], [4, 0], [2, 3], [3, 1]]]
  126. # test the GAAC clusterer with 4 clusters
  127. clusterer = GAAClusterer(4)
  128. clusters = clusterer.cluster(vectors, True)
  129. print('Clusterer:', clusterer)
  130. print('Clustered:', vectors)
  131. print('As:', clusters)
  132. print()
  133. # show the dendrogram
  134. clusterer.dendrogram().show()
  135. # classify a new vector
  136. vector = numpy.array([3, 3])
  137. print('classify(%s):' % vector, end=' ')
  138. print(clusterer.classify(vector))
  139. print()
  140. if __name__ == '__main__':
  141. demo()