QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#76637#2564. Two TreesSortingWA 4ms38732kbC++6.0kb2023-02-11 07:29:392023-02-11 07:29:41

Judging History

你现在查看的是最新测评结果

  • [2023-08-10 23:21:45]
  • System Update: QOJ starts to keep a history of the judgings of all the submissions.
  • [2023-02-11 07:29:41]
  • 评测
  • 测评结果:WA
  • 用时:4ms
  • 内存:38732kb
  • [2023-02-11 07:29:39]
  • 提交

answer

#include<bits/stdc++.h>
#include<bits/extc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ; 
typedef vector<int> vi;
typedef unsigned int uint ;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MAXN = 2e5 + 7 ;
const int LOG = 18 ;

int n ;
vector < int > v[ 2 ][ MAXN ] ;
int st[ 2 ][ MAXN ] , lvl[ 2 ][ MAXN ] ;
int mxst[ MAXN ] ;
vector < int > ord[ 2 ] ;
int rmq[ 2 ][ LOG ][ MAXN ] ;

void dfs ( int x , int prv , int wh ) {
    ord[ wh ].push_back ( x ) ;
    st[ wh ][ x ] = ord[ wh ].size ( ) ;
    for ( auto y : v[ wh ][ x ] ) {
        if ( y == prv ) { continue ; }
        lvl[ wh ][ y ] = lvl[ wh ][ x ] + 1 ;
        dfs ( y , x , wh ) ;
        ord[ wh ].push_back ( x ) ;
    }
}

int get_lca ( int wh , int x , int y ) {
    int l = st[ wh ][ x ] , r = st[ wh ][ y ] ;
    if ( l > r ) { swap ( l , r ) ; }
    int len = r - l + 1 ;
    int id = mxst[ len ] ;
    int cand1 = rmq[ wh ][ id ][ l ] , cand2 = rmq[ wh ][ id ][ r - ( 1 << id ) + 1 ] ;
    if ( lvl[ wh ][ cand1 ] < lvl[ wh ][ cand2 ] ) { return cand1 ; }
    return cand2 ;
}

int get_dist ( int wh , int x , int y ) {
    int aux = get_lca ( wh , x , y ) ;
    return lvl[ wh ][ x ] + lvl[ wh ][ y ] - 2 * lvl[ wh ][ aux ] ;
}

uint ans ;
tuple < uint , uint , uint > vals[ MAXN ] ;
tuple < uint , uint , uint > to_parent[ MAXN ] ;

bool used[ 2 ][ MAXN ] ;
int cen_prv[ 2 ][ MAXN ] ;
int cnt[ 2 ][ MAXN ] ;
int tot ;

void init ( int x , int prv , int wh ) {
    cnt[ wh ][ x ] = 1 ;
    ++ tot ; 
    for ( auto y : v[ wh ][ x ] ) {
        if ( used[ wh ][ y ] == true || y == prv ) { continue ; }
        init ( y , x , wh ) ;
        cnt[ wh ][ x ] += cnt[ wh ][ y ] ;
    }
}

int get_centroid ( int x , int prv , int wh ) {
    for ( auto y : v[ wh ][ x ] ) {
        if ( used[ wh ][ y ] == true || y == prv ) { continue ; }
        if ( 2 * cnt[ wh ][ y ] > tot ) {
            return get_centroid ( y , x , wh ) ;
        }
    }
    return x ;
}

void update ( int x , int coef , uint aux ) {
    int pos = x ;
    while ( pos > 0 ) {
        uint hh = get_dist ( 1 , x , pos ) ;
        get < 0 > ( vals[ pos ] ) += coef ;
        get < 1 > ( vals[ pos ] ) += coef * ( hh + aux ) ;
        get < 2 > ( vals[ pos ] ) += coef * ( hh + aux ) * ( hh + aux ) ;

        if ( cen_prv[ 1 ][ pos ] > 0 ) {
            hh = get_dist ( 1 , x , cen_prv[ 1 ][ pos ] ) ;
            get < 0 > ( to_parent[ pos ] ) += coef ;
            get < 1 > ( to_parent[ pos ] ) += coef * ( hh + aux ) ;
            get < 2 > ( to_parent[ pos ] ) += coef * ( hh + aux ) * ( hh + aux ) ;
        }

        pos = cen_prv[ 1 ][ pos ] ;
    }
}

void query ( int x , uint aux ) {
    int pos = x ;
    tuple < uint , uint , uint > act = { 0 , 0 , 0 } ;
    while ( pos > 0 ) {
        uint hh = get_dist ( 1 , x , pos ) ;
        get < 0 > ( act ) = get < 0 > ( vals[ pos ] ) ;
        get < 1 > ( act ) = get < 1 > ( vals[ pos ] ) ;
        get < 2 > ( act ) = get < 2 > ( vals[ pos ] ) ;

        ans += get < 0 > ( act ) * ( aux + hh ) * ( aux + hh ) ;
        ans += 2 * get < 1 > ( act ) * ( aux + hh ) ;
        ans += get < 2 > ( act ) ;

        if ( cen_prv[ 1 ][ pos ] > 0 ) {
            get < 0 > ( act ) -= get < 0 > ( to_parent[ pos ] ) ;
            get < 1 > ( act ) -= get < 1 > ( to_parent[ pos ] ) ;
            get < 2 > ( act ) -= get < 2 > ( to_parent[ pos ] ) ;
        }
        
        pos = cen_prv[ 1 ][ pos ] ;
    }
}

void mrk ( int x , int prv , int coef , int ori ) {    
    update ( x , coef , get_dist ( 0 , ori , x ) ) ;
    for ( auto y : v[ 0 ][ x ] ) {
        if ( used[ 0 ][ y ] == true || y == prv ) { continue ; }
        mrk ( y , x , coef , ori ) ;
    }
}

void add_up ( int x , int prv , int ori ) {
    query ( x , get_dist ( 0 , ori , x ) ) ;
    for ( auto y : v[ 0 ][ x ] ) {
        if ( used[ 0 ][ y ] == true || y == prv ) { continue ; }
        add_up ( y , x , ori ) ;
    }
}


void decompose ( int x , int prv , int wh ) {
    tot = 0 ; init ( x , -1 , wh ) ;
    x = get_centroid ( x , -1 , wh ) ;
    cen_prv[ wh ][ x ] = prv ;
    used[ wh ][ x ] = true ;
    if ( wh == 0 ) {
        mrk ( x , -1 , 1 , x ) ;
        for ( auto y : v[ wh ][ x ] ) {
            if ( used[ wh ][ y ] == false ) {
                mrk ( y , -1 , -1 , x ) ;
                add_up ( y , -1 , x ) ;
            }
        }
        update ( x , -1 , 0 ) ;
    }
    for ( auto y : v[ wh ][ x ] ) {
        if ( used[ wh ][ y ] == false ) {
            decompose ( y , x , wh ) ;
        }
    }
}

void solve ( ) {
    cin >> n ;
    for ( int wh = 0 ; wh < 2 ; ++ wh ) {
        for ( int i = 1 , x , y ; i < n ; ++ i ) {
            cin >> x >> y ;
            v[ wh ][ x ].push_back ( y ) ;
            v[ wh ][ y ].push_back ( x ) ;
        }
    }
    for ( int wh = 0 ; wh < 2 ; ++ wh ) {
        dfs ( 1 , -1 , wh ) ;
    }
    for ( int wh = 0 ; wh < 2 ; ++ wh ) {
        for ( int i = 1 ; i <= 2 * n - 1 ; ++ i ) {
            rmq[ wh ][ 0 ][ i ] = ord[ wh ][ i - 1 ] ;
        }
        for ( int hh = 1 ; hh < LOG ; ++ hh ) {
            for ( int i = 1 ; i + ( 1 << hh ) - 1 <= 2 * n - 1 ; ++ i ) { 
                int cand1 = rmq[ wh ][ hh - 1 ][ i ] ;
                int cand2 = rmq[ wh ][ hh - 1 ][ i + ( 1 << ( hh - 1 ) ) ] ;
                if ( lvl[ wh ][ cand1 ] <= lvl[ wh ][ cand2 ] ) {
                    rmq[ wh ][ hh ][ i ] = cand1 ;
                }
                else {
                    rmq[ wh ][ hh ][ i ] = cand2 ; 
                }
            }
        }
    }
    for ( int i = 1 ; i <= 2 * n - 1 ; ++ i ) {
        mxst[ i ] = mxst[ i - 1 ] ;
        if ( 2 * ( 1 << mxst[ i ] ) < i ) { ++ mxst[ i ] ; }
    }
    decompose ( 1 , -1 , 1 ) ;
    decompose ( 1 , -1 , 0 ) ;
    cout << ans << "\n" ;
}

int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ; 
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1ms
memory: 24412kb

input:

3
1 2
1 3
1 2
1 3

output:

24

result:

ok 1 number(s): "24"

Test #2:

score: 0
Accepted
time: 2ms
memory: 20680kb

input:

3
1 2
1 3
1 2
2 3

output:

22

result:

ok 1 number(s): "22"

Test #3:

score: -100
Wrong Answer
time: 4ms
memory: 38732kb

input:

500
30 198
198 333
198 17
333 430
333 44
17 99
17 19
430 160
430 162
44 154
44 253
99 466
99 397
19 301
19 101
160 416
160 446
162 375
162 174
154 256
154 170
253 67
253 248
466 462
466 216
397 104
397 306
301 460
301 464
101 226
101 50
416 137
416 456
446 443
446 465
375 92
375 266
174 209
174 84
2...

output:

124304379

result:

wrong answer 1st numbers differ - expected: '75020868', found: '124304379'