티스토리 뷰
트리에서 사용하는 분할정복 기법을 Centroid Decomposition이라고 한다.
트리에 존재하는 모든 경로에 대해서 어떤 연산을 수행하려 할때
트리를 분할에 서브 트리에 대해 연산을 수행하고, 그러한 서브트리를 merge해 전체 트리에서의 결과값을 얻어 낼 수 있다.
여기서 중요한 것은 주어진 트리르 어떻게 쪼개는지 이다.
쪼개는 방법에 따라 시간복잡도가 달라지는데, 분할정복 처럼 트리의 크기를 계속해서 반 씩 쪼깨고 각 서브트리를 처음 트리사이즈 시간만에 merge 할 수 있다면 시간복잡도는 마스터 정리에 의해 O(N lg N)이 된다.
이 문제를 푸는 데 있어 중요한 것은
1. Subtree 로 분할할 cutting edge를 어떻게 잡을것인가?
2. 서로다른 Subtree들의 결과값을 어떻게 리니어 하게 Merge 할 것인가?
우선 트리에서 서브트리의 노드갯수가 원래 트리 노드 갯수 / 2가 되도록 커팅 할 수 있는 정점은 항상 존재한다.
그러한 노드는 한번은 bfs순회를 통해 찾을 수 잇다.
여기서 주의할 점이 있는데, Cutting edge를 찾는 과정에서 dfs를 돌릴때 visit배열을 따로 사용하면 안된다.
이유인 즉슨 Cutting과정은 남은 노드가 1개가 될때 까지 재귀적을 반복되는데 매 Cutting과정에서 visit배열을 초기화하고 사용하면 제곱 시간이 들기 때문이다.
따라서 구현을 할때는 현재 노드가 cutting 되었는지 아닌지를 판단하는 배열, cutted하나와 자기로 온 부모 노드로 따라올라가지 못하도록 부모노드의 인덱스만 가지고 dfs를 돌린다. 트리이기 때문에 dfs를 돌릴때 부모로 가지 않도록 해주면, visit 배열을 따로 사용하지 않고 구현 할 수 있다.
그렇게 cutting할 노드를 찾기 위해 dfs 탐색을 한 후, 노드 리스트를 가지고 현재 노드를 잘랏을때 생기는 서브트리의 정점갯수중 max값을 구한다. 그런다음 그 max값이 원래노드갯수/2보다 작거나 같다면 그 정점을 루트로 다시 분할 하면 된다.
int dfs(int node, int mom = -1) { cnt[node] = 1; list.push_back(node); for (pair<int,int> e : graph[node]) { if (e.first == mom || cutted[e.first]) continue; cnt[node] += dfs(e.first, node); } return cnt[node]; } int getCut(int node) { list.clear(); int count = dfs(node); for (int x : list) { int cmax = 0; for (pair<int,int> e : graph[x]) { int y = e.first; if (cutted[y] || cnt[x] < cnt[y]) continue; cmax = maxf(cmax, cnt[y]); } cmax = maxf(cmax, count - cnt[x]); if (cmax <= count/2) return x; } return node; }
위의 구현을 코드로 나타낸 것이다.
그런 다음 merge하는 과정에서는 같은 그룹의 서브트리들이 계산 과정에서 선택되지 않도록 그룹별로 분리를 잘 해줘야 한다.
이는 문제마다 구현이 달라질 수 있으므로, 문제별로 잘 생각해본다.
머지 과정에서는 아래과정에서 구한 결과값이 전혀 필요없기 때문에
분할하기전에 머지를 먼저해도 되고, 분할하고 나서 머지를 해도 상관없다.
분할을 해서 subtree에 대해 solve를 부르는 것은, 그 subtree 내에 존재하는 모든 경로들을 탐색하는 것이고,
머지하는 과정에서는 cut한 노드를 무조건 거쳐가는 경로에 대한 계산이라 보면 된다.
밑에는 코드포스 C번에 대한 코드이다.
http://codeforces.com/contest/715/problem/C
#include <stdio.h> #include <vector> #include <queue> #include <memory.h> #include <map> #define mp make_pair #define __ 100010 #define maxf(a,b) ((a)>(b)?(a):(b)) typedef long long ll; using namespace std; struct Queue { int node, d, g; ll l, r; }; vector<Queue> qlist[__]; vector<int> list; vector<pair<int,int>> graph[__]; ll ten[__], rev[__], ans; int cnt[__], vis[__], cutted[__]; ll N, M; int dfs(int node, int mom = -1) { cnt[node] = 1; list.push_back(node); for (pair<int,int> e : graph[node]) { if (e.first == mom || cutted[e.first]) continue; cnt[node] += dfs(e.first, node); } return cnt[node]; } int getCut(int node) { list.clear(); int count = dfs(node); for (int x : list) { int cmax = 0; for (pair<int,int> e : graph[x]) { int y = e.first; if (cutted[y] || cnt[x] < cnt[y]) continue; cmax = maxf(cmax, cnt[y]); } cmax = maxf(cmax, count - cnt[x]); if (cmax <= count/2) return x; } return node; } void calc(int node) { queue<Queue> Q; int group = 0; for (pair<int,int> e : graph[node]) { if (cutted[e.first]) continue; Q.push({e.first, 0, ++group, e.second%M, e.second%M}); } for (int i=1;i<=group;i++) qlist[i].clear(); memset(vis,0,sizeof vis); vis[node] = 1; while (!Q.empty()) { Queue n = Q.front(); Q.pop(); if (vis[n.node]) continue; vis[n.node] = 1; qlist[n.g].push_back(n); for (pair<int,int> e : graph[n.node]) { if (vis[e.first] || cutted[e.first]) continue; ll L = (n.l * 10 + (ll)e.second)%M; ll R = ((ll)e.second * ten[n.d+1] + n.r)%M; Q.push({e.first, n.d+1, n.g, L, R}); } } map<ll, int> rv; for (int i=1;i<=group;i++) { for (Queue x : qlist[i]) rv[x.r] ++; } for (int i=1;i<=group;i++) { for (Queue x : qlist[i]) rv[x.r] --; for (Queue x : qlist[i]) { if (x.l == 0) ans ++; if (x.r == 0) ans ++; ll t = (-x.l + M)%M; t = (t * rev[x.d+1])%M; ans += rv[t]; } for (Queue x : qlist[i]) rv[x.r] ++; } } void solve(int node) { int cut = getCut(node); calc(cut); cutted[cut] = 1; for (pair<int,int> e : graph[cut]) { if (cutted[e.first]) continue; solve(e.first); } cutted[cut] = 0; } pair<ll, ll> extended_gcd(ll a,ll b) { if (b == 0) return mp(1,0); pair<ll, ll> t = extended_gcd(b, a%b); return mp(t.second, t.first - t.second * (a/b)); } int main() { scanf ("%lld%lld",&N,&M); ten[0] = 1 % M; for (int i=1;i<=N;i++) { ten[i] = (ten[i-1] * 10) % M; rev[i] = (extended_gcd(M, ten[i]).second + M) % M; } for (int i=1;i<N;i++) { int a,b,c; scanf ("%d%d%d",&a,&b,&c); graph[a].push_back(mp(b,c)); graph[b].push_back(mp(a,c)); } solve(0); printf ("%lld",ans); return 0; }