티스토리 뷰

카테고리 없음

Centroid Decomposition

makesource 2016. 9. 23. 15:17

트리에서 사용하는 분할정복 기법을 Centroid Decomposition이라고 한다.

트리에 존재하는 모든 경로에 대해서 어떤 연산을 수행하려 할때

트리를 분할에 서브 트리에 대해 연산을 수행하고, 그러한 서브트리를 merge해 전체 트리에서의 결과값을 얻어 낼 수 있다.


여기서 중요한 것은 주어진 트리르 어떻게 쪼개는지 이다.

쪼개는 방법에 따라 시간복잡도가 달라지는데, 분할정복 처럼 트리의 크기를 계속해서 반 씩 쪼깨고 각 서브트리를 처음 트리사이즈 시간만에 merge 할 수 있다면 시간복잡도는 마스터 정리에 의해 O(N lg N)이 된다.


이 문제를 푸는 데 있어 중요한 것은

1. Subtree 로 분할할 cutting edge를 어떻게 잡을것인가? 

2. 서로다른 Subtree들의 결과값을 어떻게 리니어 하게 Merge 할 것인가?


우선 트리에서 서브트리의 노드갯수가 원래 트리 노드 갯수 / 2가 되도록 커팅 할 수 있는 정점은 항상 존재한다.

그러한 노드는 한번은 bfs순회를 통해 찾을 수 잇다.

여기서 주의할 점이 있는데, Cutting edge를 찾는 과정에서 dfs를 돌릴때 visit배열을 따로 사용하면 안된다.

이유인 즉슨 Cutting과정은 남은 노드가 1개가 될때 까지 재귀적을 반복되는데 매 Cutting과정에서 visit배열을 초기화하고 사용하면 제곱 시간이 들기 때문이다.

따라서 구현을 할때는 현재 노드가 cutting 되었는지 아닌지를 판단하는 배열, cutted하나와 자기로 온 부모 노드로 따라올라가지 못하도록 부모노드의 인덱스만 가지고 dfs를 돌린다. 트리이기 때문에 dfs를 돌릴때 부모로 가지 않도록 해주면, visit 배열을 따로 사용하지 않고 구현 할 수 있다.


그렇게 cutting할 노드를 찾기 위해 dfs 탐색을 한 후, 노드 리스트를 가지고 현재 노드를 잘랏을때 생기는 서브트리의 정점갯수중 max값을 구한다. 그런다음 그 max값이 원래노드갯수/2보다 작거나 같다면 그 정점을 루트로 다시 분할 하면 된다. 


int dfs(int node, int mom = -1) {
    cnt[node] = 1;
    list.push_back(node);
    for (pair<int,int> e : graph[node]) {
        if (e.first == mom || cutted[e.first]) continue;
        cnt[node] += dfs(e.first, node);
    }
    return cnt[node];
}

int getCut(int node) {
    list.clear();
    int count = dfs(node);
    for (int x : list) {
        int cmax = 0;
        for (pair<int,int> e : graph[x]) {
            int y = e.first;
            if (cutted[y] || cnt[x] < cnt[y]) continue;
            cmax = maxf(cmax, cnt[y]);
        }
        cmax = maxf(cmax, count - cnt[x]);
        if (cmax <= count/2) return x;
    }
    return node;
}



위의 구현을 코드로 나타낸 것이다.


그런 다음 merge하는 과정에서는 같은 그룹의 서브트리들이 계산 과정에서 선택되지 않도록 그룹별로 분리를 잘 해줘야 한다.

이는 문제마다 구현이 달라질 수 있으므로, 문제별로 잘 생각해본다.


머지 과정에서는 아래과정에서 구한 결과값이 전혀 필요없기 때문에

분할하기전에 머지를 먼저해도 되고, 분할하고 나서 머지를 해도 상관없다.


분할을 해서 subtree에 대해 solve를 부르는 것은, 그 subtree 내에 존재하는 모든 경로들을 탐색하는 것이고,

머지하는 과정에서는 cut한 노드를 무조건 거쳐가는 경로에 대한 계산이라 보면 된다.


밑에는 코드포스 C번에 대한 코드이다.

http://codeforces.com/contest/715/problem/C


#include <stdio.h>
#include <vector>
#include <queue>
#include <memory.h>
#include <map>
#define mp make_pair
#define __ 100010
#define maxf(a,b) ((a)>(b)?(a):(b))
typedef long long ll;
using namespace std;

struct Queue {
    int node, d, g;
    ll l, r;
};

vector<Queue> qlist[__];
vector<int> list;
vector<pair<int,int>> graph[__];
ll ten[__], rev[__], ans;
int cnt[__], vis[__], cutted[__];
ll N, M;

int dfs(int node, int mom = -1) {
    cnt[node] = 1;
    list.push_back(node);
    for (pair<int,int> e : graph[node]) {
        if (e.first == mom || cutted[e.first]) continue;
        cnt[node] += dfs(e.first, node);
    }
    return cnt[node];
}

int getCut(int node) {
    list.clear();
    int count = dfs(node);
    for (int x : list) {
        int cmax = 0;
        for (pair<int,int> e : graph[x]) {
            int y = e.first;
            if (cutted[y] || cnt[x] < cnt[y]) continue;
            cmax = maxf(cmax, cnt[y]);
        }
        cmax = maxf(cmax, count - cnt[x]);
        if (cmax <= count/2) return x;
    }
    return node;
}

void calc(int node) {
    queue<Queue> Q;
    int group = 0;
    for (pair<int,int> e : graph[node]) {
        if (cutted[e.first]) continue;
        Q.push({e.first, 0, ++group, e.second%M, e.second%M});
    }
    for (int i=1;i<=group;i++) qlist[i].clear();
    memset(vis,0,sizeof vis);
    vis[node] = 1;
    while (!Q.empty()) {
        Queue n = Q.front(); Q.pop();
        if (vis[n.node]) continue;
        vis[n.node] = 1;
        qlist[n.g].push_back(n);
        for (pair<int,int> e : graph[n.node]) {
            if (vis[e.first] || cutted[e.first]) continue;
            ll L = (n.l * 10 + (ll)e.second)%M;
            ll R = ((ll)e.second * ten[n.d+1] + n.r)%M;
            Q.push({e.first, n.d+1, n.g, L, R});
        }
    }
    map<ll, int> rv;
    for (int i=1;i<=group;i++) {
        for (Queue x : qlist[i]) rv[x.r] ++;
    }
    for (int i=1;i<=group;i++) {
        for (Queue x : qlist[i]) rv[x.r] --;
        for (Queue x : qlist[i]) {
            if (x.l == 0) ans ++;
            if (x.r == 0) ans ++;
            ll t = (-x.l + M)%M;
            t = (t * rev[x.d+1])%M;
            ans += rv[t];
        }
        for (Queue x : qlist[i]) rv[x.r] ++;
    }
}

void solve(int node) {
    int cut = getCut(node);
    calc(cut);
    cutted[cut] = 1;
    for (pair<int,int> e : graph[cut]) {
        if (cutted[e.first]) continue;
        solve(e.first);
    }
    cutted[cut] = 0;
}

pair<ll, ll> extended_gcd(ll a,ll b) {
    if (b == 0) return mp(1,0);
    pair<ll, ll> t = extended_gcd(b, a%b);
    return mp(t.second, t.first - t.second * (a/b));
}

int main() {
    scanf ("%lld%lld",&N,&M);
    ten[0] = 1 % M;
    for (int i=1;i<=N;i++) {
        ten[i] = (ten[i-1] * 10) % M;
        rev[i] = (extended_gcd(M, ten[i]).second + M) % M;
    }
    for (int i=1;i<N;i++) {
        int a,b,c;
        scanf ("%d%d%d",&a,&b,&c);
        graph[a].push_back(mp(b,c));
        graph[b].push_back(mp(a,c));
    }
    solve(0);
    printf ("%lld",ans);
    return 0;
}



댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함