Length Of Cycle

2 secs 1024 MB
sgsw

ペアを選択した時に出来る閉路の長さはです。

通りに対する閉路の長さの総和をとするとき、明らかに

=

と言い換えることができます。

この値は、全頂点からdfsやbfsを行うことで計算量で求めることができますが、の条件下ではで求めることは絶望的です。

今、ある頂点(ここではとします)から()を求めましょう。

これはdfsやbfsを用いることでで求めることが可能です。これをとします。

さらに、を根に持つ木として、と定義します。

これも、dfsやbfsを用いることでで全頂点について求められます。

すると、の子を選び、グラフ上を移動したとすると、() は、

よりだけ減り、だけ増加することがわかります。

すなわち、() = です。

この関係を用いグラフ上でを行うと、全頂点に対する()を、計算量で求めることができます。

全体の時間計算量はです。

以下、による解答例です。

C++での解答例(227ms)
#include<bits/stdc++.h>
#define ll long long
#define rep(i,n) for (int i = 0; i < n; i++)

using namespace std;

template <class T = int>T extgcd(T a,T b,T &x,T &y){T g = a;x = 1;y = 0;if (b != 0) {g = extgcd(b, a % b, y, x), y -= (a / b) * x;}return g;}
template<class T = int> T invMod(T a,T m){T x,y;if (extgcd(a, m, x, y) == 1) {return (x + m) % m;}else{return -1;}}

const ll MOD = 998244353;
const int inf = 1e9;

signed main(){
    int n;
    cin>>n;
    vector<vector<int>> g(n);
    rep(i,n - 1){
        int u,v;
        cin>>u>>v;
        u--;v--;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    vector<int> DIST(n,inf);
    vector<int> child(n);
    vector<int> par(n,-1);
    ll tot = 0;
    int S_Node = 0;
    DIST[S_Node] = 0;

    function<void(int)> dfs1 = [&](int S){
        for (int adj : g[S]){
            if (DIST[adj] == inf){
                par[adj] = S;
                DIST[adj] = DIST[S] + 1;
                dfs1(adj);
            }
        }
        child[S]++;
        for (int adj : g[S]){
            if (adj != par[S]){
                child[S] += child[adj];
            }
        }
    };

    dfs1(S_Node);

    ll val = 0;rep(i,n){val += DIST[i];}

    vector<int> trace_back;
    rep(i,n){if(i!= S_Node)DIST[i] = inf;}

    function<void(int)> dfs2 = [&](int S){
        tot = (tot + val + n - 1) % MOD;
        for (int adj : g[S]){
            if (DIST[adj] == inf){
                DIST[adj] = DIST[S] + 1;
                trace_back.push_back(val);
                val = (val + n - 2 * child[adj]) % MOD;
                dfs2(adj);
            }
        }
        if (S != S_Node){
            val = trace_back.back();
            trace_back.pop_back();
        }
    };

    dfs2(S_Node);

    ll ans = tot * invMod<ll>((ll)n*(n - 1),MOD) % MOD;
    if (ans < 0){ans += MOD;}
    cout << ans << endl;
    return 0;
}
Python3での解答例(713ms)
import sys
sys.setrecursionlimit(1000000)
INF = 1 << 32
MOD = 998244353


def input():
    return sys.stdin.readline().rstrip()


n = int(input())
g = [[] for i in range(n)]

for i in range(n - 1):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    g[u].append(v)
    g[v].append(u)

DIST = [INF]*(n)
child = [0]*(n)
par = [-1]*(n)



def simple_dfs(S):
    for adj in g[S]:
        if DIST[adj] == INF:
            par[adj] = S
            DIST[adj] = DIST[S] + 1
            simple_dfs(adj)
    child[S] += 1
    for adj in g[S]:
        if adj != par[S]:
            child[S] += child[adj]


X = 0
DIST[X] = 0
simple_dfs(S=X)
tot = 0
val = sum(DIST[i] for i in range(n))


trace_back = []
for i in range(n):
    if i != X:
        DIST[i] = INF


def dfs(S):
    global val
    global tot
    tot += val + n - 1
    for adj in g[S]:
        if DIST[adj] == INF:
            tmp = val + n - 2 * child[adj]
            DIST[adj] = DIST[S] + 1
            trace_back.append(val)
            val = tmp
            dfs(adj)
    if S != X:
        val = trace_back.pop()


dfs(S=X)
ans = tot * pow(n * (n - 1), MOD - 2, MOD) % MOD
print(ans)