感觉是一个 \(\mathcal O(n)\) 做法。
下文中 \(deg\) 指以 \(1\) 为根儿子个数。
这里省略了一些转化。
首先是一个结论:将所有点按以度数为第一关键字,是否为关键点为第二关键字排序得到 \(p_1,p_2,\dots,p_n\)(钦定 \(p_1=1\)) 那么一定存在一个 \(1\le k\le n\),使得最优答案由保留 \(p_{1\sim k}\),将 \(p_{k+1\sim n}\) 按照以是否为关键点为第一关键字,度数为第二关键字重排后得到的排列取到。也就是依次填完 \(p_{1\sim k}\) 后接着填剩下的所有关键点,再填剩下的非关键点。
证明考虑如果这样取不到最优,最优的树长什么样。可以发现会存在第 \(i\) 层一个关键点 \(x\),第 \(i+1\) 层一个非关键点 \(y\) 使得 \(deg_y>deg_x\),且第 \(\ge i+2\) 层存在关键点。此时交换 \(x,y\)(原来 \(x\) 的子树都接在 \(y\) 下面,\(y\) 的前 \(deg_x\) 个子树接在 \(x\) 下面),只有一个关键点深度 \(+1\),如果 \(y\) 的后 \(deg_y-deg_x\) 个子树中有关键点则调整不劣,否则将其中任意一个子树和 \(\ge i+2\) 层的关键点交换也不劣。
考虑枚举 \(k\) 求答案,设现在的树在第 \(d\) 层还有 \(y\) 个空位,第 \(d+1\) 层还有 \(x\) 个空位,现在需要求出将 \(p_{k+1\sim n}\) 中的关键点依次填进去后他们的深度和。
先把前 \(y\) 个填进去,得到 \(x^{\prime}\) 个都在 \(d+1\) 层的空位。接下来考虑暴力:先把前 \(x_0=x^{\prime}\) 个点填在 \(d+1\) 层,设 \(x_1\) 为他们的度数和,则 \(d+2\) 层有 \(x_1\) 个空位,再填接下来的 \(x_1\) 个点,然后得到 \(x_2\),以此类推。
考虑优化。可以发现,在填入 \(deg\le 1\) 的关键点之前,都有 \(x_i\ge 2x_{i-1}\),所以这样至多填 \(\log\frac{n}{x_0}\) 层,而后面度数均为 \(1,0\) 的部分是容易 \(\mathcal O(1)\) 算的。而且可以认为,\(p_{2\sim k}\) 度数个不小于 \(2\),所以有 \(x_0=x^{\prime}\ge k\),那么这部分总复杂度为 \(\sum_{k=1}\mathcal O(\log\frac{n}{k})=\mathcal O(n)\)。
这里的复杂度分析是:
如果不特殊处理 \(deg=1\),由于 \(deg=0\) 至多填满一层,所以复杂度是 \(\sum_{k=1}\mathcal O(\frac{n}{k})=\mathcal O(n\log n)\)。
前面的排序和后面的构造容易线性,故总复杂度 \(\mathcal O(n)\)。
参考实现:
#include <bits/stdc++.h>
typedef long long LL;
typedef std::pair<int, int> pii;
#define fi first
#define se second
#define MP std::make_pairLL read()
{LL s = 0; int f = 1, c = getchar();for (; !isdigit(c); c = getchar()) f ^= (c == '-');for (; isdigit(c); c = getchar()) s = s * 10 + (c ^ 48);return f ? s : -s;
}
template<typename T>
void write(T x, char end = '\n')
{if (x < 0) putchar('-'), x = -x;static int d[100]; int cur = 0;do { d[++cur] = x % 10; } while (x /= 10);while (cur) putchar(d[cur--] ^ 48);putchar(end);
}
const int inf = 0x3f3f3f3f;
const LL INF = 0x3f3f3f3f3f3f3f3fll;
const int MAXN = 100005;
int n, m, LIM, key[MAXN], fa[MAXN], ff[MAXN];
int port[MAXN][2];
bool isKey[MAXN];
pii edge[MAXN];
std::vector<int> e[MAXN];
void setup(int x, int fat)
{fa[x] = fat;if (fat) e[x].erase(std::find(e[x].begin(), e[x].end(), fat));for (int y : e[x]) setup(y, x);
}
LL calc(int id[])
{static int fid[MAXN];LL ans = 0;int cur = 0; fid[++cur] = 0;for (int i = 1, d = 0; i <= n; d++){for (int j = i; j <= cur; j++)if (isKey[id[j]]) ans += d;int p = cur;for (int j = i; j <= p; j++)for (int t = 1; t <= e[id[j]].size(); t++)fid[++cur] = id[j];i = p + 1;if (i <= n && i > cur) return INF;}for (int i = 1; i <= n; i++) ff[id[i]] = fid[i];return ans;
}
int sdeg[MAXN];
int P;
LL query(int X, int d, int m)
{if (!X) return INF;if (m < X) return (LL)m * d;if (e[key[m]].size() >= 2)return (LL)X * d + query(sdeg[m] - sdeg[m - X], d + 1, m - X);if (e[key[m]].size() == 0)return m >= X ? INF : (LL)m * d;int q = (m - P) / X;if (!q) return (LL)X * d + query(sdeg[m] - sdeg[m - X], d + 1, m - X);return (LL)X * (d + (d + q - 1)) * q / 2 + query(X, d + q, m - q * X);
}
LL query(int Y, int X, int d, int m)
{if (m < Y) return (LL)m * d;return (LL)Y * d + query(X + sdeg[m] - sdeg[m - Y], d + 1, m - Y);
}
void calc()
{static int raw[MAXN], id[MAXN];static int cnt[MAXN];memset(cnt, 0, (n + 1) << 2);cnt[n + 1] = 1;for (int i = 2; i <= n; i++) cnt[e[i].size()]++;for (int i = n; i >= 0; i--) cnt[i] += cnt[i + 1];for (int i = 2; i <= n; i++) if (!isKey[i]) raw[cnt[e[i].size()]--] = i;for (int i = 2; i <= n; i++) if (isKey[i]) raw[cnt[e[i].size()]--] = i;raw[1] = 1;int cur = 0;for (int i = n; i; i--)if (isKey[raw[i]]) key[++cur] = raw[i];for (int i = 1; i <= cur; i++)sdeg[i] = sdeg[i - 1] + e[key[i]].size();for (P = 1; P < cur && e[key[P + 1]].size() == 0; P++) ;LL ans = calc(raw);int pt = n; int Y = e[1].size(), d = 1, X = 0;LL sum = 0;for (int i = 2; i <= n; i++){if (!cur || cur == n - i + 1) break;LL v = sum + query(Y, X, d, cur);if (v < ans) ans = v, pt = i;if (isKey[raw[i]]) cur--;Y--, X += e[raw[i]].size();if (isKey[raw[i]]) sum += d;if (!Y) d++, Y = X, X = 0;}cur = pt - 1;memcpy(id + 1, raw + 1, n << 2);for (int i = pt; i <= n; i++)if (isKey[raw[i]]) id[++cur] = raw[i]; for (int i = pt; i <= n; i++)if (!isKey[raw[i]]) id[++cur] = raw[i]; calc(id);write(ans);
}
void construct()
{static int cur[MAXN];memset(cur + 1, 0, n << 2);int T = 0;for (int i = 2; i <= n; i++){int x = ff[i];int y = e[x][cur[x]++];port[y][0] = port[i][1] = ++T;}for (int i = 1; i < n; i++){int u = edge[i].se;if (fa[edge[i].fi] == u) u = edge[i].fi, std::swap(port[u][0], port[u][1]);write(2, ' ');write(port[u][0], ' '), write(0, ' ');write(port[u][1], ' '), write(1);}
}
void mian()
{n = read(), m = read(), LIM = read();for (int i = 1; i < n; i++){int u = read(), v = read();e[u].push_back(v), e[v].push_back(u);edge[i] = MP(u, v);}memset(isKey + 1, false, n);for (int i = 1; i <= m; i++) isKey[read()] = true;if (n == 1) return write(0);setup(1, 0);calc();construct();
}
int main()
{for (int Tcnt = read(); Tcnt--; ){mian();for (int i = 1; i <= n; i++) e[i].clear();}return 0;
}