First, we make an observation about what the shortest distance to some -tuple is. The distance to -tuple is if in each graph there is a walk of exactly steps that ends at . To use this, we exploit more structure about what lengths of walks can end at a vertex. If there is a walk of length in that ends at , then we know there are also walks of length because one can repeatedly take one step away from and one step towards .
Moreover, if denotes the length of the shortest path of even length in that ends at , and for odd shortest path respectively, then valid distances for are the union of and . (If there is no even path or odd path, ignore that corresponding set.) We can compute the quantities and by creating a duplicate graph (one representing odd and one representing even with edges going between the copies) and using a BFS.
Now, we want to use this structure to compute the sum of all distances. Core to this question, is knowing for some -tuple what the minimum possible distance is. If we decide this distance will be even (and there exists an even path for each node in the tuple), then the distance is simply the maximum in the tuple. More concretely, we denote the sum of the distance of these tuples as where contains a list of pairs containing each node's and corresponding graph (if it has an even path). Then, calculates the sum, over all valid tuples, of the maximum value. (And analogously the same statement, if we decide this distance will be odd.)
But making such a decision for a tuple to be even or odd is difficult. Consider, instead, that we immediately calculate . If an entire tuple could use even paths or odd paths, then we overcounted the answer for that tuple by exactly the larger quantity (e.g. if odd was the worse decision for that tuple, then the maximum in that tuple). Finally, we can correct this by subtracting , where denotes a list with for corresponding nodes that have even paths and odd paths.
All that remains is how to calculate for some list . One such way is computing the number of tuples where the maximum is , then the answer is . To calculate , we note that this is equal to the product of over all , where represents the number of nodes from graph whose corresponding value is . Directly computing this would be too slow, but we can optimize.
For each graph , we compute for all . For all larger , . To hold this, we maintain a suffix product array (i.e. similar to a prefix sum array, but for suffixes and multiplication instead) and modify it such that all elements in the suffix starting with will be multiplied by . For the values of , we can similarly have an array and multiply by . Then, we can finally compute as . This runs in linear time, so in total our algorithm runs in time.
It is also possible to calculate using a segment tree or modular inverses.
Spencer's code:
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll mod = 1e9+7; int k; int n[50000]; int inf = 1e8; ll compute_sum(vector<pair<int, int> > li){ int maxn = 0; vector<ll> prefix_prod; vector<ll> suffix_prod; vector<ll> graphs[k]; for(int i = 0; i<li.size(); i++){ graphs[li[i].second].push_back(li[i].first); } for(int i = 0; i<k; i++){ vector<ll> cnt(2*n[i]); maxn = max(maxn,2*n[i]); while(prefix_prod.size()<maxn){ prefix_prod.push_back(1); } while(suffix_prod.size()<=maxn){ suffix_prod.push_back(1); } for(int j = 0; j<graphs[i].size(); j++){ cnt[graphs[i][j]]++; } for(int j = 0; j<2*n[i]; j++){ if(j>0){ cnt[j] += cnt[j-1]; } prefix_prod[j] *= cnt[j]; prefix_prod[j] %= mod; } suffix_prod[2*n[i]] *= cnt[2*n[i]-1]; suffix_prod[2*n[i]] %= mod; } for(int i = 1; i<suffix_prod.size(); i++){ suffix_prod[i] *= suffix_prod[i-1]; suffix_prod[i] %= mod; } ll ans = 0LL; for(int i = 1; i<maxn; i++){ ll cur_num = (prefix_prod[i]*suffix_prod[i])-(prefix_prod[i-1]*suffix_prod[i-1]); cur_num %= mod; ans += cur_num * (ll)i; ans %= mod; } if(ans<0LL){ ans += mod; } return ans; } int main(){ ios_base::sync_with_stdio(false); cin.tie(0); vector<pair<int, int> > evens; vector<pair<int, int> > odds; vector<pair<int, int> > both; cin >> k; for(int i = 0; i<k; i++){ int m; cin >> n[i] >> m; vector<int> adj[2*n[i]]; for(int j = 0; j<m; j++){ int a, b; cin >> a >> b; a--; b--; adj[a].push_back(n[i]+b); adj[b].push_back(n[i]+a); adj[n[i]+a].push_back(b); adj[n[i]+b].push_back(a); } vector<int> dist(2*n[i], inf); vector<int> li; dist[0] = 0; li.push_back(0); for(int j = 0; j<li.size(); j++){ int now = li[j]; for(int a = 0; a<adj[now].size(); a++){ int to = adj[now][a]; if(dist[to]==inf){ dist[to] = dist[now]+1; li.push_back(to); } } } for(int j = 0; j<n[i]; j++){ if(dist[j]<inf){ evens.push_back(make_pair(dist[j],i)); } if(dist[j+n[i]]<inf){ odds.push_back(make_pair(dist[j+n[i]],i)); } if(max(dist[j],dist[j+n[i]])<inf){ both.push_back(make_pair(max(dist[j],dist[j+n[i]]),i)); } } } ll ans = compute_sum(evens)+compute_sum(odds)-compute_sum(both); ans %= mod; if(ans<0LL){ ans += mod; } cout << ans << "\n"; }
Danny Mittal's code (with modular inverse):
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; public class SingleSourceShortestPath { public static final long MOD = 1000000007L; public static long inverse(long base) { int exponent = (int) MOD - 2; long res = 1; while (exponent != 0) { if (exponent % 2 == 1) { res *= base; res %= MOD; } exponent /= 2; base *= base; base %= MOD; } return res; } public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); long[] amts = new long[200000]; long[] amts2 = new long[200000]; boolean[] amtsZero = new boolean[200000]; boolean[] amts2Zero = new boolean[200000]; Arrays.fill(amts, 1); Arrays.fill(amts2, 1); int k = Integer.parseInt(in.readLine()); boolean anyBipartite = false; for (int g = 0; g < k; g++) { in.readLine(); StringTokenizer tokenizer = new StringTokenizer(in.readLine()); int n = Integer.parseInt(tokenizer.nextToken()); int m = Integer.parseInt(tokenizer.nextToken()); List<Integer>[] adj = new List[(2 * n) + 1]; for (int a = 1; a <= 2 * n; a++) { adj[a] = new ArrayList<>(); } for (int j = 1; j <= m; j++) { tokenizer = new StringTokenizer(in.readLine()); int a = Integer.parseInt(tokenizer.nextToken()); int b = Integer.parseInt(tokenizer.nextToken()); adj[a].add(n + b); adj[n + b].add(a); adj[n + a].add(b); adj[b].add(n + a); } int[] dist = new int[(2 * n) + 1]; Arrays.fill(dist, -1); dist[1] = 0; LinkedList<Integer> q = new LinkedList<>(); q.add(1); while (!q.isEmpty()) { int a = q.remove(); for (int b : adj[a]) { if (dist[b] == -1) { dist[b] = dist[a] + 1; q.add(b); } } } if (dist[n + 1] == -1) { anyBipartite = true; } long[] freq = new long[2 * n]; long[] freq2 = new long[2 * n]; for (int a = 1; a <= n; a++) { if (dist[a] != -1) { freq[dist[a]]++; } if (dist[n + a] != -1) { freq[dist[n + a]]++; } if (dist[a] != -1 && dist[n + a] != -1) { freq2[Math.max(dist[a], dist[n + a])]++; } } for (int d = 2; d < 2 * n; d++) { freq[d] += freq[d - 2]; freq2[d] += freq2[d - 1]; } for (int d = 0; d < 2 * n; d++) { if (freq[d] == 0L) { amtsZero[d] = true; } else { amts[d] *= freq[d]; amts[d] %= MOD; if (d >= 2 && freq[d - 2] != 0L) { amts[d] *= inverse(freq[d - 2]); amts[d] %= MOD; } } if (freq2[d] == 0L) { amts2Zero[d] = true; } else { amts2[d] *= freq2[d]; amts2[d] %= MOD; if (d >= 1 && freq2[d - 1] != 0L) { amts2[d] *= inverse(freq2[d - 1]); amts2[d] %= MOD; } } } } for (int d = 2; d < 200000; d++) { amts[d] *= amts[d - 2]; amts[d] %= MOD; amts2[d] *= amts2[d - 1]; amts2[d] %= MOD; } for (int d = 0; d < 200000; d++) { if (amtsZero[d]) { amts[d] = 0; } if (amts2Zero[d]) { amts2[d] = 0; } } if (anyBipartite) { Arrays.fill(amts2, 0); } long answer = 0; for (int d = 0; d < 200000; d++) { long dl = d; answer += dl * amts[d]; answer -= dl * amts2[d]; if (d >= 2) { answer -= dl * amts[d - 2]; answer += dl * amts2[d - 1]; } answer %= MOD; } answer += MOD; answer %= MOD; System.out.println(answer); } }