音乐播放器
HMF-have much fun Site
 
文章 标签
16

Powered by Gridea | Theme: Fog
载入天数...
载入时分秒...
总访问量:  |   访问人数:

【题解】AT3611 Tree MST

  热度: loading...

xpp杂题清除计划

这题和P8207 [THUPC2022 初赛] 最小公倍树有点像,都可以用优化选边的 kruskal 解决。

很显然,对于一张边数为 n2n^2 级别的完全图,直接做 kruskal 的复杂度时 O(n2logn2)O(n^2logn^2) 的,显然过不了。

但是也正是因为这是一张完全图,所以我们可以用一个定理来解决问题。

定理:对于任意一张完全图 G=(V,E)G=(V,E) ,选取数个边集 (E1,E2,E3,...,Ek)(E_1,E_2,E_3,...,E_k) 使其完全覆盖边集 EE 。对每个集合 EiE_i 求最小生成树,得到边集 EMSTiE_{MST_{i}} 。再对 EMST1,EMST2,EMST3,...,EMSTkE_{MST_{1}},E_{MST_{2}},E_{MST_{3}},...,E_{MST_{k}} 求最小生成树,其最小生成树一定也是 EE 的最小生成树。

证明:考虑反证。假设最终图有一种比用 EMST1,EMST2,EMST3,...,EMSTkE_{MST_{1}},E_{MST_{2}},E_{MST_{3}},...,E_{MST_{k}} 建出来的最小生成树更小的生成树,则该树一定有一条边 (u,v)(u,v) 不被包含在任何的 EMSTiE_{MST_i} 中。考虑这条边所属的一个或多个边集,在该边集中一定有一条连接 (u,v)(u,v) 的链,且此链一定不比 (u,v)(u,v) 更劣。因此在最终的最小生成树中用该链替代 (u,v)(u,v) 一定不会更劣。与该生成树比 EMST1,EMST2,EMST3,...,EMSTkE_{MST_{1}},E_{MST_{2}},E_{MST_{3}},...,E_{MST_{k}} 建出来的最小生成树更小矛盾。

因此考虑淀粉质,假设我们以当前的分治中心 rtrt 为根,设每个点的权值 pip_idisrt,i+widis_{rt,i} + w_i ,则对于不同子树的两个点,他们连边的权值为 pi+pjp_i+p_j 。考虑两两不同子树内的点之间连边组成的边集,设 pip_i 最小的 iimnmn ,则这个边集的最小生成树一定为 iVpi+pmn\sum_{i \in V} p_i+p_{mn} (这里的证明珂以参考一下 kruskal 算法正确性的证明),因此将淀粉质得到的所有边 (i,mn)(i,mn) 加入边集,再做一次最小生成树就是 G=(V,E)G=(V,E) 的最小生成树了。

CodeCode

#include<bits/stdc++.h>
#define N 200010
using namespace std;
typedef long long ll;
int n,en;
ll w[N],INF=1e18;
int cnt,head[N],to[N<<1],nxt[N<<1];
ll val[N<<1];
int S,son[N],siz[N],vis[N],rt;
struct edge{
	int u,v;
	ll w;
	const bool operator < (const edge o) const {return w<o.w;}
}e[N<<4];
void insert(int u,int v,int w) {
	cnt++;
	to[cnt]=v;
	val[cnt]=w;
	nxt[cnt]=head[u];
	head[u]=cnt;
}
int read() {
	int res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) f=ch=='-'?-1:1,ch=getchar();
	while(isdigit(ch)) res=res*10+ch-'0',ch=getchar();
	return f*res;
}
void gtrt(int now,int fa) {
    siz[now]=1,son[now]=0;
    for(int i=head[now]; i; i=nxt[i]) if(to[i]!=fa&&!vis[to[i]]) {
    	gtrt(to[i],now);
    	siz[now]+=siz[to[i]];
    	son[now]=max(son[now],siz[to[i]]);
    }
    son[now]=max(son[now],S-siz[now]);
    if(son[now]<son[rt]) rt=now;
}
ll sk[N];
int id[N],tp,mn;
void dfs(int now,int fa,ll dis) {
	sk[++tp]=w[now]+dis,id[tp]=now;
	if(sk[tp]<sk[mn]) mn=tp;
	for(int i=head[now]; i; i=nxt[i]) if(!vis[to[i]]&&to[i]!=fa) 
		dfs(to[i],now,dis+val[i]);
}
void solve(int now) {
	vis[now]=1;
    tp=mn=0,dfs(now,0,0);
	for(int i=1; i<=tp; i++) if(i!=mn) e[++en]=edge{id[i],id[mn],sk[i]+sk[mn]}; 
	int tmp=S;
	for(int i=head[now]; i; i=nxt[i]) if(!vis[to[i]]) {
	    S=siz[to[i]]<siz[now]?siz[to[i]]:tmp-siz[now];
	    rt=0,gtrt(to[i],now),solve(rt);
	}
}
int fa[N];
int find(int x) {
	while(fa[x]!=x) x=fa[x]=fa[fa[x]];
	return x;
}
int main()
{
	n=read();
	for(int i=1; i<=n; i++) w[i]=read();
	for(int i=1; i<n; i++) {
		int u=read(),v=read(),w=read();
		insert(u,v,w);
		insert(v,u,w);
	}
	S=son[0]=n,rt=0,gtrt(1,0),sk[0]=INF;
	solve(rt);
	for(int i=1; i<=n; i++) fa[i]=i;
	sort(e+1,e+en+1);int tot=0;ll sum=0;
	for(int i=1; i<=en; i++) {
		int u=find(e[i].u),v=find(e[i].v);
		if(u!=v) fa[u]=v,tot++,sum+=e[i].w;
		if(tot==n-1) {printf("%lld",sum);return 0;}
	}
	return 0;
}

我是不会说我开 nlogn 的数组只开了 n*4 导致错了好几发的