From: Mout-sea <2582621015@qq.com> Date: Wed, 27 Jan 2021 07:14:07 +0000 (+0800) Subject: fix(tree-diameter.md): 将代码换为了更可读的版本 X-Git-Url: http://git.osdn.net/view?a=commitdiff_plain;h=9626375be4917704f8862cb60b0da4a47fa7edd1;p=oi-wiki%2Fmain.git fix(tree-diameter.md): 将代码换为了更可读的版本 --- diff --git a/docs/graph/tree-diameter.md b/docs/graph/tree-diameter.md index 5a442e3b..7fd0ca45 100644 --- a/docs/graph/tree-diameter.md +++ b/docs/graph/tree-diameter.md @@ -28,29 +28,33 @@ 因此定理成立。 ```cpp -const int N = 10009; -VI adj[N]; -int d[N], c; -int n; -#define v (*it) -void dfs(int u) { - ECH(it, adj[u]) if (!d[v]) { +#include +using namespace std; + +const int N = 10000 + 10; + +int n, c, d[N]; +vector E[N]; + +void dfs(int u, int fa) { + for (int v: E[u]) { + if (v == fa) continue; d[v] = d[u] + 1; if (d[v] > d[c]) c = v; - dfs(v); + dfs(v, u); } } -#undef v + int main() { - REP_C(i, RD(n) - 1) { - int a, b; - RD(a, b); - --a, --b; - adj[a].PB(b), adj[b].PB(a); + scanf("%d", &n); + for (int i = 1; i < n; i++) { + int u,v; + scanf("%d %d", &u, &v); + E[u].push_back(v), E[v].push_back(u); } - - d[0] = 1, dfs(0); - RST(d), d[c] = 1, dfs(c), OT(d[c] - 1); + dfs(1,0); d[c] = 0, dfs(c,0); + printf("%d\n", d[c]); + return 0; } ``` @@ -59,36 +63,37 @@ int main() { 我们记录每个节点向下,所能延伸的最远距离 $d_1$ ,和次远距离 $d_2$ ,那么直径就是所有 $d_1 + d_2$ 的最大值。 ```cpp -#include -#include +#include using namespace std; -const int N = int(1e4) + 9; -vector adj[N]; -int n, d; -int dfs(int u = 1, int p = -1) { - int d1 = 0, d2 = 0; - for (auto v : adj[u]) { - if (v == p) continue; - int d = dfs(v, u) + 1; - if (d > d1) - d2 = d1, d1 = d; - else if (d > d2) - d2 = d; +const int N = 10000 + 10; + +int n, c, d = 0; +int d1[N], d2[N]; +vector E[N]; + +void dfs(int u, int fa) { + d1[u] = d2[u] = 0; + for (int v: E[u]) { + if (v == fa) continue; + dfs(v, u); + int t = d1[v] + 1; + if (t > d1[u]) d2[u] = d1[u], d1[u] = t; + else if (t > d2[u]) d2[u] = t; } - d = max(d, d1 + d2); - return d1; + d = max(d, d1[u] + d2[u]); } + int main() { - cin >> n; - for (int i = 0; i < n - 1; ++i) { - int a, b; - cin >> a >> b; - adj[a].push_back(b); - adj[b].push_back(a); + scanf("%d", &n); + for (int i = 1; i < n; i++) { + int u,v; + scanf("%d %d", &u, &v); + E[u].push_back(v), E[v].push_back(u); } - dfs(); - cout << d << endl; + dfs(1, 0); + printf("%d\n", d); + return 0; } ```