相關鏈接Orz AHdoc!!!!!!!!!!!!!
這種神犇算法的關鍵在于真正利用了MST是一棵“樹”的性質。也就是,它在求出MST后把它轉化為有根樹,然后,
按長度遞增順序對于圖中每一條不在MST中的邊(i, j),找到樹中i、j的最近公共祖先(LCA),記為p=LCA(i, j)。這樣,
樹中i->p->j就是從i到j的路徑。然后,依次掃描這條路徑上的所有的邊,將新邊(i, j)的長度與路徑上所有邊的長度比較,找到長度差最小的(不過由于邊(i, j)的長度一定不小于路徑上所有邊的長度,所以只要找到路徑上最長邊,則“刪去這條最長邊,加入邊(i, j)”一定是所有加入的邊為(i, j)的可行變換中代價最小的),取這個長度差最小的即可。不過最為神犇的一點是,這個算法在遍歷完這條路徑后,
會將路徑上所有的點(p點除外)的父結點全部設為p,也就是相當于并查集的路徑壓縮!這雖然會改變樹的形態,但任何兩點的LCA都是不會變的,因此不會影響后面的邊。
注意上面“按長度遞增順序”是重點,原因是“路徑壓縮”可能會改變某些點之間的路徑,也就是將某些點之間的路徑長度減小。但是,很容易發現,
被“壓縮”的這些邊必然是已經訪問過的,也就是說這些邊必然已經作為了前面的某條邊(i, j),i到j路徑上的邊。對于這條邊來說,可行變換中,新加入的邊的長度應盡量小,因此,如果按長度遞增順序,則這些邊在(i, j)之后肯定不會出現代價更小的可行變換,因此就可以將它們壓縮,不會影響最優解。復雜度分析:先不管求LCA的時間復雜度。設樹中結點i的深度為h[i](h[root]=0)。對于樹中的任意一個葉結點v,從root到v的路徑的總長度(總邊數)為h[v],因此,若某次要嘗試的邊(i, j)的某一端點(設為i)在從root到v的這條路徑上,則p=LCA(i, j)一定也在這條路徑上。這樣,訪問從i到p的路徑上的總訪問次數(也就是從i到p路徑上的邊數)為(h[i]-h[p])。在訪問完成后,需要將從i到p路徑上除p外的所有結點的父結點都設為p,也就是
從root到v的路徑的總長度減少了(h[i]-h[p]-1)。因此,在嘗試所有不在MST中的邊的過程中,訪問從root到v的最初路徑上的邊的總次數不會超過(h[v]+這些邊的總數)(這里h[v]指初始的h[v])。因此可以得到:
訪問樹中所有邊的總次數不會超過(最初所有葉結點深度之和+2*M),M為所有不在MST中邊的總數!由于“最初所有葉結點深度之和”不會超過Nlog2N,因此總時間復雜度為O(Mlog2M+M+Nlog2N),其中O(Mlog2M)為用Kruskal求MST的時間,如果忽略這部分時間,則總時間復雜度為O(M+Nlog2N)。
其實這個算法的時間復雜度在忽略排序的情況下是線性的,即O(M+N),但本沙茶搞不懂怎么證明這一步。
下面是具體實現時的注意事項:
(1)將MST轉化為有根樹時,應用BFS,而不要用DFS,否則對于特殊數據可能爆棧;
(2)求LCA時,應用先讓深度大的結點向上的方法(AHdoc神犇的方法),具體見下面的代碼片段1;或者應用兩者同時往上的方法(本沙茶的方法),具體見下面的代碼片段2;否則,對于樹是一條鏈,且每次訪問都是訪問最深的兩個結點時,一次求LCA的時間復雜度可能升到O(N)。
【代碼片段1】
int lca(int a, int b)
{
for (;;)
{
if (a == b) return b;
if (h[a] >= h[b]) a = Fa[a]; else b = Fa[b];
}
}
【代碼片段2】
int LCA(int x, int y)
{
while (x && y) {
if (fl[x] == _s) return x; else fl[x] = _s;
if (fl[y] == _s) return y; else fl[y] = _s;
x = pr[x]; y = pr[y];
}
if (x) while (1) if (fl[x] == _s) return x; else x = pr[x]; else while (1) if (fl[y] == _s) return y; else y = pr[y];
}
【具體題目】Beijing2010 Tree(BZOJ1977)
這題要求嚴格次小生成樹,因此在枚舉的時候要注意,不能使可行變換的代價為0。
#include <iostream>
#include <stdio.h>
#include <algorithm>
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
#define re1(i, n) for (int i=1; i<=n; i++)
const int MAXN = 100001, MAXM = 300001;
const long long INF = ~0Ull >> 2;
struct edge {
int a, b, len;
friend bool operator< (edge e0, edge e1) {return e0.len < e1.len;}
} E[MAXM];
struct edge0 {
int a, b, id, pre, next;
} E0[MAXM + MAXM];
int n, m, m0, u[MAXN], pr[MAXN], No[MAXN], s[MAXM], Q[MAXN], fl[MAXN], _s;
long long mst_v = 0, res;
bool inmst[MAXM], vst[MAXN];
void init()
{
scanf("%d%d", &n, &m);
re(i, m) scanf("%d%d%d", &E[i].a, &E[i].b, &E[i].len);
}
int find(int x) {int r = x, r0; while (u[r] > 0) r = u[r]; while (u[x] > 0) {r0 = u[x]; u[x] = r; x = r0;} return r;}
void uni(int s1, int s2) {int tmp = u[s1] + u[s2]; if (u[s1] > u[s2]) {u[s1] = s2; u[s2] = tmp;} else {u[s2] = s1; u[s1] = tmp;}}
void init_d()
{
re1(i, n) {E0[i].a = i; E0[i].pre = E0[i].next = i;}
if (n % 2) m0 = n + 1; else m0 = n + 2;
}
void add_edge(int a, int b, int id)
{
E0[m0].a = a; E0[m0].b = b; E0[m0].id = id; E0[m0].pre = E0[a].pre; E0[m0].next = a; E0[a].pre = m0; E0[E0[m0].pre].next = m0++;
E0[m0].a = b; E0[m0].b = a; E0[m0].id = id; E0[m0].pre = E0[b].pre; E0[m0].next = b; E0[b].pre = m0; E0[E0[m0].pre].next = m0++;
}
void prepare()
{
sort(E, E + m);
re1(i, n) u[i] = -1; init_d();
int s1, s2, z = 0;
re(i, m) {
s1 = find(E[i].a); s2 = find(E[i].b);
if (s1 != s2) {z++; mst_v += E[i].len; add_edge(E[i].a, E[i].b, i); inmst[i] = 1; uni(s1, s2); if (z == n - 1) break;}
}
}
void bfs()
{
re1(i, n) vst[i] = 0;
Q[0] = 1; vst[1] = 1;
int i, j;
for (int front=0, rear=0; front<=rear; front++) {
i = Q[front];
for (int p=E0[i].next; p != i; p=E0[p].next) {
j = E0[p].b;
if (!vst[j]) {
vst[j] = 1; Q[++rear] = j; pr[j] = i; No[j] = E0[p].id;
}
}
}
}
int LCA(int x, int y)
{
while (x && y) {
if (fl[x] == _s) return x; else fl[x] = _s;
if (fl[y] == _s) return y; else fl[y] = _s;
x = pr[x]; y = pr[y];
}
if (x) while (1) if (fl[x] == _s) return x; else x = pr[x]; else while (1) if (fl[y] == _s) return y; else y = pr[y];
}
void sol0(int a, int b, int l)
{
int p = LCA(a, b), p0, No0;
while (a != p) {No0 = No[a]; if (!s[No0] && l > E[No0].len) s[No0] = l - E[No0].len; p0 = pr[a]; pr[a] = p; a = p0;}
while (b != p) {No0 = No[b]; if (!s[No0] && l > E[No0].len) s[No0] = l - E[No0].len; p0 = pr[b]; pr[b] = p; b = p0;}
}
void solve()
{
pr[1] = 0; bfs();
re(i, m) if (!inmst[i]) {_s = i + 1; sol0(E[i].a, E[i].b, E[i].len);}
res = INF;
re(i, m) if (inmst[i] && s[i] && s[i] < res) res = s[i];
res += mst_v;
}
void pri()
{
cout << res << endl;
}
int main()
{
init();
prepare();
solve();
pri();
return 0;
}