123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- # Natural Language Toolkit: Confusion Matrices
- #
- # Copyright (C) 2001-2019 NLTK Project
- # Author: Edward Loper <edloper@gmail.com>
- # Steven Bird <stevenbird1@gmail.com>
- # URL: <http://nltk.org/>
- # For license information, see LICENSE.TXT
- from __future__ import print_function, unicode_literals
- from nltk.probability import FreqDist
- from nltk.compat import python_2_unicode_compatible
- @python_2_unicode_compatible
- class ConfusionMatrix(object):
- """
- The confusion matrix between a list of reference values and a
- corresponding list of test values. Entry *[r,t]* of this
- matrix is a count of the number of times that the reference value
- *r* corresponds to the test value *t*. E.g.:
- >>> from nltk.metrics import ConfusionMatrix
- >>> ref = 'DET NN VB DET JJ NN NN IN DET NN'.split()
- >>> test = 'DET VB VB DET NN NN NN IN DET NN'.split()
- >>> cm = ConfusionMatrix(ref, test)
- >>> print(cm['NN', 'NN'])
- 3
- Note that the diagonal entries *Ri=Tj* of this matrix
- corresponds to correct values; and the off-diagonal entries
- correspond to incorrect values.
- """
- def __init__(self, reference, test, sort_by_count=False):
- """
- Construct a new confusion matrix from a list of reference
- values and a corresponding list of test values.
- :type reference: list
- :param reference: An ordered list of reference values.
- :type test: list
- :param test: A list of values to compare against the
- corresponding reference values.
- :raise ValueError: If ``reference`` and ``length`` do not have
- the same length.
- """
- if len(reference) != len(test):
- raise ValueError('Lists must have the same length.')
- # Get a list of all values.
- if sort_by_count:
- ref_fdist = FreqDist(reference)
- test_fdist = FreqDist(test)
- def key(v):
- return -(ref_fdist[v] + test_fdist[v])
- values = sorted(set(reference + test), key=key)
- else:
- values = sorted(set(reference + test))
- # Construct a value->index dictionary
- indices = dict((val, i) for (i, val) in enumerate(values))
- # Make a confusion matrix table.
- confusion = [[0 for val in values] for val in values]
- max_conf = 0 # Maximum confusion
- for w, g in zip(reference, test):
- confusion[indices[w]][indices[g]] += 1
- max_conf = max(max_conf, confusion[indices[w]][indices[g]])
- #: A list of all values in ``reference`` or ``test``.
- self._values = values
- #: A dictionary mapping values in ``self._values`` to their indices.
- self._indices = indices
- #: The confusion matrix itself (as a list of lists of counts).
- self._confusion = confusion
- #: The greatest count in ``self._confusion`` (used for printing).
- self._max_conf = max_conf
- #: The total number of values in the confusion matrix.
- self._total = len(reference)
- #: The number of correct (on-diagonal) values in the matrix.
- self._correct = sum(confusion[i][i] for i in range(len(values)))
- def __getitem__(self, li_lj_tuple):
- """
- :return: The number of times that value ``li`` was expected and
- value ``lj`` was given.
- :rtype: int
- """
- (li, lj) = li_lj_tuple
- i = self._indices[li]
- j = self._indices[lj]
- return self._confusion[i][j]
- def __repr__(self):
- return '<ConfusionMatrix: %s/%s correct>' % (self._correct, self._total)
- def __str__(self):
- return self.pretty_format()
- def pretty_format(
- self,
- show_percents=False,
- values_in_chart=True,
- truncate=None,
- sort_by_count=False,
- ):
- """
- :return: A multi-line string representation of this confusion matrix.
- :type truncate: int
- :param truncate: If specified, then only show the specified
- number of values. Any sorting (e.g., sort_by_count)
- will be performed before truncation.
- :param sort_by_count: If true, then sort by the count of each
- label in the reference data. I.e., labels that occur more
- frequently in the reference label will be towards the left
- edge of the matrix, and labels that occur less frequently
- will be towards the right edge.
- @todo: add marginals?
- """
- confusion = self._confusion
- values = self._values
- if sort_by_count:
- values = sorted(
- values, key=lambda v: -sum(self._confusion[self._indices[v]])
- )
- if truncate:
- values = values[:truncate]
- if values_in_chart:
- value_strings = ["%s" % val for val in values]
- else:
- value_strings = [str(n + 1) for n in range(len(values))]
- # Construct a format string for row values
- valuelen = max(len(val) for val in value_strings)
- value_format = '%' + repr(valuelen) + 's | '
- # Construct a format string for matrix entries
- if show_percents:
- entrylen = 6
- entry_format = '%5.1f%%'
- zerostr = ' .'
- else:
- entrylen = len(repr(self._max_conf))
- entry_format = '%' + repr(entrylen) + 'd'
- zerostr = ' ' * (entrylen - 1) + '.'
- # Write the column values.
- s = ''
- for i in range(valuelen):
- s += (' ' * valuelen) + ' |'
- for val in value_strings:
- if i >= valuelen - len(val):
- s += val[i - valuelen + len(val)].rjust(entrylen + 1)
- else:
- s += ' ' * (entrylen + 1)
- s += ' |\n'
- # Write a dividing line
- s += '%s-+-%s+\n' % ('-' * valuelen, '-' * ((entrylen + 1) * len(values)))
- # Write the entries.
- for val, li in zip(value_strings, values):
- i = self._indices[li]
- s += value_format % val
- for lj in values:
- j = self._indices[lj]
- if confusion[i][j] == 0:
- s += zerostr
- elif show_percents:
- s += entry_format % (100.0 * confusion[i][j] / self._total)
- else:
- s += entry_format % confusion[i][j]
- if i == j:
- prevspace = s.rfind(' ')
- s = s[:prevspace] + '<' + s[prevspace + 1 :] + '>'
- else:
- s += ' '
- s += '|\n'
- # Write a dividing line
- s += '%s-+-%s+\n' % ('-' * valuelen, '-' * ((entrylen + 1) * len(values)))
- # Write a key
- s += '(row = reference; col = test)\n'
- if not values_in_chart:
- s += 'Value key:\n'
- for i, value in enumerate(values):
- s += '%6d: %s\n' % (i + 1, value)
- return s
- def key(self):
- values = self._values
- str = 'Value key:\n'
- indexlen = len(repr(len(values) - 1))
- key_format = ' %' + repr(indexlen) + 'd: %s\n'
- for i in range(len(values)):
- str += key_format % (i, values[i])
- return str
- def demo():
- reference = 'DET NN VB DET JJ NN NN IN DET NN'.split()
- test = 'DET VB VB DET NN NN NN IN DET NN'.split()
- print('Reference =', reference)
- print('Test =', test)
- print('Confusion matrix:')
- print(ConfusionMatrix(reference, test))
- print(ConfusionMatrix(reference, test).pretty_format(sort_by_count=True))
- if __name__ == '__main__':
- demo()
|