백준 1289 : 트리의 가중치
https://www.acmicpc.net/problem/1289
트리는 거들 뿐, 수학적인 배경 지식이 필요한 문제다.
노드의 개수는 최대 10만개이고, 각 노드들 간의 간선의 가중치의 합을 구해야한다.
일반적인 방법을 사용한다면 O(N^2)으로 무조건 시간초과가 발생한다.
우리가 필요한 정보는 각 노드들 사이의 간선의 가중치가 아니고 그 가중치들의 합이 필요하므로 O(N)으로 원하는 값을 얻을 수 있다.
예를 들어보자,
만약 루트 노드 아래에 4개의 자식노드가 있고 각각의 노드까지의 가중치가 a,b,c,d라고 할 때 각각의 노드들 사이의 간선의 가중치의 합을 구하는 과정은 아래와 같을 것이다.
위에 그림의 마지막 줄에 나온 식을 잘 정리하면 DP 방법으로 O(N)의 시간복잡도로 모든 간선들 사이의 가중치의 합을 구할 수 있다.
코드로 나타내면 아래와 같다.
int sum_dist(vector<int>& dist) {
int ans = 0;
int sum = 0;
for (int i = 0; i < dist.size(); i++) {
ans += dist[i] * sum;
sum += dist[i];
}
return ans;
}
위의 그림 예시같은 상황이라면 dist배열은 { 1, a, b, c, d }일 것이다.
여기서 1은 자기 자신까지의 거리라고 생각하자. 1을 넣어주지 않으면 자신을 제외한 자식 노드들끼리의 거리만 구해서 더해질 것이다.
그러면 자식노드에도 자식노드가 존재한다면, 즉 자식노드를 루트노드로 하는 서브트리 또한 존재한다면 어떻게 해줘야 할까?
방법은 같다. 간선의 가중치가 (자식노드로 가는 간선의 가중치 * 자식노드의 자식노드로 가는 간선의 가중치) 인 자식노드가 하나 더 생긴다고 생각하면 쉽다.
그림으로 보면 아래와 같다.
위 그림의 치환된 그래프의 dist배열은 { 1, a, b, c, ab, ac } 일 것이다.
위의 상황에서 노드들의 간선의 가중치의 합이 ANS라고 할 때,
그럼 만약에 여기서 가중치가 d인 간선과 함께 부모노드가 하나 추가된다고 가정하면 어떻게 해줘야할까?
간선 d가 추가되면 새로운 간선 1*d, ad, bd, cd, abd, acd 가 생기므로, ANS 새로 생기는 간선들의 가중치의 합을 더해주면 된다.
첫 번째 그림에서 확인한 자식 노드들의 가중치의 합을 처리 하는 방법과 두 번째 그림에서 확인한 서브트리의 가중치의 합을 처리하는 방법을 조합하면, 재귀 방식으로 문제를 해결할 수 있다.
#include <iostream>
#include <vector>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
const int MAX_NODE = 100'001;
const int MOD = 1e9 + 7;
int n;
ll ans;
vector<pii> edge[MAX_NODE];
int DP(int cur, int prev) {
int next, cur_dist, sum_new_edge, sum_edge = 1;
for (int i = 0; i < edge[cur].size(); i++) {
next = edge[cur][i].first;
cur_dist = edge[cur][i].second;
if (next == prev) continue;
// 새로 생기는 간선의 가중치의 합
// (cur노드에서 next를 루트로 하는 서브트리의 모든 노드로 가는 간선의 가중치의 합)
sum_new_edge = ((ll)DP(next, cur) * cur_dist) % MOD;
// 새로 생긴 간선들의 가중치의 합 ans에 더해줌
ans += (ll)sum_new_edge * sum_edge, ans %= MOD;
// 위의 예시 코드에서 sum_dist함수 내부의 sum변수와 같은 역할
sum_edge += sum_new_edge, sum_edge %= MOD;
}
// 부모 노드에 현재 노드에서 서브 트리 내부의 모든 노드로 가는
// 간선들의 가중치의 합 리턴
return sum_edge;
}
void solution() {
int a, b, w;
cin >> n;
for (int i = 1; i < n; i++) {
cin >> a >> b >> w;
edge[a].emplace_back(b, w);
edge[b].emplace_back(a, w);
}
DP(1, 0);
cout << ans;
}
int main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
solution();
return 0;
}
메모리: 9496 kb | 시간: 48 ms |