QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#288614#7610. Bus LinesSortingWA 6ms37960kbC++209.0kb2023-12-23 06:33:232023-12-23 06:33:24

Judging History

你现在查看的是最新测评结果

  • [2023-12-23 06:33:24]
  • 评测
  • 测评结果:WA
  • 用时:6ms
  • 内存:37960kb
  • [2023-12-23 06:33:23]
  • 提交

answer

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ; 
typedef vector<int> vi;
typedef vector < long long > vl ;
typedef unsigned int uint ;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MAXN = 2e5 + 7 ;
const int LOG = 20 ;

int n , m ;
vector < int > v[ MAXN ] ;
pii a[ MAXN ] ;
int subtree[ MAXN ] ;
vector < int > at_lca[ MAXN ] ;
vector < int > ending_at[ MAXN ] ;
int LCA[ MAXN ][ LOG ] , lvl[ MAXN ] ;
int st[ MAXN ] , en[ MAXN ] , tp ;

void init ( int x , int prv ) {
    st[ x ] = ++ tp ;
    for ( int i = 1 ; i < LOG ; ++ i ) {
        LCA[ x ][ i ] = LCA[ LCA[ x ][ i - 1 ] ][ i - 1 ] ;
    }
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        LCA[ y ][ 0 ] = x ;
        lvl[ y ] = lvl[ x ] + 1 ;
        init ( y , x ) ;
    }
    en[ x ] = tp ;
}

ll down[ MAXN ] , up[ MAXN ] ; // down[ x ] - edge above x
ll child_sum[ MAXN ] ;

class Tree {
public :
    ll tr[ 4 * MAXN ] ;
    void init ( int w , int l , int r ) {
        tr[ w ] = 0 ;
        if ( l == r ) { return ; }
        int mid = ( l + r ) / 2 ;
        init ( 2 * w , l , mid ) ;
        init ( 2 * w + 1 , mid + 1 , r ) ;
    }
    void update ( int w , int l , int r , int st , int en , ll add ) {
        if ( l > r || st > en ) { return ; }
        if ( r < st || en < l ) { return ; }
        if ( st <= l && r <= en ) {
            tr[ w ] += add ;
            return ;
        }
        int mid = ( l + r ) / 2 ;
        update ( 2 * w , l , mid , st , en , add ) ;
        update ( 2 * w + 1 , mid + 1 , r , st , en , add ) ;
    }
    ll query ( int w , int l , int r , int pos ) {
        if ( pos < l || r < pos ) { return 0 ; }
        ll ret = tr[ w ] ;
        if ( l == r ) { return ret ; }
        int mid = ( l + r ) / 2 ;
        if ( pos <= mid ) { ret += query ( 2 * w , l , mid , pos ) ; }
        else { ret += query ( 2 * w + 1 , mid + 1 , r , pos ) ; }
        return ret ;
    }
};
Tree w ;

int get_lca ( int x , int y ) {
    for ( int i = LOG - 1 ; i >= 0 ; -- i ) {
        if ( ( lvl[ x ] - ( 1 << i ) ) >= lvl[ y ] ) { x = LCA[ x ][ i ] ; }
        if ( ( lvl[ y ] - ( 1 << i ) ) >= lvl[ x ] ) { y = LCA[ y ][ i ] ; }
    }
    for ( int i = LOG - 1 ; i >= 0 ; -- i ) {
        if ( LCA[ x ][ i ] != LCA[ y ][ i ] ) {
            x = LCA[ x ][ i ] ;
            y = LCA[ y ][ i ] ;
        }
    }
    if ( x != y ) { x = LCA[ x ][ 0 ] ; }
    return x ;
}

ll path_val ( int x , int y ) {
    int aux = get_lca ( x , y ) ;
    // printf ( "path query %d %d -- %lld\n" , x , y , w.query ( 1 , 1 , n , st[ x ] ) + w.query ( 1 , 1 , n , st[ y ] ) - 2 * w.query ( 1 , 1 , n , st[ aux ] ) ) ;
    return w.query ( 1 , 1 , n , st[ x ] ) + w.query ( 1 , 1 , n , st[ y ] ) - 2 * w.query ( 1 , 1 , n , st[ aux ] ) ;
}

ll hh[ MAXN ] ;
multiset < pii > s[ MAXN ] ;

void add_elem ( int wh , int x ) {
    if ( s[ wh ].empty ( ) == true ) {
        s[ wh ].insert ( { st[ x ] , x } ) ;
        return ;
    }
    auto nxt = s[ wh ].lower_bound ( { st[ x ] , x } ) ;
    auto prv = nxt ;
    if ( nxt == s[ wh ].end ( ) ) { nxt = s[ wh ].begin ( ) ; }
    if ( prv == s[ wh ].begin ( ) ) { prv = s[ wh ].end ( ) ; -- prv ; }
    else { -- prv ; }
    hh[ wh ] -= path_val ( prv->se , nxt->se ) ;
    hh[ wh ] += path_val ( prv->se , x ) ;
    hh[ wh ] += path_val ( x , nxt->se ) ;
    
    s[ wh ].insert ( { st[ x ] , x } ) ;
}

void rem_elem ( int wh , int x ) {
    if ( s[ wh ].size ( ) == 1 ) {
        s[ wh ].clear ( ) ;
        return ;
    }
    auto it = s[ wh ].find ( { st[ x ] , x } ) ;
    auto nxt = it , prv = it ;
    ++ nxt ;
    if ( nxt == s[ wh ].end ( ) ) { nxt = s[ wh ].begin ( ) ; }

    if ( prv == s[ wh ].begin ( ) ) { prv = s[ wh ].end ( ) ; -- prv ; }
    else { -- prv ; }

    hh[ wh ] -= path_val ( prv->se , x ) ;
    hh[ wh ] -= path_val ( x , nxt->se ) ;
    hh[ wh ] += path_val ( prv->se , nxt->se ) ;
    
    s[ wh ].erase ( it ) ;
}

void dfs_down ( int x , int prv ) {
    int wh = 0 ;
    subtree[ x ] = 1 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        dfs_down ( y , x ) ;
        subtree[ x ] += subtree[ y ] ;
        child_sum[ x ] += down[ y ] ;
        if ( s[ y ].size ( ) > s[ wh ].size ( ) ) { wh = y ; }
    }
    swap ( s[ wh ] , s[ x ] ) ;
    hh[ x ] = hh[ wh ] ;
    for ( auto y : v[ x ] ) {
        if ( y == prv || y == wh ) { continue ; }
        for ( auto [ pos , val ] : s[ y ] ) {
            add_elem ( x , val ) ;
        }
    }
    for ( auto id : at_lca[ x ] ) {
        if ( a[ id ].fi != x ) { rem_elem ( x , a[ id ].fi ) ; }
        if ( a[ id ].se != x ) { rem_elem ( x , a[ id ].se ) ; }
    }
    /**
    printf ( "-- %d --> %d %lld %lld\n" , x , subtree[ x ] , hh[ x ] / 2 , child_sum[ x ] ) ;
    for ( auto [ pos , val ] : s[ x ] ) {
        printf ( "%d " , val ) ;
    }
    printf ( "\n" ) ;
    **/
    down[ x ] = subtree[ x ] ;
    if ( s[ x ].empty ( ) == false ) {
        add_elem ( x , x ) ;        
        down[ x ] += hh[ x ] / 2 + child_sum[ x ] ;
        rem_elem ( x , x ) ;
    }
    w.update ( 1 , 1 , n , st[ x ] , en[ x ] , child_sum[ x ] - down[ x ] ) ;
    for ( auto foo : ending_at[ x ] ) { 
        add_elem ( x , x ) ;
    }
}

void dfs_out ( int x , int prv ) {
    int wh = 0 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        dfs_out ( y , x ) ;
        if ( s[ wh ].size ( ) < s[ y ].size ( ) ) { wh = y ; }
    }
    swap ( s[ wh ] , s[ x ] ) ;
    hh[ x ] = hh[ wh ] ;
    for ( auto y : v[ x ] ) {
        if ( y == prv || y == wh ) { continue ; }
        for ( auto [ pos , val ] : s[ y ] ) {
            add_elem ( x , val ) ;
        }
    }
    for ( auto id : ending_at[ x ] ) {
        int y = a[ id ].fi + a[ id ].se - x ;
        add_elem ( x , y ) ;
    }
    for ( auto id : at_lca[ x ] ) {
        if ( a[ id ].fi != x ) { rem_elem ( x , a[ id ].se ) ; }
        if ( a[ id ].se != x ) { rem_elem ( x , a[ id ].fi ) ; }
    }
    /**
    printf ( "AT VERTEX %d\n" , x ) ;
    for ( auto [ pos , val ] : s[ x ] ) {
        printf ( "%d " , val ) ;
    }
    printf ( "\n" ) ;
    **/
    up[ x ] = n - subtree[ x ] ;
    if ( s[ x ].empty ( ) == false ) { 
        add_elem ( x , x ) ;
        int tot_lca = get_lca ( (s[ x ].begin ( ))->se , (s[ x ].rbegin ( ))->se ) ;
        up[ x ] += hh[ x ] / 2 + down[ tot_lca ] - child_sum[ x ] ;
        // printf ( "%lld %lld %lld up = %lld\n" , hh[ x ] , down[ tot_lca ] , child_sum[ x ] , up[ x ] ) ;
        rem_elem ( x , x ) ;
    }
}

int mxup[ MAXN ] ;

void mrk ( int x , int prv ) {
    int wh = 0 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        mrk ( y , x ) ;
        if ( s[ y ].size ( ) > s[ wh ].size ( ) ) { wh = y ; }
    }
    swap ( s[ x ] , s[ wh ] ) ;
    for ( auto y : v[ x ] ) {
        if ( y == prv || y == wh ) { continue ; }
        for ( auto foo : s[ y ] ) {
            s[ x ].insert ( foo ) ;
        }
    }
    for ( auto id : ending_at[ x ] ) {
        int aux = get_lca ( a[ id ].fi , a[ id ].se ) ; 
        s[ x ].insert ( { lvl[ aux ] , aux } ) ;
    }
    for ( auto id : at_lca[ x ] ) {
        auto it = s[ x ].find ( { lvl[ x ] , x } ) ;
        s[ x ].erase ( it ) ;
    }
    if ( s[ x ].empty ( ) == true ) { mxup[ x ] = LCA[ x ][ 0 ] ; }
    else { mxup[ x ] = (s[ x ].begin ( ))->se ; }
}

void dfs_up ( int x , int prv ) {
    if ( mxup[ x ] != 1 ) { 
        up[ x ] += up[ mxup[ x ] ] - ( subtree[ mxup[ x ] ] - subtree[ x ] ) ;
    }
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        dfs_up ( y , x ) ;
    }
}

void solve ( ) {
    cin >> n ;
    for ( int i = 1 , x , y ; i < n ; ++ i ) {
        cin >> x >> y ;
        v[ x ].push_back ( y ) ;
        v[ y ].push_back ( x ) ;
    }
    init ( 1 , -1 ) ;
    cin >> m ;
    for ( int i = 1 ; i <= m ; ++ i ) {
        cin >> a[ i ].fi >> a[ i ].se ;
        int aux = get_lca ( a[ i ].fi , a[ i ].se ) ;
        if ( a[ i ].fi != aux ) { ending_at[ a[ i ].fi ].push_back ( i ) ; }
        if ( a[ i ].se != aux ) { ending_at[ a[ i ].se ].push_back ( i ) ; }
        at_lca[ get_lca ( a[ i ].fi , a[ i ].se ) ].push_back ( i ) ;
    }
    w.init ( 1 , 1 , n ) ;
    dfs_down ( 1 , -1 ) ;
    /**
    for ( int i = 1 ; i <= n ; ++ i ) {
        printf ( "down[ %d ] = %lld\n" , i , down[ i ] ) ;
    }
    **/
    for ( int i = 1 ; i <= n ; ++ i ) {
        s[ i ].clear ( ) ;
        hh[ i ] = 0 ;
    }
    dfs_out ( 1 , -1 ) ;
    /**
    for ( int i = 1 ; i <= n ; ++ i ) {
        printf ( "up[ %d ] = %lld\n" , i , up[ i ] ) ;
    }
    **/
    for ( int i = 1 ; i <= n ; ++ i ) {
        s[ i ].clear ( ) ;
        hh[ i ] = 0 ;
    }
    mrk ( 1 , -1 ) ;
    dfs_up ( 1 , -1 ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        cout << up[ i ] + child_sum[ i ] << " " ;
    }
    cout << "\n" ;
}

int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ; 
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}


Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 3ms
memory: 37960kb

input:

6
1 2
5 4
6 5
3 1
1 5
3
6 1
2 3
6 4

output:

6 9 9 10 7 7 

result:

ok single line: '6 9 9 10 7 7 '

Test #2:

score: -100
Wrong Answer
time: 6ms
memory: 37384kb

input:

2
1 2
1
1 2

output:

3 2 

result:

wrong answer 1st lines differ - expected: '1 1', found: '3 2 '