Codeforces 1101D (树形DP + GCD)

Codeforces 1101D
题意:给出一棵无向带点权树,定义$dist(x,y)$为$x,y$之间的点(包含$x,y$)的个数,$g(x,y)$为$x,y$之间所有点权$gcd$,求最大的$dist(x,y)$,并且$g(x,y) \geq 2$

由于$2 \times 10^5$以内的数最多$7$种质数相乘,所以可以对点权分解质因数,指数不重要,只考虑质因子,因为质因子相同的$gcd$一定$\geq 2$。

所以我一开始想对每个质因子建一棵树,然后在树上做类似直径的DP。
这样的树的总大小不会超过$7n$,但是空间就比较爆炸,虽然可以用map之类的搞但是麻烦得要死还带了$log$。

其实不需要建那么多棵树,直接一次DFS,找到孩子如果和自己有公质因数,那么像树直径DP那样转移即可。

注意这里距离为点的个数而不是边的条数,DP 方程要特别注意

即某个点初始化长度即更新为1,然后转移时

ans = max(ans, st_dfs[u][k].second + st_dfs[v][j].second);
st_dfs[u][k].second = max(st_dfs[u][k].second, st_dfs[v][j].second + 1);

注意加一的位置,画画图更好。

知识点:
1、时间不够了也要冷静分析,一定相信自己能找出问题
2、点权边权的问题要考虑清楚再写

//==========================Head files==========================
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#include<cmath>
#include<set>
#include<iostream>
#include<map>
#define LL long long
#define db double
#define mp make_pair
#define pr pair<int, int>
#define fir first
#define sec second
#define pb push_back
#define ms(i, j) memset(i, j, sizeof i)
using namespace std;
//==========================Templates==========================
inline int read() {
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9'){if (c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9'){x = x * 10 + c - '0'; c = getchar();}
return x * f;
}
inline LL readl() {
LL x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9'){if (c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9'){x = x * 10 + c - '0'; c = getchar();}
return x * f;
}
int power(int a, int b) {
int ans = 1;
while (b) {
if(b & 1) ans = ans * a;
b >>= 1; a = a * a;
}
return ans;
}
int power_mod(int a, int b, int mod) {
a %= mod;
int ans = 1;
while (b) {
if(b & 1) ans = (ans * a) % mod;
b >>= 1, a = (a * a) % mod;
}
return ans;
}
LL powerl(LL a, LL b) {
LL ans = 1ll;
while (b) {
if(b & 1ll) ans = ans * a;
b >>= 1ll;a = a * a;
}
return ans;
}
LL power_modl(LL a, LL b, LL mod) {
a %= mod;
LL ans = 1ll;
while (b) {
if(b & 1ll) ans = (ans * a) % mod;
b >>= 1ll, a = (a * a) % mod;
}
return ans;
}
LL gcdl(LL a, LL b) {return b == 0 ? a : gcdl(b, a % b);}
LL abssl(LL a) {return a > 0 ? a : -a;}
int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}
int abss(int a) {return a > 0 ? a : -a;}
//==========================Main body==========================
#define LD "%I64d"
#define D "%d"
#define pt printf
#define sn scanf
#define pty printf("YES\n")
#define ptn printf("NO\n")
//==========================Code here==========================
const LL MAXN = 200000 + 5;
LL n, a[MAXN], ans = 1ll;
vector<LL > st_a[MAXN];
vector<LL > G[MAXN];
vector<pair<LL, LL> > st_dfs[MAXN];
void dfs(LL u, LL fa) {
for (LL i = 0; i < (LL)st_a[u].size(); ++i) st_dfs[u].push_back(make_pair(st_a[u][i], 1));
for (LL i = 0; i < (LL)G[u].size(); ++i) {
LL v = G[u][i];
if (v != fa) {
dfs(v, u);
for (LL j = 0; j < (LL)st_dfs[v].size(); ++j) { // st_dfs[v][j]
LL whw = st_dfs[v][j].first;
for (LL k = 0; k < (LL)st_dfs[u].size(); ++k) { // st_dfs[u][k]
LL gg = st_dfs[u][k].first;
if (whw == gg) {
ans = max(ans, st_dfs[u][k].second + st_dfs[v][j].second);
//printf("ans=%d, u=%d, v=%d, whw=%d\n", ans, u, v, whw);
st_dfs[u][k].second = max(st_dfs[u][k].second, st_dfs[v][j].second + 1);
}
}
}
}
}
}
int main() {
cin >> n;
for (LL i = 1; i <= n; ++i) scanf("%lld", &a[i]);
int fl = 0;
for (LL i = 1; i <= n; ++i) if (a[i] != 1) fl = 1;
if (!fl) return printf("0\n"), 0;
for (LL u = 1; u <= n; ++u) {
LL tmp = a[u];
for (LL i = 2; i * i <= tmp; ++i) if (tmp % i == 0) {
st_a[u].push_back(i);
while (tmp % i == 0) tmp /= i;
}
if (tmp != 1) st_a[u].push_back(tmp);
sort(st_a[u].begin(), st_a[u].end());
}
/*for (LL u = 1; u <= n; ++u) {
for (LL i = 0; i < (LL)st_a[u].size(); ++i) printf("%d ", st_a[u][i]);
puts("");
}*/
for (LL x, y, i = 1; i < n; ++i) {
scanf("%lld%lld", &x, &y);
G[x].push_back(y), G[y].push_back(x);
}
dfs(1, 0);
cout << ans;
return 0;
}
/*
5
2 2 2 2 2
1 2
2 3
1 4
4 5
3
2 2 2
1 2
2 3
*/

------ 本文结束 ------