QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#467851 | #1163. Another Tree Queries Problem | BINYU | WA | 127ms | 26948kb | C++14 | 5.8kb | 2024-07-08 17:52:00 | 2024-07-08 17:52:00 |
Judging History
answer
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5;
#define ll long long
ll calc(ll st,ll len,ll k)
{
// cout<<"calc : "<<st<<" "<<len<<" "<<k<<" "<< (2 * st + len * k - k) * len / 2<<"\n";
return (2 * st + len * k - k) * len / 2;
}
struct Segment_Tree
{
#define ls id << 1
#define rs id << 1 | 1
struct node
{
int dep;
ll len,sum,siz,ad1,ad2,ad3;
}a[4 * N + 5];
void pushdown(int id)
{
int ad1 = a[id].ad1,ad2 = a[id].ad2,ad3 = a[id].ad3;
if(ad1)
{
a[ls].ad1 += ad1;
a[ls].sum += ad1 * a[ls].len;
a[rs].ad1 += ad1;
a[rs].sum += ad1 * a[rs].len;
a[id].ad1 = 0;
}
if(ad2)
{
a[ls].ad2 += ad2;
a[ls].sum += calc(ad2,a[ls].len,ad2);
a[rs].ad2 += ad2;
a[rs].ad1 += a[ls].len * ad2;
a[rs].sum += a[ls].len * ad2 * a[rs].len;
a[rs].sum += calc(ad2 + a[ls].len * ad2,a[rs].len,ad2);
a[id].ad2 = 0;
}
if(ad3)
{
a[ls].ad3 += ad3;
a[ls].sum += ad3 * a[ls].siz;
a[rs].ad3 += ad3;
a[rs].sum += ad3 * a[rs].siz;
a[id].ad3 = 0;
}
}
void build(int id,int l,int r)
{
a[id].len = r - l + 1;
if(l == r)return;
int mid = l + r >> 1;
build(ls,l,mid);build(rs,mid + 1,r);
}
void update(int id,int l,int r,int x,ll siz,int dep)
{
if(l == r)
{
a[id].siz = siz;
a[id].dep = dep;
return;
}
int mid = l + r >> 1;
if(x <= mid)update(ls,l,mid,x,siz,dep);
else update(rs,mid + 1,r,x,siz,dep);
a[id].siz = a[ls].siz + a[rs].siz;
a[id].dep = a[ls].dep + a[rs].dep;
}
void add1(int id,int l,int r,int st,int en,ll v)
{
if(l >= st&&r <= en)
{
a[id].ad1 += v;
a[id].sum += v * (r - l + 1);
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
return;
}
int mid = l + r >> 1;pushdown(id);
if(st <= mid)add1(ls,l,mid,st,en,v);
if(en > mid)add1(rs,mid + 1,r,st,en,v);
a[id].sum = a[ls].sum + a[rs].sum;
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
}
int add2(int id,int l,int r,int st,int en,ll v)
{
if(l >= st&&r <= en)
{
a[id].ad1 += v;
a[id].ad2++;
a[id].sum += calc(v,r - l + 1,1);
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
return v + r - l + 1;
}
int mid = l + r >> 1;pushdown(id);
if(st <= mid)v = add2(ls,l,mid,st,en,v);
if(en > mid)v = add2(rs,mid + 1,r,st,en,v);
a[id].sum = a[ls].sum + a[rs].sum;
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
return v;
}
void add3(int id,int l,int r,int st,int en,ll v)
{
if(l >= st&&r <= en)
{
a[id].ad3 += v;
a[id].sum += v * a[id].siz;
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
return;
}
int mid = l + r >> 1;pushdown(id);
if(st <= mid)add3(ls,l,mid,st,en,v);
if(en > mid)add3(rs,mid + 1,r,st,en,v);
a[id].sum = a[ls].sum + a[rs].sum;
// cout<<l<<" "<<r<<" : "<<a[id].sum<<"\n";
}
ll query(int id,int l,int r,int st,int en)
{
if(l >= st&&r <= en)return a[id].sum;
int mid = l + r >> 1;pushdown(id);
ll res = 0;
if(st <= mid)res += query(ls,l,mid,st,en);
if(en > mid)res += query(rs,mid + 1,r,st,en);
return res;
}
int query_dep(int id,int l,int r,int st,int en)
{
if(l >= st&&r <= en)return a[id].dep;
int mid = l + r >> 1;pushdown(id);
ll res = 0;
if(st <= mid)res += query_dep(ls,l,mid,st,en);
if(en > mid)res += query_dep(rs,mid + 1,r,st,en);
return res;
}
#undef ls
#undef rs
}st;
int n,q,u,v,op;
int siz[N + 5],son[N + 5],dep[N + 5],f[20][N + 5];
int dfn[N + 5],cntdfn,top[N + 5];
vector <int> e[N + 5];
ll sum1,sum2;
void dfs1(int u,int fa)
{
siz[u] = 1;
f[0][u] = fa;
dep[u] = dep[fa] + 1;
for(auto v : e[u])
{
if(v == fa)continue;
dfs1(v,u);siz[u] += siz[v];
if(siz[v] > siz[son[u]])
son[u] = v;
}
}
void dfs2(int u,int t)
{
top[u] = t;
dfn[u] = ++cntdfn;
if(son[u])dfs2(son[u],t);
for(auto v : e[u])
if(!dfn[v])dfs2(v,v);
}
int lca(int u,int v)
{
if(dep[u] < dep[v])swap(u,v);
for(int i = 18;~i;i--)
if(dep[f[i][u]] >= dep[v])
u = f[i][u];
if(u == v)return u;
for(int i = 18;~i;i--)
if(f[i][u] != f[i][v])
u = f[i][u],v = f[i][v];
return f[0][u];
}
int jmp(int u,int d)
{
for(int i = 18;~i;i--)
{
if(dep[f[i][u]] > d)
u = f[i][u];
}
return u;
}
ll query(int u)
{
ll res = 0;
while(u)
res += st.query(1,1,n,dfn[top[u]],dfn[u]),
u = f[0][top[u]];
return res;
}
void add1(int u,ll v)
{
while(u)
st.add1(1,1,n,dfn[top[u]],dfn[u],v),
u = f[0][top[u]];
}
void add2(int u,int F)
{
sum1 += dep[u] - dep[F] + 1;
int now = 1;
while(top[u] != top[F])
sum2 += st.query_dep(1,1,n,dfn[top[u]],dfn[u]),
now = st.add2(1,1,n,dfn[top[u]],dfn[u],now),
u = f[0][top[u]];
sum2 += st.query_dep(1,1,n,dfn[F],dfn[u]);
st.add2(1,1,n,dfn[F],dfn[u],now);
}
void add3(int u,ll v)
{
sum1 += v * siz[u];
sum2 += v * st.query_dep(1,1,n,dfn[u],dfn[u] + siz[u] - 1);
st.add3(1,1,n,dfn[u],dfn[u] + siz[u] - 1,v);
add1(f[0][u],v * siz[u]);
}
int main()
{
scanf("%d",&n);
for(int i = 1;i < n;i++)
scanf("%d %d",&u,&v),
e[u].push_back(v),e[v].push_back(u);
dfs1(1,0);dfs2(1,1);
for(int i = 0;i < 18;i++)
for(int j = 1;j <= n;j++)
f[i + 1][j] = f[i][f[i][j]];
st.build(1,1,n);
for(int i = 1;i <= n;i++)
st.update(1,1,n,dfn[i],siz[i],dep[i]);
scanf("%d",&q);
while(q--)
{
scanf("%d",&op);
if(op == 1)
{
scanf("%d %d",&u,&v);
if(u == v)add3(1,1);
else if(lca(u,v) == v)
add3(1,1),add3(jmp(u,dep[v]),-1);
else add3(v,1);
}
else if(op == 2)
{
scanf("%d %d",&u,&v);
int F = lca(u,v);
sum2 += dep[F];sum1++;
if(F != u)add2(u,jmp(u,dep[F]));
if(F != v)add2(v,jmp(v,dep[F]));
int cnt = dep[u] + dep[v] - 2 * dep[F] + 1;
add1(F,cnt);
}
else
{
scanf("%d",&u);
ll sum3 = query(u);
cout<<sum1 * dep[u] + sum2 - 2 * sum3<<"\n";
}
}
}
/*
6
1 2
1 3
3 4
3 5
5 6
2
1 4 3
3 4
*/
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 0ms
memory: 26948kb
input:
5 4 2 2 5 1 5 1 3 5 2 2 4 3 4 2 1 5 2 5 5 3 2
output:
1 5
result:
ok 2 number(s): "1 5"
Test #2:
score: -100
Wrong Answer
time: 127ms
memory: 26696kb
input:
200 171 114 50 183 28 68 67 152 139 125 67 55 50 98 106 71 46 42 157 165 42 49 113 12 81 145 105 13 38 96 34 156 24 17 21 191 135 54 174 116 177 157 123 71 95 130 135 193 150 129 25 190 96 93 188 173 90 160 86 187 20 132 199 75 59 195 189 24 40 68 163 83 25 13 73 33 59 50 154 19 146 21 151 67 89 69 ...
output:
834 908 851 1841 2354 1386 3104 2471 4887 4009 4413 4009 6139 4879 6410 7689 4752 7741 7571 9975 8111 6134 11010 7216 13171 14835 9837 9231 8709 11659 18440 21642 16353 18620 11519 15553 12045 19736 12368 17545 23430 21726 25330 20101 15240 22646 14309 20460 14429 26308 19828 30676 20386 33312 19127...
result:
wrong answer 1st numbers differ - expected: '826', found: '834'