OSDN Git Service

fix(tree-diameter.md): 将代码换为了更可读的版本
authorMout-sea <2582621015@qq.com>
Wed, 27 Jan 2021 07:14:07 +0000 (15:14 +0800)
committerMout-sea <2582621015@qq.com>
Wed, 27 Jan 2021 07:14:07 +0000 (15:14 +0800)
docs/graph/tree-diameter.md

index 5a442e3..7fd0ca4 100644 (file)
 因此定理成立。
 
 ```cpp
-const int N = 10009;
-VI adj[N];
-int d[N], c;
-int n;
-#define v (*it)
-void dfs(int u) {
-  ECH(it, adj[u]) if (!d[v]) {
+#include <bits/stdc++.h>
+using namespace std;
+
+const int N = 10000 + 10;
+
+int n, c, d[N];
+vector<int> E[N];
+
+void dfs(int u, int fa) {
+  for (int v: E[u]) {
+    if (v == fa) continue;
     d[v] = d[u] + 1;
     if (d[v] > d[c]) c = v;
-    dfs(v);
+    dfs(v, u);
   }
 }
-#undef v
+
 int main() {
-  REP_C(i, RD(n) - 1) {
-    int a, b;
-    RD(a, b);
-    --a, --b;
-    adj[a].PB(b), adj[b].PB(a);
+  scanf("%d", &n);
+  for (int i = 1; i < n; i++) {
+    int u,v;
+    scanf("%d %d", &u, &v);
+    E[u].push_back(v), E[v].push_back(u);
   }
-
-  d[0] = 1, dfs(0);
-  RST(d), d[c] = 1, dfs(c), OT(d[c] - 1);
+  dfs(1,0); d[c] = 0, dfs(c,0);
+  printf("%d\n", d[c]);
+  return 0;
 }
 ```
 
@@ -59,36 +63,37 @@ int main() {
 我们记录每个节点向下,所能延伸的最远距离 $d_1$ ,和次远距离 $d_2$ ,那么直径就是所有 $d_1 + d_2$ 的最大值。
 
 ```cpp
-#include <iostream>
-#include <vector>
+#include <bits/stdc++.h>
 using namespace std;
 
-const int N = int(1e4) + 9;
-vector<int> adj[N];
-int n, d;
-int dfs(int u = 1, int p = -1) {
-  int d1 = 0, d2 = 0;
-  for (auto v : adj[u]) {
-    if (v == p) continue;
-    int d = dfs(v, u) + 1;
-    if (d > d1)
-      d2 = d1, d1 = d;
-    else if (d > d2)
-      d2 = d;
+const int N = 10000 + 10;
+
+int n, c, d = 0;
+int d1[N], d2[N];
+vector<int> E[N];
+
+void dfs(int u, int fa) {
+  d1[u] = d2[u] = 0;
+  for (int v: E[u]) {
+    if (v == fa) continue;
+    dfs(v, u);
+    int t = d1[v] + 1;
+    if (t > d1[u]) d2[u] = d1[u], d1[u] = t;
+    else if (t > d2[u]) d2[u] = t;
   }
-  d = max(d, d1 + d2);
-  return d1;
+  d = max(d, d1[u] + d2[u]);
 }
+
 int main() {
-  cin >> n;
-  for (int i = 0; i < n - 1; ++i) {
-    int a, b;
-    cin >> a >> b;
-    adj[a].push_back(b);
-    adj[b].push_back(a);
+  scanf("%d", &n);
+  for (int i = 1; i < n; i++) {
+    int u,v;
+    scanf("%d %d", &u, &v);
+    E[u].push_back(v), E[v].push_back(u);
   }
-  dfs();
-  cout << d << endl;
+  dfs(1, 0);
+  printf("%d\n", d);
+  return 0;
 }
 ```