QOJ.ac
QOJ
ID | 题目 | 提交者 | 结果 | 用时 | 内存 | 语言 | 文件大小 | 提交时间 | 测评时间 |
---|---|---|---|---|---|---|---|---|---|
#73615 | #1254. Biggest Set Ever | upobir | WA | 1ms | 3476kb | C++14 | 5.1kb | 2023-01-26 18:57:31 | 2023-01-26 18:57:31 |
Judging History
answer
#include<bits/stdc++.h>
using namespace std;
using ll = long long int;
//7340033 = 7*2^20, 645922817 = 77*2^23, G = 3
//897581057=107*2^23, 998244353=119*2^23, G = 3
namespace NTT {
vector<int> perm, wp[2]; int root, inv, N, invN;
const int mod = 998244353, G = 3; ///G prim root
int power(int a, int p) {
int ans = 1;
while (p) {
if (p & 1) ans = (1LL*ans*a)%mod;
a = (1LL*a*a)%mod; p >>= 1;
} return ans;
}
void precalculate(int n) {
assert( (n&(n-1)) == 0 && (mod-1)%n==0);
N = n; invN = power(N, mod-2);
perm = wp[0] = wp[1] = vector<int>(N);
perm[0] = 0;
for (int k=1; k<N; k<<=1)
for (int i=0; i<k; i++) {
perm[i] <<= 1; perm[i+k] = 1 + perm[i]; }
root=power(G,(mod-1)/N); inv=power(root, mod-2);
wp[0][0]=wp[1][0]=1;
for (int i=1; i<N; i++) {
wp[0][i] = (wp[0][i-1]*1LL*root)%mod;
wp[1][i] = (wp[1][i-1]*1LL*inv)%mod;
}
}
void fft(vector<int> &v, bool invert = false) {
if (v.size()!=perm.size())precalculate(v.size());
for (int i=0; i<N; i++)
if (i < perm[i]) swap(v[i], v[perm[i]]);
for (int len = 2; len <= N; len *= 2) {
for (int i=0, d = N/len; i<N; i+=len) {
for (int j=0, idx=0; j<len/2; j++, idx+=d) {
int x=v[i+j], y =
(wp[invert][idx]*1LL*v[i+j+len/2])%mod;
v[i+j] = (x+y>=mod ? x+y-mod : x+y);
v[i+j+len/2] = (x-y>=0 ? x-y : x-y+mod);
}
}
} if (invert) {
for (int &x : v) x = (x*1LL*invN)%mod; }
}
vector<int> multiply(vector<int> a, vector<int> b){
int n = 1; while (n < a.size()+ b.size()) n<<=1;
a.resize(n); b.resize(n);
fft(a); fft(b);
for (int i=0;i<n;i++) a[i]=(a[i]*1LL*b[i])%mod;
fft(a, true); return a;
}
vector<int> conv(vector<int> a, vector<int> b, int NN){
int n = 1; while (n < a.size()+ b.size()) n<<=1;
a.resize(n); b.resize(n);
fft(a); fft(b);
for (int i=0;i<n;i++) a[i]=(a[i]*1LL*b[i])%mod;
fft(a, true);
for(int i = NN; i<a.size(); i++){
a[i%NN] += a[i];
if(a[i%NN] >= mod) a[i%NN] -= mod;
}
a.resize(NN);
return a;
}
vector<int> selfconv(vector<int> a, int NN){
int n = 1; while (n < a.size()+ a.size()) n<<=1;
a.resize(n);
fft(a);
for (int i=0;i<n;i++) a[i]=(a[i]*1LL*a[i])%mod;
fft(a, true);
for(int i = NN; i<a.size(); i++){
a[i%NN] += a[i];
if(a[i%NN] >= mod) a[i%NN] -= mod;
}
a.resize(NN);
return a;
}
};
const int base = 1e9;
const ll mod = 998244353;
int divide(vector<int>& a, int b){
int carry = 0;
for (int i=(int)a.size()-1; i>=0; --i) {
long long cur = a[i] + carry * 1ll * base;
a[i] = int (cur / b);
carry = int (cur % b);
}
while (a.size() > 1 && a.back() == 0)
a.pop_back();
return carry;
}
vector<int> tonum(string& s){
vector<int> a;
for (int i=(int)s.length(); i>0; i-=9)
if (i < 9)
a.push_back (atoi (s.substr (0, i).c_str()));
else
a.push_back (atoi (s.substr (i-9, 9).c_str()));
return a;
}
int dp[2][10005];
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin>>n;
int rem;
cin>>rem;
assert(0 <= rem && rem < n && n <= 10000);
string ss;
cin>>ss;
assert(ss.size() <= 100000 && 1 <= ss.size());
assert(ss[0] != '0');
for(auto c : ss){
assert(isdigit(c));
}
auto T = tonum(ss);
auto tmp = T;
auto kk = divide(tmp, mod-1);
int pref = divide(T, n);
auto k = divide(T, mod-1);
if(k == 0){
cerr<<"noo"<<endl;
}
// pref--;
// if(pref < 0){
// pref += n;
// k--;
// if(k < 0)
// k += mod-1;
// }
// auto T = stoll(s);
// auto pref = T % n;
// T /= n;
// auto k = T % (mod-1);
vector<int> P(n), Full(n);
P[0] = 1;
// cout<<k<<" "<<pref<<endl;
dp[0][0] = 1;
int cur = 0;
for(int i = 0; i<n; i++){
cur^=1;
for(int s = 0; s<n; s++){
if(s-i >= 0)
dp[cur][s] = dp[cur^1][s] + dp[cur^1][s-i];
else
dp[cur][s] = dp[cur^1][s] + dp[cur^1][s-i+n];
if(dp[cur][s] >= mod)
dp[cur][s] -= mod;
}
if(i+1 == pref){
for(int j = 0; j<n; j++)
P[j] = dp[cur][j];
}
}
for(int i = 0; i<n; i++){
Full[i] = dp[cur][i];
}
for(int i = 0; i<n; i++)
cout<<P[i]<<" ";
cout<<endl;
for(int i = 0; i<n; i++)
cout<<Full[i]<<" ";
cout<<endl;
/// FULL^k
vector<int> ans(n);
ans[0] = 1;
while(k){
if(k&1){
ans = NTT::conv(ans, Full, n);
}
Full = NTT::selfconv(Full, n);
k >>= 1;
}
ans = NTT::conv(ans, P, n);
ll sum = 0;
for(int i = 0; i<n; i++){
sum = (sum + ans[i]) % mod;
}
assert(sum == NTT::power(2, kk));
assert(ans[rem] >= 0 && ans[rem] < mod);
cout<<ans[rem]<<endl;
return 0;
}
详细
Test #1:
score: 0
Wrong Answer
time: 1ms
memory: 3476kb
input:
3 2 5
output:
2 2 0 4 2 2 8
result:
wrong answer 1st lines differ - expected: '8', found: '2 2 0 '