컴퓨터 사이언스/1고리즘

백준 1289 : 트리의 가중치

저세상 개발자 2021. 8. 23. 03:18

https://www.acmicpc.net/problem/1289

 

1289번: 트리의 가중치

첫째 줄에 트리의 정점의 개수 N(1 ≤ N ≤ 100,000)이 주어진다. 다음 N-1개의 줄에 대해 각 줄에는 세 개의 정수 A, B, W(1 ≤ A, B ≤ N, 0 ≤ W ≤ 1,000)가 입력되는데 이는 A점과 B점이 연결되어 있고 이

www.acmicpc.net

 

트리는 거들 뿐, 수학적인 배경 지식이 필요한 문제다.

 

노드의 개수는 최대 10만개이고, 각 노드들 간의 간선의 가중치의 합을 구해야한다.

일반적인 방법을 사용한다면 O(N^2)으로 무조건 시간초과가 발생한다. 

우리가 필요한 정보는 각 노드들 사이의 간선의 가중치가 아니고 그 가중치들의 합이 필요하므로 O(N)으로 원하는 값을 얻을 수 있다.

 

예를 들어보자,

만약 루트 노드 아래에 4개의 자식노드가 있고 각각의 노드까지의 가중치가 a,b,c,d라고 할 때 각각의 노드들 사이의 간선의 가중치의 합을 구하는 과정은 아래와 같을 것이다.

위에 그림의 마지막 줄에 나온 식을 잘 정리하면 DP 방법으로 O(N)의 시간복잡도로 모든 간선들 사이의 가중치의 합을 구할 수 있다.

 

코드로 나타내면 아래와 같다.

int sum_dist(vector<int>& dist) {
	int ans = 0;
	int sum = 0;
    
	for (int i = 0; i < dist.size(); i++) {
		ans += dist[i] * sum;
		sum += dist[i];
	}

	return ans;
}

위의 그림 예시같은 상황이라면 dist배열은 { 1, a, b, c, d }일 것이다.

여기서 1은 자기 자신까지의 거리라고 생각하자. 1을 넣어주지 않으면 자신을 제외한 자식 노드들끼리의 거리만 구해서 더해질 것이다.

 

그러면 자식노드에도 자식노드가 존재한다면, 즉 자식노드를 루트노드로 하는 서브트리 또한 존재한다면 어떻게 해줘야 할까?

방법은 같다. 간선의 가중치가 (자식노드로 가는 간선의 가중치 * 자식노드의 자식노드로 가는 간선의 가중치) 인 자식노드가 하나 더 생긴다고 생각하면 쉽다.

 

그림으로 보면 아래와 같다.

위 그림의 치환된 그래프의 dist배열은 { 1, a, b, c, ab, ac } 일 것이다.

 

위의 상황에서 노드들의 간선의 가중치의 합이 ANS라고 할 때,

그럼 만약에 여기서 가중치가 d인 간선과 함께 부모노드가 하나 추가된다고 가정하면 어떻게 해줘야할까?

간선 d가 추가되면 새로운 간선 1*d, ad, bd, cd, abd, acd 가 생기므로, ANS 새로 생기는 간선들의 가중치의 합을 더해주면 된다.

 

첫 번째 그림에서 확인한 자식 노드들의 가중치의 합을 처리 하는 방법과 두 번째 그림에서 확인한 서브트리의 가중치의 합을 처리하는 방법을 조합하면, 재귀 방식으로 문제를 해결할 수 있다.

 

#include <iostream>
#include <vector>

using namespace std;
using ll = long long;
using pii = pair<int, int>;

const int MAX_NODE = 100'001;
const int MOD = 1e9 + 7;

int n;
ll ans;
vector<pii> edge[MAX_NODE];

int DP(int cur, int prev) {
	int next, cur_dist, sum_new_edge, sum_edge = 1;
	for (int i = 0; i < edge[cur].size(); i++) {
		next = edge[cur][i].first;
		cur_dist = edge[cur][i].second;

		if (next == prev) continue;
		
        // 새로 생기는 간선의 가중치의 합
        // (cur노드에서 next를 루트로 하는 서브트리의 모든 노드로 가는 간선의 가중치의 합)
		sum_new_edge = ((ll)DP(next, cur) * cur_dist) % MOD;
        // 새로 생긴 간선들의 가중치의 합 ans에 더해줌
		ans += (ll)sum_new_edge * sum_edge, ans %= MOD;
        // 위의 예시 코드에서 sum_dist함수 내부의 sum변수와 같은 역할
		sum_edge += sum_new_edge, sum_edge %= MOD;
	}

	// 부모 노드에 현재 노드에서 서브 트리 내부의 모든 노드로 가는
    // 간선들의 가중치의 합 리턴
	return sum_edge;
}

void solution() {
	int a, b, w;
	cin >> n;
	for (int i = 1; i < n; i++) {
		cin >> a >> b >> w;
		edge[a].emplace_back(b, w);
		edge[b].emplace_back(a, w);
	}

	DP(1, 0);
	cout << ans;
}

int main() {
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	solution();

	return 0;
}
메모리: 9496 kb 시간: 48 ms