QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#320653#8213. Graffitiucup-team197#WA 3ms27532kbC++173.8kb2024-02-03 19:39:142024-02-03 19:39:15

Judging History

你现在查看的是最新测评结果

  • [2024-02-03 19:39:15]
  • 评测
  • 测评结果:WA
  • 用时:3ms
  • 内存:27532kb
  • [2024-02-03 19:39:14]
  • 提交

answer

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ;
typedef vector < int > vi ;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MAXN = 3e5 + 7 ;

int n ;
string a ;
vector < int > v[ MAXN ] ;
int ch[ MAXN ] ;

// dp[ x ][ col ][ cnt ] = dp when x gets col and cnt children are side characters
// always optimal to have ( cnt + 1 ) / 2 of most common color
vector < ll > dp[ MAXN ][ 2 ] ; // 0 - side , 1 - mid
// in bad case - 0 - non-mid , 1 - mid

ll best[ MAXN ][ 2 ] ;
bool bad = false ;

void dfs ( int x , int prv ) {
    ch[ x ] = v[ x ].size ( ) ;
    if ( prv > 0 ) { -- ch[ x ] ; }
    dp[ x ][ 0 ].resize ( ch[ x ] + 1 , 0 ) ;
    dp[ x ][ 1 ].resize ( ch[ x ] + 1 , 0 ) ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        dfs ( y , x ) ;
    }
    for ( int col = 0 ; col < 2 ; ++ col ) {
        vector < pair < ll , int > > srt ;
        ll aux = 0 ;
        for ( auto y : v[ x ] ) {
            if ( y == prv ) { continue ; }
            best[ y ][ 0 ] = best[ y ][ 1 ] = 0 ;
            for ( int j = 0 ; j <= ch[ y ] ; ++ j ) {
                // y is 0 
                ll cand = dp[ y ][ 0 ][ j ] ;
                best[ y ][ 0 ] = max ( best[ y ][ 0 ] , cand ) ;

                // y is 1
                cand = dp[ y ][ 1 ][ j ] ;
                if ( bad == true ) {
                    if ( col == 0 ) { cand += ( ch[ y ] - j ) ; }
                    else { cand += j ; }
                }
                else if ( col == 0 ) {
                    if ( a[ 0 ] == a[ 2 ] ) { cand += 2 * j ; }
                    else { cand += ( j + 1 ) / 2 ; }
                }
                best[ y ][ 1 ] = max ( best[ y ][ 1 ] , cand ) ;
            }
            srt.push_back ( { best[ y ][ 0 ] - best[ y ][ 1 ] , y } ) ;
            aux = aux + best[ y ][ 1 ] ;
        }
        sort ( srt.begin ( ) , srt.end ( ) ) ;
        for ( int j = 0 ; j <= ch[ x ] ; ++ j ) {
            dp[ x ][ col ][ j ] = aux ;
            if ( bad == true ) {
                if ( col == 1 ) { dp[ x ][ col ][ j ] += 1LL * j * ( ch[ x ] - j ) ; }
            }
            else if ( col == 1 ) {
                if ( a[ 0 ] == a[ 2 ] ) { dp[ x ][ col ][ j ] += 1LL * j * ( j - 1 ) ; }
                else { dp[ x ][ col ][ j ] += 1LL * j * ( j - 1 ) / 2 ; }
            }
            if ( j < ch[ x ] ) { 
                aux += srt[ j ].fi ;
            }
            // printf ( "dp[ %d ][ %d ][ %d ] = %lld\n" , x , col , j , dp[ x ][ col ][ j ] ) ;
        }
    }
}

void solve ( ) {
    cin >> n ;
    cin >> a ;
    for ( int i = 1 , x , y ; i < n ; ++ i ) {
        cin >> x >> y ;
        v[ x ].push_back ( y ) ;
        v[ y ].push_back ( x ) ;
    }
    if ( a.size ( ) == 1 ) {
        cout << n << "\n" ;
        return ;
    }
    else if ( a.size ( ) == 2 ) {
        if ( a[ 0 ] == a[ 1 ] ) { cout << 2 * ( n - 1 ) << "\n" ; }
        else { cout << ( n - 1 ) << "\n" ; }
        return ;
    }
    if ( a[ 0 ] == a[ 1 ] && a[ 1 ] == a[ 2 ] ) {
        ll ans = 0 ;
        for ( int i = 1 ; i <= n ; ++ i ) {
            ll hh = v[ i ].size ( ) ;
            ans += hh * ( hh - 1 ) ;
        }
        cout << ans << "\n" ;
        return ;
    }
    if ( a[ 0 ] == a[ 1 ] || a[ 1 ] == a[ 2 ] ) { bad = true ; }
    dfs ( 1 , -1 ) ;
    int sz = v[ 1 ].size ( ) ;
    ll ans = 0 ;
    for ( int i = 0 ; i <= sz ; ++ i ) {
        ans = max ( ans , dp[ 1 ][ 0 ][ i ] ) ;
        ans = max ( ans , dp[ 1 ][ 1 ][ i ] ) ;
    }
    cout << ans << "\n" ;
}

int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ;
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}


Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 3ms
memory: 25940kb

input:

1
a

output:

1

result:

ok 1 number(s): "1"

Test #2:

score: 0
Accepted
time: 0ms
memory: 27120kb

input:

3
orz
1 2
2 3

output:

1

result:

ok 1 number(s): "1"

Test #3:

score: 0
Accepted
time: 0ms
memory: 26964kb

input:

2
ab
1 2

output:

1

result:

ok 1 number(s): "1"

Test #4:

score: 0
Accepted
time: 0ms
memory: 26936kb

input:

5
bob
3 2
5 1
1 4
2 4

output:

4

result:

ok 1 number(s): "4"

Test #5:

score: -100
Wrong Answer
time: 3ms
memory: 27532kb

input:

50
abc
23 14
24 25
1 3
47 46
2 26
22 41
34 19
7 14
50 24
29 38
17 25
4 26
35 37
21 14
11 4
13 27
8 25
5 10
20 27
44 27
15 39
19 9
30 12
38 27
39 27
41 40
14 48
32 7
16 37
3 13
42 5
48 27
49 25
6 5
26 9
31 17
36 7
43 29
9 5
45 9
18 9
40 42
27 5
25 42
46 10
37 42
12 48
28 26
33 5

output:

44

result:

wrong answer 1st numbers differ - expected: '37', found: '44'