matrix.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright (c) 2018 Manfred Moitzi
  2. # License: MIT License
  3. from typing import Iterable, Tuple, List, Sequence, Union, Any
  4. from itertools import repeat
  5. def zip_to_list(*args) -> Iterable[List]:
  6. for e in zip(*args): # returns immutable tuples
  7. yield list(e) # need mutable list
  8. class Matrix:
  9. """
  10. Simple unoptimized Matrix math.
  11. Initialization:
  12. - Matrix(shape=(rows, cols)) -> Matrix filled with zeros
  13. - Matrix(matrix[, shape=(rows, cols)]) -> Matrix by copy a Matrix and optional reshape
  14. - Matrix([[row_0], [row_1], ..., [row_n]]) -> Matrix from List[List[float]]
  15. - Matrix([a1, a2, ..., an], shape=(rows, cols)) -> Matrix from List[float] and shape
  16. """
  17. def __init__(self, items: Any = None,
  18. shape: Tuple[int, int] = None,
  19. matrix: List[List[float]] = None):
  20. self.matrix = matrix # type: List[List[float]]
  21. if items is None:
  22. if shape is not None:
  23. self.matrix = Matrix.reshape(repeat(0.), shape).matrix
  24. else: # items is None, shape is None
  25. pass
  26. elif isinstance(items, Matrix):
  27. if shape is None:
  28. shape = items.shape
  29. self.matrix = Matrix.reshape(items, shape=shape).matrix
  30. else:
  31. items = list(items)
  32. try:
  33. self.matrix = [list(row) for row in items]
  34. except TypeError:
  35. self.matrix = Matrix.reshape(items, shape).matrix
  36. def __iter__(self) -> Iterable[float]:
  37. for row in self.matrix:
  38. for item in row:
  39. yield item
  40. @staticmethod
  41. def reshape(items: Iterable[float], shape: Tuple[int, int]) -> 'Matrix':
  42. items = iter(items)
  43. rows, cols = shape
  44. return Matrix(matrix=[[next(items) for _ in range(cols)] for _ in range(rows)])
  45. @property
  46. def nrows(self) -> int:
  47. return len(self.matrix)
  48. @property
  49. def ncols(self) -> int:
  50. return len(self.matrix[0])
  51. @property
  52. def shape(self) -> Tuple[int, int]:
  53. return self.nrows, self.ncols
  54. def row(self, index) -> List[float]:
  55. return self.matrix[index]
  56. def col(self, index) -> List[float]:
  57. return [row[index] for row in self.matrix]
  58. def rows(self) -> List[List[float]]:
  59. return self.matrix
  60. def cols(self) -> List[List[float]]:
  61. return [self.col(i) for i in range(self.ncols)]
  62. def append_row(self, items: Sequence[float]) -> None:
  63. if self.matrix is None:
  64. self.matrix = [list(items)]
  65. elif len(items) == self.ncols:
  66. self.matrix.append(items)
  67. else:
  68. raise ValueError('Invalid item count.')
  69. def append_col(self, items: Sequence[float]) -> None:
  70. if self.matrix is None:
  71. self.matrix = [[item] for item in items]
  72. elif len(items) == self.nrows:
  73. for row, item in zip(self.matrix, items):
  74. row.append(item)
  75. else:
  76. raise ValueError('Invalid item count.')
  77. def __getitem__(self, item: Tuple[int, int]) -> float:
  78. row, col = item
  79. return self.matrix[row][col]
  80. def __setitem__(self, key: Tuple[int, int], value: float):
  81. row, col = key
  82. self.matrix[row][col] = value
  83. def __eq__(self, other: 'Matrix') -> bool:
  84. if not isinstance(other, Matrix):
  85. raise TypeError('Only comparable to class Matrix.')
  86. if self.shape != other.shape:
  87. raise TypeError('Matrices has to have the same shape.')
  88. for row1, row2 in zip(self.matrix, other.matrix):
  89. if list(row1) != list(row2):
  90. return False
  91. return True
  92. def __mul__(self, other: Union['Matrix', float]) -> 'Matrix':
  93. if isinstance(other, Matrix):
  94. matrix = Matrix(
  95. matrix=[[sum(a * b for a, b in zip(X_row, Y_col)) for Y_col in zip(*other.matrix)] for X_row in
  96. self.matrix])
  97. else:
  98. factor = float(other)
  99. matrix = Matrix.reshape([item * factor for item in self], shape=self.shape)
  100. return matrix
  101. __imul__ = __mul__
  102. def __add__(self, other: Union['Matrix', float]) -> 'Matrix':
  103. if isinstance(other, Matrix):
  104. matrix = Matrix.reshape([a + b for a, b in zip(self, other)], shape=self.shape)
  105. else:
  106. other = float(other)
  107. matrix = Matrix.reshape([item + other for item in self], shape=self.shape)
  108. return matrix
  109. __iadd__ = __add__
  110. def __sub__(self, other: Union['Matrix', float]) -> 'Matrix':
  111. if isinstance(other, Matrix):
  112. matrix = Matrix.reshape([a - b for a, b in zip(self, other)], shape=self.shape)
  113. else:
  114. other = float(other)
  115. matrix = Matrix.reshape([item - other for item in self], shape=self.shape)
  116. return matrix
  117. __isub__ = __sub__
  118. def transpose(self) -> 'Matrix':
  119. return Matrix(matrix=list(zip_to_list(*self.matrix)))
  120. def gauss(self, col):
  121. m = Matrix(self)
  122. m.append_col(col)
  123. return gauss(m.matrix)
  124. def gauss_matrix(self, matrix) -> 'Matrix':
  125. B = Matrix(matrix)
  126. if self.nrows != B.nrows:
  127. raise ValueError('Row count of matrices do not match.')
  128. result = [self.gauss(col) for col in B.cols()]
  129. return Matrix(items=zip(*result))
  130. def gauss(matrix: List[List[float]]) -> List[float]:
  131. """
  132. Solves a nxn Matrix A x = b, Matrix has 1 column more than rows.
  133. Args:
  134. matrix: matrix [[a11, a12, ..., a1n, b1],
  135. [a21, a22, ..., a2n, b2],
  136. [a21, a22, ..., a2n, b3],
  137. ...
  138. [an1, an2, ..., ann, bn],]
  139. Returns: x vector as list
  140. """
  141. n = len(matrix)
  142. for i in range(0, n):
  143. # Search for maximum in this column
  144. max_element = abs(matrix[i][i])
  145. max_row = i
  146. for k in range(i + 1, n):
  147. if abs(matrix[k][i]) > max_element:
  148. max_element = abs(matrix[k][i])
  149. max_row = k
  150. # Swap maximum row with current row (column by column)
  151. for k in range(i, n + 1):
  152. tmp = matrix[max_row][k]
  153. matrix[max_row][k] = matrix[i][k]
  154. matrix[i][k] = tmp
  155. # Make all rows below this one 0 in current column
  156. for k in range(i + 1, n):
  157. c = -matrix[k][i] / matrix[i][i]
  158. for j in range(i, n + 1):
  159. if i == j:
  160. matrix[k][j] = 0
  161. else:
  162. matrix[k][j] += c * matrix[i][j]
  163. # Solve equation Ax=b for an upper triangular matrix A
  164. x = [0.] * n
  165. for i in range(n - 1, -1, -1):
  166. x[i] = matrix[i][n] / matrix[i][i]
  167. for k in range(i - 1, -1, -1):
  168. matrix[k][n] -= matrix[k][i] * x[i]
  169. return x