백준 10830 : 행렬 제곱
https://www.acmicpc.net/problem/10830
10830번: 행렬 제곱
크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.
www.acmicpc.net
이 문제는 행렬 곱에 대한 배경지식이 있어야 풀 수 있는 문제다.
행렬곱은 3중 포문으로 간단하게 구현할 수 있다. 행렬 곱 구현은 아래 코드를 참고하자.
여기서 중요한 것은 행렬곱보다도 문제를 어떻게 해결할까 하는 아이디어인데,
A^b인 행렬을 만들기 위해서 A*A를 b번 반복하면 제일 간단하지만 이 문제에서 b의 범위가 1e11이므로 당연히 시간초과가 발생한다.
따라서 우리는 계속 A를 곱하는게 아니라 초기 행렬을 계속해서 제곱해서 2^x 제곱 형태의 행렬을 만들어 주고, 초기에 주어지는 b값을 이진수로 변환하여 필요한 2^x 제곱 형태의 행렬만 정답에 곱해주면 된다.
예를 들어,
b = 5일 때, b는 이진수로 101 이므로, ans = A^4 * A^1 이다. 따라서 우리는 초기 행렬을 재귀적으로 제곱시켜주고 우리가 필요한 A^1과 A^4만 ans에 곱해주면 된다.
ans의 초기값은 단위행렬로 설정한다.
(※ 참고로 단위행렬은 대각행렬이라고도 하며 주대각성분이 1이고 그 외의 값이 모두 0인 행렬이다)
b가 최대 값을 가질 때 2진수로 변환하면 최대 37자리의 2진수로 변환되므로 넉넉하게 길이 40짜리의 bitset을 사용했다. bool 배열을 사용해도 상관없다. 2진수를 만드는 작업이 귀찮기 때문에 bitset을 사용했다.
아 그리고 문제를 풀다가 행렬 곱은 교환법칙이 성립하지 않으므로 ans값을 어떤 순서로 곱해줘야할지 고민이 되었는데 행렬 A와 행렬 B를 곱한다고 할 때, B = P(A) 꼴이면 AB = BA 로 곱의 교환법칙이 성립한다고 한다.
여기서 P(A) = (시그마 0<= x <=n) a_x * A^x 꼴이다. (수식 입력이 안돼서 죄송.. ㅎ, 다들 이해하셨을거라 믿는다.)
따라서 이 문제에서는 곱의 교환법칙이 성립하므로 행렬곱 순서는 전혀 신경쓸 필요가 없다.
#include <iostream>
#include <bitset>
#define MOD 1000
using namespace std;
using ll = long long;
int n;
ll b;
int matrix[5][5];
int ans[5][5];
void matrix_mul(bool ispow) {
int tmp_matrix[5][5] = { 0, };
for (int y = 0; y < n; y++) {
for (int x = 0; x < n; x++) {
for (int i = 0; i < n; i++) {
if (ispow) tmp_matrix[y][x] += matrix[y][i] * matrix[i][x];
else tmp_matrix[y][x] += ans[y][i] * matrix[i][x];
}
tmp_matrix[y][x] %= MOD;
}
}
for (int y = 0; y < n; y++) {
for (int x = 0; x < n; x++) {
if (ispow) matrix[y][x] = tmp_matrix[y][x];
else ans[y][x] = tmp_matrix[y][x];
}
}
}
void recursion(int cur_idx, bitset<40>& bs) {
if (cur_idx >= 39) return;
//b를 2진수 변환했을 때 cur_idx 지점의 값이 1이면
//ans에 현재 행렬을 곱해줌 (현재 행렬은 초기행렬^(2^cur_idx)인 상태)
if (bs[cur_idx]) {
matrix_mul(0);
}
matrix_mul(1);
recursion(cur_idx + 1, bs);
}
void solution() {
cin >> n >> b;
for (int y = 0; y < n; y++) {
for (int x = 0; x < n; x++) {
cin >> matrix[y][x];
if (x == y) ans[y][x] = 1;
}
}
bitset<40> bs(b);
recursion(0, bs);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
cout << ans[i][j] << ' ';
}
cout << '\n';
}
}
int main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
solution();
return 0;
}
메모리: 2020 kb | 시간: 0 ms |