1. Review
- 분할 정복(Divide & Conquer)중 하나인 Strassen Algorithm에 대해서 알아보자.
- 우리는 앞서 4-1 에서 행렬곱을 계산하는 알고리즘에 대해서 알아보았다.
- 이때 코드를 다시한번 바라봐보자.
A = [[1,2],[3,4]]
B = [[5,6],[7,8]]
C = [[0,0],[0,0]]
for i in range(2):
for j in range(2):
for k in range(2):
C[i][j] += A[i][k]*B[k][j]
C
------------------------------------------------------
[[19, 22], [43, 50]]
- 이를 시간 복잡도로 표현하면 O(n^3)과 같다.
- 그 이유는 3중 반복문이 행렬의 크기 n만큼 돌기 때문이다.
- 과거엔 행렬곱을 계산하기 위한 알고리즘의 시간복잡도는 O(n^3)보다 작을 순 없다고 생각하였다.
- 하지만 Strassen 알고리즘이 등장한 순간부터 그 고정관념이 부셔지게 되었다
- Strassen 알고리즘에 대해서 자세히 알아보자.
2. Strassen’s algorithm
- 원리를 간단하게 설명하면 행렬들을 부분행렬로 나눠서 곱하고 더하는 과정을 재귀적으로 반복하는데
- 행렬의 덧셈이 곱셈보다 더 빠른 점을 이용하기 위해 부분 행렬들의 곱셈 횟수를 줄이고 덧셈 횟수를 늘린다.
- 4-1 에서 행렬을 부분 행렬로 나눠서 계산하는 방법에 대해서 알아봤었는데 다시한번 알아보면 다음과 같다.
- 이제 Strassen 알고리즘에 대해 자세히 알아보자.
- 우선 본격적으로 알아가기 전에 다음과 같이 7개의 행렬을 정의하고 시작한다.
- 이를 C11,C12,C21,C22의 부분행렬들을 M1~7까지를이용하여 나타내면 다음과 같다.
- 이를 바탕으로 시간복잡도를 계산하면 다음과 같다.
- Strassen알고리즘은 그러면 기존의 O(n^3)을 가진 행렬곱보다 무조건 실행시간이 빠를까?
- 여기에 대답은 매우 큰 n이 아니면 더 느려진다.
- 그 이유는 cO(n^log7)에서 c가 매우큰 상수이기 때문에 더 느려지게 된다.
- 넘파이를 이용해서 부분행렬을 간단하게 나타내면 Strassen 알고리즘을 다음과 같이 간단하게 구현이 가능하다.
import numpy as np
def strassen(A, B):
n = len(A)
if n <= 2: # Base case
return np.dot(A, B)
# Partition matrices into submatrices
mid = n // 2
A11 = A[:mid, :mid]
A12 = A[:mid, mid:]
A21 = A[mid:, :mid]
A22 = A[mid:, mid:]
B11 = B[:mid, :mid]
B12 = B[:mid, mid:]
B21 = B[mid:, :mid]
B22 = B[mid:, mid:]
# Recursive multiplication
P1 = strassen(A11, B12 - B22)
P2 = strassen(A11 + A12, B22)
P3 = strassen(A21 + A22, B11)
P4 = strassen(A22, B21 - B11)
P5 = strassen(A11 + A22, B11 + B22)
P6 = strassen(A12 - A22, B21 + B22)
P7 = strassen(A11 - A21, B11 + B12)
# Combine results to form C
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
# Combine quadrants to form C
C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
return C
# Example usage:
A = np.array([[1, 3], [7, 5]])
B = np.array([[6, 8], [4, 2]])
C = strassen(A, B)
print("Matrix C (Result of A * B):\n", C)
3.reference
https://www.geeksforgeeks.org/strassen-algorithm-in-python/
Strassen algorithm in Python - GeeksforGeeks
A Computer Science portal for geeks. It contains well written, well thought and well explained computer science and programming articles, quizzes and practice/competitive programming/company interview Questions.
www.geeksforgeeks.org
'CS Study > CLRS (자료구조 | 알고리즘)' 카테고리의 다른 글
[CLRS] [4-4] Recursion Tree Method(재귀 트리) (0) | 2024.09.06 |
---|---|
[CLRS] [4-3] The substitution method (1) | 2024.09.06 |
[CLRS] [4-1] Multiplying square matrices(행렬 곱) (0) | 2024.09.05 |
[CLRS] [3-1] 시간 복잡도 (big-O, big-Ω, big-θ) (0) | 2024.09.04 |
[CLRS] [2-4] Bubble Sort(버블 정렬) (0) | 2024.09.04 |