SCPC 2016 2차예선 4번 간략한 풀이

아주아주 대충 간단하게 풀이를 써볼까 합니다.

(미리보기 방지)

우선 그래프를 '현재 상태' 로 만들어야 합니다. 1번 점에서부터 다익스트라 따위를 돌려서 1번 점을 루트로 하는 트리를 만들어줍시다.

이제 특정 간선 i - j 를 잡고 이 놈이 전체 거리쌍의 합을 어떻게 변화시키는지 관찰해야 하는데

이건 그림을 그려 살펴보면 그나마 좀 눈에 쉽게 들어옵니다.

위 그림을 토대로 설명을 해볼게요.

우선 pi 와 pj 를 잡아야 합니다. 이건 어떤 점이냐면

dist(1, pi) >= dist(pi, i) 인 pi 중 가장 높은 것, dist(1, pj) >= dist(pj, j) 인 pj 중 가장 높은 것 입니다.

이걸 왜 구했냐? 하면 바로 문제의 조건 때문인데요, 여행의 경로가 달라지는 점들을 알아내기 위함입니다.

구체적으로 i - j 간선을 복구한다면, "pi를 루트로 하는 서브트리의 점들" (하늘색부분) 에서 출발하여 "서브트리 A 이외의 곳"으로 가는 경우에는 무조건 간선 i - j 를 타야 합니다.

pj의 경우에도 마찬가지 입니다. (연두색부분)

그래서 이제 간선 M개를 하나씩 보면서 (물론 '현재 상태' 에 포함된 간선 혹은 lca가 1이 아닌 간선은 제외해야함)

  1. "pi를 루트로 하는 서브트리의 점들" 에서 출발하여 "서브트리 A 이외의 곳"으로 가는 경우

  2. "pj를 루트로 하는 서브트리의 점들" 에서 출발하여 "서브트리 B 이외의 곳"으로 가는 경우

이 두가지에 대해 변화하는 거리합의 양을 O(1) 에 계산해서 전체 답을 구하면 됩니다.

이걸 O(1) 에 구하기 위해선 갖가지 전처리가 필요한데요... 정말 끔찍하게 구현이 오래 걸렸습니다...

구현을 자세하게 설명하기엔 좀 지치네요 ㅠㅠ

더 간단하게 해결하신 분의 코멘트 부탁드려요...

댓글 (1개) 댓글 쓰기


appa 4달 전

92점 풀이는 위의 풀이에서 두 가지에 대해 변화하는 거리의 합의 양을 특정 간선의 정점 i와 j('현재 상태' 그래프에서 LCA가 1이 아닌)에 대해 매번 다익스트라를 돌려주시는 것입니다. 계산식은 모든 정점에 대해 LCA 테이블을 구하시면서 1번 루트에까지 이르는 거리를 O(1)에 답할 수 있다면, 비교적 간단하게 구할 수 있습니다. 제가 92점을 받았던 소스코드는 아래와 같습니다.


#include <stdio.h>
#include <queue>
#include <vector>
#include <algorithm>
using namespace std;
long long res, sum, S, ss[10005];
int LOG_N = 13;
int n, m;
int X[100005], Y[100005], Z[100005], lev[10005], parent[14][10005], dist[10005];
vector <int> v[10005], c[10005], g[10005], gc[10005];
void dfs(int x, int par, int w) {
    for (int i = 0; i < g[x].size(); i++) {
        int y = g[x][i], w2 = gc[x][i];
        if (y == par) continue;
        lev[y] = lev[x] + 1;
        parent[0][y] = x;
        int anc = x;
        for (int j = 0; anc != 0; j++) {
            parent[j + 1][y] = parent[j][anc];
            anc = parent[j][anc];
        }
        dist[y] = w + w2;
        dfs(y, x, w + w2);
    }
}
int R;
void go(int x, int par, int w) {
    sum += w;
    ss[R] += w;
    for (int i = 0; i < g[x].size(); i++) {
        int y = g[x][i], w2 = gc[x][i];
        if (y == par) continue;
        go(y, x, w + w2);
    }
}
struct PQ {
    int x, w;
    bool operator < (const PQ q) const { return w < q.w; }
} t, now;
int d[10005], par[10005], pw[10005], D[3][10005];
void dijkstra(int x) {
    for (int i = 1; i <= n; i++) d[i] = 1e9;
    priority_queue <PQ> pq;
    d[x] = 0;
    t.x = x; t.w = 0; pq.push(t);
    while (!pq.empty()) {
        now = pq.top(); pq.pop();
        if (now.w > d[now.x]) continue;
        for (int i = 0; i < v[now.x].size(); i++) {
            int y = v[now.x][i], w = c[now.x][i];
            if (d[y] > d[now.x] + w) {
                d[y] = d[now.x] + w;
                t.x = y; t.w = d[y];
                par[y] = now.x;
                pw[y] = w;
                pq.push(t);
            }
        }
    }
}
void dijkstra2(int x, int p) {
    for (int i = 1; i <= n; i++) D[p][i] = 1e9;
    priority_queue <PQ> pq;
    D[p][x] = 0;
    if (p == 0) par[x] = 0;
    t.x = x; t.w = 0; pq.push(t);
    while (!pq.empty()) {
        now = pq.top(); pq.pop();
        if (now.w > D[p][now.x]) continue;
        for (int i = 0; i < g[now.x].size(); i++) {
            int y = g[now.x][i], w = gc[now.x][i];
            if (D[p][y] > D[p][now.x] + w) {
                D[p][y] = D[p][now.x] + w;
                t.x = y; t.w = D[p][y];
                if (p == 0) {
                    par[y] = now.x;
                    pw[y] = w;
                }
                pq.push(t);
            }
        }
    }
}
int XX, YY, ZZ;
int lca(int v1, int v2) {
    if (lev[v1] < lev[v2]) swap(v1, v2);
    for (int i = LOG_N; i >= 0; i--) if (((lev[v1] - lev[v2]) >> i) & 1) v1 = parent[i][v1];
    if (v1 == v2) return v1;
    for (int i = LOG_N; i >= 0; i--) {
        if (parent[i][v1] != parent[i][v2]) {
            v1 = parent[i][v1];
            v2 = parent[i][v2];
        }
    }
    return parent[0][v1];
}
void f(int x, int par) {
    if (x != 1) {
        for (int i = 0; i < v[x].size(); i++) {
            int y = v[x][i];
            if (y == 1) continue;
            if (x > y) continue;
            int LCA = lca(x, y);
            if (LCA != 1) continue;
            int cost = c[x][i];
            g[x].push_back(y); gc[x].push_back(cost);
            g[y].push_back(x); gc[y].push_back(cost);
            S = 0;
            XX = x, YY = y; ZZ = cost;
            dijkstra2(1, 0); dijkstra2(x, 1); dijkstra2(y, 2);
            for (int p = 1; p <= n; p++) {
                for (int q = 1; q <= n; q++) {
                    if (p == q) continue;
                    int LCA = lca(p, q);
                    if (LCA != 1) S += (dist[p] + dist[q] - 2 * dist[LCA]);
                    else {
                        if (D[0][p] < D[1][p] && D[0][p] < D[2][p]) S += (dist[p] + dist[q] - 2 * dist[LCA]);
                        else {
                            if (D[1][p] < D[2][p]) S += D[1][p] + ZZ + D[2][q];
                            else S += D[2][p] + ZZ + D[1][q];
                        }
                    }
                }
            }
            if (res < sum - S) res = sum - S;
            g[x].pop_back(); gc[x].pop_back();
            g[y].pop_back(); gc[y].pop_back();
        }
    }
    for (int i = 0; i < g[x].size(); i++) {
        int y = g[x][i];
        if (y == par) continue;
        f(y, x);
    }
}
int main(int argc, char** argv) {
    setbuf(stdout, NULL);
    int T;
    int test_case;
    scanf("%d", &T);
    for (test_case = 1; test_case <= T; test_case++) {
        res = -1e18;
        scanf("%d%d", &n, &m);
        for (int i = 1; i <= m; i++) {
            scanf("%d%d%d", &X[i], &Y[i], &Z[i]);
            v[X[i]].push_back(Y[i]); c[X[i]].push_back(Z[i]);
            v[Y[i]].push_back(X[i]); c[Y[i]].push_back(Z[i]);
        }
        dijkstra(1);
        for (int i = 2; i <= n; i++) {
            int y = par[i], w = pw[i];
            g[i].push_back(y); gc[i].push_back(w);
            g[y].push_back(i); gc[y].push_back(w);
        }
        dfs(1, 0, 0);
        for (int i = 1; i <= n; i++) {
            R = i;
            go(i, 0, 0);
        }
        f(1, 0);
        printf("Case #%d\n", test_case);
        printf("%lld\n", res);
        for (int i = 1; i <= n; i++) {
            v[i].clear(), c[i].clear(), g[i].clear(), gc[i].clear(), ss[i] = 0;
            dist[i] = 0;
            lev[i] = 0; pw[i] = par[i] = 0;
            for (int j = 0; j < LOG_N; j++) parent[j][i] = 0;
        }
        sum = S = 0;
    }
    return 0;
}