QOJ.ac
QOJ
ID | Problem | Submitter | Result | Time | Memory | Language | File size | Submit time | Judge time |
---|---|---|---|---|---|---|---|---|---|
#87804 | #5739. Super Meat Bros | xiaoyaowudi | RE | 5ms | 5896kb | C++17 | 46.5kb | 2023-03-14 11:49:24 | 2023-03-14 11:49:27 |
Judging History
answer
#define __AVX__ 1
#define __AVX2__ 1
#define __AVX512F__ 1
#define __AVX512DQ__ 1
#pragma GCC target("avx,avx2,avx512f,avx512dq")
// #pragma GCC target("avx,avx2")
#include <iostream>
#include <immintrin.h>
#include <cstdint>
#include <memory>
#include <new>
#include <vector>
#include <array>
#include <exception>
#include <cstring>
#include <algorithm>
namespace library
{
typedef std::uint32_t ui;
typedef std::int32_t i32;
typedef std::int64_t ll;
typedef std::uint64_t ull;
typedef __uint128_t u128;
typedef __int128 i128;
constexpr ui default_mod=998244353;
struct fast_mod_32 {
ull b, m;
fast_mod_32(ull b) : b(b), m(ull((u128(1) << 64) / b)) {}
fast_mod_32(){}
fast_mod_32(const fast_mod_32 &d):b(d.b),m(d.m){}
ull reduce(ull a) {
ull q = (ull)((u128(m) * a) >> 64);
ull r = a - q * b;
return r >= b ? r - b : r;
}
};
fast_mod_32 global_fast_mod_32(default_mod);
struct montgomery_mi_lib{
constexpr static ui calc_k(ui MOD,ui len){ui ans=1;for(ui i=1;i<len;++i) ans=(ans*(MOD+1)+1);return ans;}
ui P,P2,NP,Pk;static constexpr ui ui_len = sizeof(ui)*8;
montgomery_mi_lib(ui P0):P(P0),P2(P0*2),NP((-ull(P0))%P0),Pk(calc_k(P0,ui_len)){}
montgomery_mi_lib(){}
#define INLINE_OP __attribute__((__always_inline__))
INLINE_OP ui redd(ui k) const {return k>=P2?k-P2:k;}INLINE_OP ui reds(ui k) const {return k>=P?k-P:k;}INLINE_OP ui redu(ull k) const {return (k+ull(ui(k)*Pk)*P)>>ui_len;}
INLINE_OP ui add(ui a,ui b) const {return redd(a+b);}INLINE_OP ui sub(ui a,ui b) const {return redd(a-b+P2);}INLINE_OP ui mul(ui a,ui b) const {return redu(ull(a)*b);}
INLINE_OP ui neg(ui a) const {return redd(P2-a);}INLINE_OP ui v(ui a) const {return redu(ull(a)*NP);}INLINE_OP ui rv(ui a) const {return reds(redu(a));}
#undef INLINE_OP
};
struct montgomery_mi{
static montgomery_mi_lib mlib;
ui val;
void init(ui a){val=mlib.redu(ull(a)*mlib.NP);}
montgomery_mi(){val=0;} montgomery_mi(const montgomery_mi &a):val(a.val){} montgomery_mi(ui v):val(mlib.redu(ull(v)*mlib.NP)){}
montgomery_mi& operator=(const montgomery_mi &b) {val=b.val;return *this;}
montgomery_mi& operator+=(const montgomery_mi &b) {val=mlib.redd(val+b.val);return *this;}
montgomery_mi& operator-=(const montgomery_mi &b) {val=mlib.redd(val-b.val+mlib.P2);return *this;}
montgomery_mi operator+(const montgomery_mi &b) const {return montgomery_mi(*this)+=b;}
montgomery_mi operator-(const montgomery_mi &b) const {return montgomery_mi(*this)-=b;}
montgomery_mi operator*=(const montgomery_mi &b) {val=mlib.redu(ull(val)*b.val);return *this;}
montgomery_mi operator*(const montgomery_mi &b) const {return montgomery_mi(*this)*=b;}
montgomery_mi operator-() const {montgomery_mi b;return b-=(*this);}
ui real_val() const {return mlib.reds(mlib.redu(val));}
ui get_val() const {return val;}
friend std::istream& operator>>(std::istream &in, montgomery_mi &m_int){ui inp;in>>inp;m_int.init(inp);return in;}
friend std::ostream& operator<<(std::ostream &out, const montgomery_mi &m_int){out<<m_int.real_val();return out;}
};
static __attribute__((__always_inline__)) inline montgomery_mi ui2mi(ui k){
union ui2mi_impl
{
ui uval;
montgomery_mi mval;
ui2mi_impl(){}
}T;
T.uval=k;
return T.mval;
}
typedef montgomery_mi mi;
typedef montgomery_mi_lib lmi;
#if defined(__AVX__) && defined(__AVX2__)
struct montgomery_mm256_lib{
alignas(32) __m256i P,P2,NP,Pk;static constexpr ui ui_len=sizeof(ui)*8;
montgomery_mm256_lib(ui P0){P=_mm256_set1_epi32(P0),P2=_mm256_set1_epi32(P0*2),
NP=_mm256_set1_epi32(ui((-ull(P0))%P0)),Pk=_mm256_set1_epi32(montgomery_mi_lib::calc_k(P0,ui_len));}
montgomery_mm256_lib(){}
#define INLINE_OP __attribute__((__always_inline__))
INLINE_OP __m256i redd(__m256i k){__m256i a=_mm256_sub_epi32(k,P2);__m256i b=_mm256_and_si256(_mm256_cmpgt_epi32(_mm256_setzero_si256(),a),P2);return _mm256_add_epi32(a,b);}
INLINE_OP __m256i reds(__m256i k){__m256i a=_mm256_sub_epi32(k,P); __m256i b=_mm256_and_si256(_mm256_cmpgt_epi32(_mm256_setzero_si256(),a),P); return _mm256_add_epi32(a,b);}
INLINE_OP __m256i redu(__m256i k){return _mm256_srli_epi64(_mm256_add_epi64(_mm256_mul_epu32(_mm256_mul_epu32(k,Pk),P),k),32);}
INLINE_OP __m256i mul(__m256i k1,__m256i k2){
return _mm256_or_si256(redu(_mm256_mul_epu32(k1,k2)),_mm256_slli_epi64(redu(_mm256_mul_epu32(_mm256_srli_epi64(k1,32),_mm256_srli_epi64(k2,32))),32));}
INLINE_OP __m256i add(__m256i k1,__m256i k2){return redd(_mm256_add_epi32(k1,k2));}
INLINE_OP __m256i sub(__m256i k1,__m256i k2){return redd(_mm256_add_epi32(P2,_mm256_sub_epi32(k1,k2)));}
#undef INLINE_OP
};
struct montgomery_mm256_int{
static montgomery_mm256_lib mlib;
__m256i val;
void init(__m256i a){val=mlib.mul(a,mlib.NP);}
montgomery_mm256_int(){val=_mm256_setzero_si256();} montgomery_mm256_int(const montgomery_mm256_int &a):val(a.val){} montgomery_mm256_int(__m256i v):val(mlib.mul(v,mlib.NP)){}
montgomery_mm256_int(ui v){init(_mm256_set1_epi32(v));}
montgomery_mm256_int& operator=(const montgomery_mm256_int &b) {val=b.val;return *this;}
montgomery_mm256_int& operator+=(const montgomery_mm256_int &b) {val=mlib.add(val,b.val);return *this;}
montgomery_mm256_int& operator-=(const montgomery_mm256_int &b) {val=mlib.sub(val,b.val);return *this;}
montgomery_mm256_int operator+(const montgomery_mm256_int &b) const {return montgomery_mm256_int(*this)+=b;}
montgomery_mm256_int operator-(const montgomery_mm256_int &b) const {return montgomery_mm256_int(*this)-=b;}
montgomery_mm256_int operator*=(const montgomery_mm256_int &b) {val=mlib.mul(val,b.val);return *this;}
montgomery_mm256_int operator*(const montgomery_mm256_int &b) const {return montgomery_mm256_int(*this)*=b;}
montgomery_mm256_int operator-() const {montgomery_mm256_int b;return b-=(*this);}
__m256i real_val() const {return mlib.reds(_mm256_or_si256(_mm256_slli_epi64(mlib.redu(_mm256_srli_epi64(val,32)),32),
mlib.redu(_mm256_srli_epi64(_mm256_slli_epi64(val,32),32))));}
__m256i get_val() const {return val;}
};
typedef montgomery_mm256_int mai;
typedef montgomery_mm256_lib lma;
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
struct montgomery_mm512_lib{
alignas(64) __m512i P,P2,NP,Pk;static constexpr ui ui_len=sizeof(ui)*8;
montgomery_mm512_lib(ui P0){P=_mm512_set1_epi32(P0),P2=_mm512_set1_epi32(P0*2),
NP=_mm512_set1_epi32(ui((-ull(P0))%P0)),Pk=_mm512_set1_epi32(montgomery_mi_lib::calc_k(P0,ui_len));}
montgomery_mm512_lib(){}
#define INLINE_OP __attribute__((__always_inline__))
INLINE_OP __m512i redd(__m512i k){__m512i a=_mm512_sub_epi32(k,P2);return _mm512_mask_add_epi32(a,_mm512_movepi32_mask(a),a,P2);}
INLINE_OP __m512i reds(__m512i k){__m512i a=_mm512_sub_epi32(k,P);return _mm512_mask_add_epi32(a,_mm512_movepi32_mask(a),a,P);}
INLINE_OP __m512i redu(__m512i k){return _mm512_srli_epi64(_mm512_add_epi64(_mm512_mul_epu32(_mm512_mul_epu32(k,Pk),P),k),32);}
INLINE_OP __m512i mul(__m512i k1,__m512i k2){
return _mm512_or_si512(redu(_mm512_mul_epu32(k1,k2)),_mm512_slli_epi64(redu(_mm512_mul_epu32(_mm512_srli_epi64(k1,32),_mm512_srli_epi64(k2,32))),32));}
INLINE_OP __m512i add(__m512i k1,__m512i k2){return redd(_mm512_add_epi32(k1,k2));}
INLINE_OP __m512i sub(__m512i k1,__m512i k2){return redd(_mm512_add_epi32(P2,_mm512_sub_epi32(k1,k2)));}
#undef INLINE_OP
};
struct montgomery_mm512_int{
static montgomery_mm512_lib mlib;
__m512i val;
void init(__m512i a){val=mlib.mul(a,mlib.NP);}
montgomery_mm512_int(){val=_mm512_setzero_si512();} montgomery_mm512_int(const montgomery_mm512_int &a):val(a.val){} montgomery_mm512_int(__m512i v):val(mlib.mul(v,mlib.NP)){}
montgomery_mm512_int(ui v){init(_mm512_set1_epi32(v));}
montgomery_mm512_int& operator=(const montgomery_mm512_int &b) {val=b.val;return *this;}
montgomery_mm512_int& operator+=(const montgomery_mm512_int &b) {val=mlib.add(val,b.val);return *this;}
montgomery_mm512_int& operator-=(const montgomery_mm512_int &b) {val=mlib.sub(val,b.val);return *this;}
montgomery_mm512_int operator+(const montgomery_mm512_int &b) const {return montgomery_mm512_int(*this)+=b;}
montgomery_mm512_int operator-(const montgomery_mm512_int &b) const {return montgomery_mm512_int(*this)-=b;}
montgomery_mm512_int operator*=(const montgomery_mm512_int &b) {val=mlib.mul(val,b.val);return *this;}
montgomery_mm512_int operator*(const montgomery_mm512_int &b) const {return montgomery_mm512_int(*this)*=b;}
montgomery_mm512_int operator-() const {montgomery_mm512_int b;return b-=(*this);}
__m512i real_val() const {return mlib.reds(_mm512_or_si512(_mm512_slli_epi64(mlib.redu(_mm512_srli_epi64(val,32)),32),
mlib.redu(_mm512_srli_epi64(_mm512_slli_epi64(val,32),32))));}
__m512i get_val() const {return val;}
};
typedef montgomery_mm512_int m5i;
typedef montgomery_mm512_lib lm5;
#endif
ui global_mod_mi=default_mod;
void set_mod_mi(ui p);
#if defined(__AVX__) && defined(__AVX2__)
ui global_mod_mai=default_mod;
void set_mod_mai(ui p);
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
ui global_mod_m5i=default_mod;
void set_mod_m5i(ui p);
#endif
montgomery_mi_lib mi::mlib(default_mod);
void set_mod_mi(ui p){
mi::mlib=montgomery_mi_lib(p);
global_fast_mod_32=fast_mod_32(p);
global_mod_mi=p;
}
#if defined(__AVX__) && defined(__AVX2__)
montgomery_mm256_lib mai::mlib(default_mod);
void set_mod_mai(ui p){
mai::mlib=montgomery_mm256_lib(p);
global_mod_mai=p;
}
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
montgomery_mm512_lib m5i::mlib(default_mod);
void set_mod_m5i(ui p){
m5i::mlib=montgomery_mm512_lib(p);
global_mod_m5i=p;
}
#endif
template<typename T,size_t align_val>
struct aligned_delete {
void operator()(T* ptr) const {
operator delete[](ptr,std::align_val_t(align_val));
}
};
template<typename T,size_t align_val>
using aligned_array=std::unique_ptr<T[],aligned_delete<T,align_val>>;
template<typename T,size_t align_val>
aligned_array<T,align_val> create_aligned_array(size_t size){
return aligned_array<T,align_val>(new(std::align_val_t(align_val)) T[size]);
}
#define restrict __restrict
#define NTT_partition_size 10
typedef std::vector<mi> poly;
class polynomial_kernel_mtt;
class polynomial_kernel_ntt
{
friend class polynomial_kernel_mtt;
private:
static constexpr ui tmp_size=9;
aligned_array<ui,64> ws0,ws1,_inv,tt[tmp_size],num;ui P,G;
ui fn,fb,mx;
void release(){
ws0.reset();ws1.reset();
_inv.reset();num.reset();
fn=fb=mx=0;
for(ui i=0;i<tmp_size;++i) tt[i].reset();
}
ui _fastpow(ui a,ui b){ui ans=li.v(1),off=a;while(b){if(b&1) ans=li.mul(ans,off);off=li.mul(off,off);b>>=1;}return ans;}
void dif(ui* restrict p,ui n){
ui len=(1<<n);
ui* restrict ws=ws0.get();
if(len<16){
ui t1,t2;
for(ui l=len;l>=2;l>>=1) for(ui j=0,mid=(l>>1);j<len;j+=l){
ui restrict *p1=p+j,*p2=p+j+mid,*ww=ws+mid;
for(ui i=0;i<mid;++i,++p1,++p2,++ww) t1=*p1,t2=*p2,*p1=li.add(t1,t2),*p2=li.mul(li.sub(t1,t2),(*ww));
}
}else if(len<=(1<<NTT_partition_size)){
__m512i* pp=(__m512i*)p,*p1,*p2,*ww;
__m512i msk,val;__mmask16 smsk;
for(ui l=len;l>16;l>>=1){
ui mid=(l>>1);
for(ui j=0;j<len;j+=l){
p1=(__m512i*)(p+j),p2=(__m512i*)(p+j+mid),ww=(__m512i*)(ws+mid);
for(ui i=0;i<mid;i+=16,++p1,++p2,++ww){
__m512i x=*p1,y=*p2;
*p1=l5.add(x,y);
*p2=l5.mul(l5.sub(x,y),*ww);
}
}
}
val=_mm512_setr_epi32(ws[8],ws[8],ws[8],ws[8],
ws[8],ws[8],ws[8],ws[8],
ws[8],ws[9],ws[10],ws[11],
ws[12],ws[13],ws[14],ws[15]);
msk=_mm512_setr_epi32(0,0,0,0,0,0,0,0,P*2,P*2,P*2,P*2,P*2,P*2,P*2,P*2);
smsk=0xff00;
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_BADC);
__m512i y=_mm512_mask_sub_epi32(*pp,smsk,msk,*pp);
*pp=l5.mul(l5.add(x,y),val);
}
val=_mm512_setr_epi32(ws[4],ws[4],ws[4],ws[4],
ws[4],ws[5],ws[6],ws[7],
ws[4],ws[4],ws[4],ws[4],
ws[4],ws[5],ws[6],ws[7]);
smsk=0xf0f0;
msk=_mm512_setr_epi32(0,0,0,0,P*2,P*2,P*2,P*2,0,0,0,0,P*2,P*2,P*2,P*2);
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_CDAB);
__m512i y=_mm512_mask_sub_epi32(*pp,smsk,msk,*pp);
*pp=l5.mul(l5.add(x,y),val);
}
val=_mm512_setr_epi32(ws[2],ws[2],ws[2],ws[3],
ws[2],ws[2],ws[2],ws[3],
ws[2],ws[2],ws[2],ws[3],
ws[2],ws[2],ws[2],ws[3]);
msk=_mm512_setr_epi32(0,0,P*2,P*2,0,0,P*2,P*2,0,0,P*2,P*2,0,0,P*2,P*2);
pp=(__m512i*)p;
smsk=0xcccc;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_epi32(*pp,_MM_PERM_BADC);
__m512i y=_mm512_mask_sub_epi32(*pp,smsk,msk,*pp);
*pp=l5.mul(l5.add(x,y),val);
}
msk=_mm512_setr_epi32(0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2);
pp=(__m512i*)p;
smsk=0xaaaa;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_epi32(*pp,_MM_PERM_CDAB);
__m512i y=_mm512_mask_sub_epi32(*pp,smsk,msk,*pp);
*pp=l5.add(x,y);
}
}
else{
__m512i *p1=(__m512i*)(p),*p2=(__m512i*)(p+(len>>2)),*p3=(__m512i*)(p+(len>>1)),*p4=(__m512i*)(p+(len>>2)*3),*w1=(__m512i*)(ws0.get()+(len>>1)),
*w2=(__m512i*)(ws0.get()+(len>>1)+(len>>2)),*w3=(__m512i*)(ws0.get()+(len>>2));
for(ui i=0;i<(len>>2);i+=16,++p1,++p2,++p3,++p4,++w2,++w3,++w1){
__m512i x=(*(p1)),y=(*(p2)),z=(*(p3)),w=(*(p4));
__m512i r=l5.add(x,z),s=l5.mul(l5.sub(x,z),*w1);
__m512i t=l5.add(y,w),q=l5.mul(l5.sub(y,w),*w2);
(*(p1))=l5.add(r,t);(*(p2))=l5.mul(l5.sub(r,t),*w3);
(*(p3))=l5.add(s,q);(*(p4))=l5.mul(l5.sub(s,q),*w3);
}
dif(p,n-2);dif(p+(1<<(n-2)),n-2);dif(p+(1<<(n-1)),n-2);dif(p+(1<<(n-2))*3,n-2);
}
}
void dit(ui* restrict p,ui n,bool inverse_coef=true){
ui len=(1<<n);
ui* restrict ws=ws1.get();
if(len<16){
ui t1,t2;
for(ui l=2;l<=len;l<<=1) for(ui j=0,mid=(l>>1);j<len;j+=l){
ui restrict *p1=p+j,*p2=p+j+mid,*ww=ws+mid;
for(ui i=0;i<mid;++i,++p1,++p2,++ww) t1=*p1,t2=li.mul((*p2),(*ww)),*p1=li.add(t1,t2),*p2=li.sub(t1,t2);
}
ui co=_inv[len-1];ui* restrict p1=p;
for(ui i=0;i<len;++i,++p1) (*p1)=li.mul(co,(*p1));
}else if(len<=(1<<NTT_partition_size)){
__m512i* pp=(__m512i*)p,*p1,*p2,*ww;
__m512i msk,val;__mmask16 smsk;
msk=_mm512_setr_epi32(0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2,0,P*2);
smsk=0xaaaa;
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_epi32(*pp,_MM_PERM_CDAB);
__m512i y=_mm512_mask_sub_epi32(*pp,smsk,msk,*pp);
*pp=l5.add(x,y);
}
val=_mm512_setr_epi32(ws[2],ws[3],li.neg(ws[2]),li.neg(ws[3]),
ws[2],ws[3],li.neg(ws[2]),li.neg(ws[3]),
ws[2],ws[3],li.neg(ws[2]),li.neg(ws[3]),
ws[2],ws[3],li.neg(ws[2]),li.neg(ws[3]));
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_epi32(*pp,_MM_PERM_BABA);
__m512i y=_mm512_shuffle_epi32(*pp,_MM_PERM_DCDC);
*pp=l5.add(x,l5.mul(y,val));
}
val=_mm512_setr_epi32( ws[4], ws[5], ws[6], ws[7],
li.neg(ws[4]),li.neg(ws[5]),li.neg(ws[6]),li.neg(ws[7]),
ws[4], ws[5], ws[6], ws[7],
li.neg(ws[4]),li.neg(ws[5]),li.neg(ws[6]),li.neg(ws[7]));
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_CCAA);
__m512i y=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_DDBB);
*pp=l5.add(x,l5.mul(y,val));
}
val=_mm512_setr_epi32( ws[8], ws[9], ws[10], ws[11],
ws[12], ws[13], ws[14], ws[15],
li.neg(ws[8]), li.neg(ws[9]), li.neg(ws[10]),li.neg(ws[11]),
li.neg(ws[12]),li.neg(ws[13]),li.neg(ws[14]),li.neg(ws[15]));
pp=(__m512i*)p;
for(ui j=0;j<len;j+=16,++pp){
__m512i x=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_BABA);
__m512i y=_mm512_shuffle_i64x2(*pp,*pp,_MM_PERM_DCDC);
*pp=l5.add(x,l5.mul(y,val));
}
for(ui l=32;l<=len;l<<=1){
ui mid=(l>>1);
for(ui j=0;j<len;j+=l){
p1=(__m512i*)(p+j),p2=(__m512i*)(p+j+mid),ww=(__m512i*)(ws+mid);
for(ui i=0;i<mid;i+=16,++p1,++p2,++ww){
__m512i x=*p1,y=l5.mul(*p2,*ww);
*p1=l5.add(x,y);
*p2=l5.sub(x,y);
}
}
}
if(inverse_coef){
__m512i co=_mm512_set1_epi32(_inv[len-1]);
pp=(__m512i*)p;
for(ui i=0;i<len;i+=16,++pp) (*pp)=l5.mul(*pp,co);
}
}
else{
dit(p,n-2,false);dit(p+(1<<(n-2)),n-2,false);dit(p+(1<<(n-1)),n-2,false);dit(p+(1<<(n-2))*3,n-2,false);
__m512i *p1=(__m512i*)(p),*p2=(__m512i*)(p+(len>>2)),*p3=(__m512i*)(p+(len>>1)),*p4=(__m512i*)(p+(len>>2)*3),*w1=(__m512i*)(ws+(len>>1)),
*w2=(__m512i*)(ws+(len>>1)+(len>>2)),*w3=(__m512i*)(ws+(len>>2));
for(ui i=0;i<(len>>2);i+=16,++p1,++p2,++p3,++p4,++w2,++w3,++w1){
__m512i x=(*(p1)),y=(*(p2)),z=(*(p3)),w=(*(p4));
__m512i h=l5.mul(y,*w3),
k=l5.mul(w,*w3);
__m512i t=l5.mul(l5.add(z,k),*w1),q=l5.mul(l5.sub(z,k),*w2);
__m512i r=l5.add(x,h),s=l5.sub(x,h);
(*(p1))=l5.add(r,t);(*(p2))=l5.add(s,q);
(*(p3))=l5.sub(r,t);(*(p4))=l5.sub(s,q);
}
if(inverse_coef){
__m512i co=_mm512_set1_epi32(_inv[len-1]);
p1=(__m512i*)p;
for(ui i=0;i<len;i+=16,++p1) (*p1)=l5.mul(*p1,co);
}
}
}
void dif_xni(ui* restrict arr,ui n){
ui* restrict ws=ws0.get();
ui len=(1<<n);
if(len<=4){
for(ui i=0;i<len;++i) arr[i]=li.mul(arr[i],ws[(len<<1)+i]);
}else{
__m256i restrict *p1=(__m256i*)arr,*p2=(__m256i*)(ws+(len<<1));
for(ui i=0;i<len;i+=8,++p1,++p2) *p1=la.mul(*p1,*p2);
}
dif(arr,n);
}
void dit_xni(ui* restrict arr,ui n){
dit(arr,n);
ui* restrict ws=ws1.get();
ui len=(1<<n);
if(len<=4){
for(ui i=0;i<len;++i) arr[i]=li.mul(arr[i],ws[(len<<1)+i]);
}else{
__m256i restrict *p1=(__m256i*)arr,*p2=(__m256i*)(ws+(len<<1));
for(ui i=0;i<len;i+=8,++p1,++p2) *p1=la.mul(*p1,*p2);
}
}
void internal_mul(ui* restrict src1,ui* restrict src2,ui* restrict dst,ui m){
dif(src1,m);
dif(src2,m);
if((1<<m)<8){
for(ui i=0;i<(1<<m);++i) dst[i]=li.mul(src1[i],src2[i]);
}
else{
__m256i restrict *p1=(__m256i*)src1, *p2=(__m256i*)src2, *p3=(__m256i*)dst;
for(ui i=0;i<(1<<m);i+=8,++p1,++p2,++p3) *p3=la.mul(*p1,*p2);
}
dit(dst,m);
}
void internal_transpose_mul(ui* restrict src1,ui* restrict src2,ui* restrict dst,ui m){
std::reverse(src1,src1+(1<<m));
internal_mul(src1,src2,dst,m);
std::reverse(dst,dst+(1<<m));
}
void internal_inv(ui* restrict src,ui* restrict dst,ui* restrict tmp,ui* restrict tmp2,ui len){//10E(n) x^n->x^{2n}
if(len==1){dst[0]=_fastpow(src[0],P-2);return;}
internal_inv(src,dst,tmp,tmp2,len>>1);
std::memcpy(tmp,src,sizeof(ui)*len);std::memcpy(tmp2,dst,sizeof(ui)*(len>>1));std::memset(tmp2+(len>>1),0,sizeof(ui)*(len>>1));
std::memset(dst+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp,__builtin_ctz(len));dif(tmp2,__builtin_ctz(len));
if(len<=4){
for(ui i=0;i<len;++i) tmp[i]=li.mul(tmp[i],tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)tmp;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p2)=la.mul((*p1),(*p2));
}
dit(tmp,__builtin_ctz(len));std::memset(tmp,0,sizeof(ui)*(len>>1));dif(tmp,__builtin_ctz(len));
if(len<=4){
for(ui i=0;i<len;++i) tmp[i]=li.mul(tmp[i],tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)tmp;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p2)=la.mul((*p1),(*p2));
}
dit(tmp,__builtin_ctz(len));
if(len<=8){
for(ui i=(len>>1);i<len;++i) dst[i]=li.neg(tmp[i]);
}else{
__m256i restrict *p1=(__m256i*)(tmp+(len>>1)),*p2=(__m256i*)(dst+(len>>1));
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p2)=la.sub(_mm256_setzero_si256(),(*p1));
}
}
void internal_inv_faster(ui* restrict src,ui* restrict dst,ui* restrict tmp,ui* restrict tmp2,ui* restrict tmp3,ui len){//9E(n) x^n->x^{2n}
if(len==1){dst[0]=_fastpow(src[0],P-2);return;}
internal_inv_faster(src,dst,tmp,tmp2,tmp3,len>>1);
std::memcpy(tmp,src,sizeof(ui)*(len>>1));std::memcpy(tmp2,dst,sizeof(ui)*(len>>1));
if(len<=8){
ui mip=ws0[3];
for(ui i=0;i<(len>>1);++i) tmp[i]=li.add(tmp[i],li.mul(mip,src[i+(len>>1)]));
}
else{
__m256i mip=_mm256_set1_epi32(ws0[3]);
__m256i restrict *p1=(__m256i*)(src+(len>>1)),*p2=(__m256i*)tmp;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p2)=la.add((*p2),la.mul((*p1),mip));
}
dif_xni(tmp,__builtin_ctz(len>>1));
dif_xni(tmp2,__builtin_ctz(len>>1));
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp[i]=li.mul(li.mul(tmp2[i],tmp2[i]),tmp[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)tmp;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p2)=la.mul((*p2),la.mul((*p1),(*p1)));
}
dit_xni(tmp,__builtin_ctz(len>>1));
std::memcpy(tmp2,src,sizeof(ui)*len);std::memcpy(tmp3,dst,sizeof(ui)*(len>>1));std::memset(tmp3+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp2,__builtin_ctz(len));dif(tmp3,__builtin_ctz(len));
if(len<=8){
for(ui i=0;i<len;++i) tmp2[i]=li.mul(li.mul(tmp3[i],tmp3[i]),tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp3,*p2=(__m256i*)tmp2;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p2)=la.mul((*p2),la.mul((*p1),(*p1)));
}
dit(tmp2,__builtin_ctz(len));
if(len<=8){
ui mip=ws0[3],iv2=_inv[1];
for(ui i=0;i<(len>>1);++i) dst[i+(len>>1)]=li.mul(li.sub(li.mul(li.sub(li.add(tmp[i],tmp2[i]),li.add(dst[i],dst[i])),mip),tmp2[i+(len>>1)]),iv2);
}
else{
__m256i restrict *p1=(__m256i*)tmp,*p2=(__m256i*)(tmp2+(len>>1)),*p3=(__m256i*)(tmp2),*p4=(__m256i*)(dst+(len>>1)),*p5=(__m256i*)(dst);
__m256i mip=_mm256_set1_epi32(ws0[3]),iv2=_mm256_set1_epi32(_inv[1]);
for(ui i=0;i<(len>>1);i+=8,++p1,++p2,++p3,++p4,++p5) (*p4)=la.mul(la.sub(la.mul(la.sub(la.add((*p1),(*p3)),la.add((*p5),(*p5))),mip),(*p2)),iv2);
}
}
void internal_ln(ui* restrict src,ui* restrict dst,ui* restrict tmp1,ui* restrict tmp2,ui* restrict tmp3,ui len){
#if defined(__AVX__) && defined(__AVX2__)
ui pos=1;__m256i restrict *pp=(__m256i*)tmp1,*iv=(__m256i*)num.get();ui restrict *p1=src+1;
for(;pos+8<=len;pos+=8,p1+=8,++pp,++iv) *pp=la.mul(_mm256_loadu_si256((__m256i*)p1),*iv);
for(;pos<len;++pos) tmp1[pos-1]=li.mul(src[pos],num[pos-1]);tmp1[len-1]=li.v(0);
#else
ui restrict *p1=src+1,*p2=tmp1,*p3=num.get();
for(ui i=1;i<len;++i,++p1,++p2,++p3) *p2=li.mul((*p1),(*p3));
tmp1[len-1]=li.v(0);
#endif
internal_inv(src,dst,tmp2,tmp3,len);
std::memset(dst+len,0,sizeof(ui)*len);std::memset(tmp1+len,0,sizeof(ui)*len);
internal_mul(tmp1,dst,tmp2,__builtin_ctz(len<<1));
#if defined(__AVX__) && defined(__AVX2__)
ui ps=1;__m256i restrict *pp0=(__m256i*)tmp2,*iv0=(__m256i*)_inv.get();ui restrict *p10=dst+1;
for(;ps+8<=len;ps+=8,p10+=8,++pp0,++iv0) _mm256_storeu_si256((__m256i*)p10,la.mul(*pp0,*iv0));
dst[0]=li.v(0);
for(;ps<len;++ps) dst[ps]=li.mul(_inv[ps-1],tmp2[ps-1]);
#else
dst[0]=li.v(0);
ui restrict *p10=dst+1,*p20=tmp2,*p30=_inv.get();
for(ui i=1;i<len;++i,++p10,++p20,++p30) *p10=li.mul((*p20),(*p30));
#endif
}
void internal_ln_faster(ui* restrict src,ui* restrict dst,ui* restrict tmp1,ui* restrict tmp2,ui* restrict tmp3,ui* restrict tmp4,ui len){
if(len<4){internal_ln(src,dst,tmp1,tmp2,tmp3,len);return;}
#if defined(__AVX__) && defined(__AVX2__)
ui pos=1;__m256i restrict *pp=(__m256i*)tmp1,*iv=(__m256i*)num.get();ui restrict *p1=src+1;
for(;pos+8<=len;pos+=8,p1+=8,++pp,++iv) *pp=la.mul(_mm256_loadu_si256((__m256i*)p1),*iv);
for(;pos<len;++pos) tmp1[pos-1]=li.mul(src[pos],num[pos-1]);tmp1[len-1]=li.v(0);
#else
ui restrict *p1=src+1,*p2=tmp1,*p3=num.get();
for(ui i=1;i<len;++i,++p1,++p2,++p3) *p2=li.mul((*p1),(*p3));
tmp1[len-1]=li.v(0);
#endif
internal_inv_faster(src,dst,tmp2,tmp3,tmp4,(len>>1));//tmp4=F_{n/2}(g0 mod x^{n/4})
std::memcpy(dst,dst+(len>>2),sizeof(ui)*(len>>2));std::memset(dst+(len>>2),0,sizeof(ui)*(len>>2));
dif(dst,__builtin_ctz(len>>1));
std::memcpy(tmp2,tmp1+(len>>2),sizeof(ui)*(len>>2));std::memset(tmp2+(len>>2),0,sizeof(ui)*(len>>2));
dif(tmp2,__builtin_ctz(len>>1));
std::memset(tmp1+(len>>2),0,sizeof(ui)*(len>>2));
dif(tmp1,__builtin_ctz(len>>1));
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp2[i]=li.add(li.mul(tmp2[i],tmp4[i]),li.mul(dst[i],tmp1[i])),tmp1[i]=li.mul(tmp1[i],tmp4[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)tmp4,*p3=(__m256i*)dst,*p4=(__m256i*)tmp1;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2,++p3,++p4) (*p1)=la.add(la.mul((*p1),(*p2)),la.mul((*p3),(*p4))),(*p4)=la.mul((*p2),(*p4));
}
dit(tmp1,__builtin_ctz(len>>1));dit(tmp2,__builtin_ctz(len>>1));
if(len<=16){
for(ui i=0;i<(len>>2);++i) tmp1[i+(len>>2)]=li.add(tmp1[i+(len>>2)],tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)(tmp1+(len>>2)),*p2=(__m256i*)tmp2;
for(ui i=0;i<(len>>2);i+=8,++p1,++p2) (*p1)=la.add((*p1),(*p2));
}
std::memcpy(tmp2,tmp1,sizeof(ui)*(len>>1));
std::memset(tmp2+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp2,__builtin_ctz(len));std::memcpy(tmp3,src,sizeof(ui)*len);
dif(tmp3,__builtin_ctz(len));
if(len<=4){
for(ui i=0;i<len;++i) tmp3[i]=li.mul(tmp3[i],tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)(tmp3),*p2=(__m256i*)tmp2;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p1)=la.mul((*p1),(*p2));
}
dit(tmp3,__builtin_ctz(len));
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp3[i+(len>>1)]=li.sub(tmp1[i+(len>>1)],tmp3[i+(len>>1)]);
}
else{
__m256i restrict *p1=(__m256i*)(tmp3+(len>>1)),*p2=(__m256i*)(tmp1+(len>>1));
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p1)=la.sub((*p2),(*p1));
}
std::memcpy(tmp3,tmp3+(len>>2)*3,sizeof(ui)*(len>>2));std::memset(tmp3+(len>>2),0,sizeof(ui)*(len>>2));
std::memcpy(tmp2,tmp3+(len>>1),sizeof(ui)*(len>>2));std::memset(tmp2+(len>>2),0,sizeof(ui)*(len>>2));
dif(tmp3,__builtin_ctz(len>>1));dif(tmp2,__builtin_ctz(len>>1));
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp3[i]=li.add(li.mul(tmp3[i],tmp4[i]),li.mul(dst[i],tmp2[i])),tmp2[i]=li.mul(tmp2[i],tmp4[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp3,*p2=(__m256i*)tmp4,*p3=(__m256i*)dst,*p4=(__m256i*)tmp2;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2,++p3,++p4) (*p1)=la.add(la.mul((*p1),(*p2)),la.mul((*p3),(*p4))),(*p4)=la.mul((*p2),(*p4));
}
dit(tmp3,__builtin_ctz(len>>1));dit(tmp2,__builtin_ctz(len>>1));
std::memcpy(tmp1+(len>>1),tmp2,sizeof(mi)*(len>>2));
if(len<=16){
for(ui i=0;i<(len>>2);++i) tmp1[i+(len>>2)*3]=li.add(tmp2[i+(len>>2)],tmp3[i]);
}
else{
__m256i restrict *p1=(__m256i*)(tmp1+(len>>2)*3),*p2=(__m256i*)(tmp2+(len>>2)),*p3=(__m256i*)(tmp3);
for(ui i=0;i<(len>>2);i+=8,++p1,++p2,++p3) (*p1)=la.add((*p2),(*p3));
}
ui ps=1;__m256i restrict *pp0=(__m256i*)tmp1,*iv0=(__m256i*)_inv.get();ui restrict *p10=dst+1;
for(;ps+8<=len;ps+=8,p10+=8,++pp0,++iv0) _mm256_storeu_si256((__m256i*)p10,la.mul(*pp0,*iv0));
dst[0]=li.v(0);
for(;ps<len;++ps) dst[ps]=li.mul(_inv[ps-1],tmp1[ps-1]);
}
void internal_exp(ui* restrict src,ui* restrict dst,ui* restrict gn,ui* restrict gxni,
ui* restrict h,ui* restrict tmp1,ui* restrict tmp2,ui* restrict tmp3,ui len,bool calc_h=false){
if(len==1){dst[0]=li.v(1);return;}
else if(len==2){dst[0]=li.v(1);dst[1]=src[1];gn[0]=li.add(dst[0],dst[1]),gn[1]=li.sub(dst[0],dst[1]);gxni[0]=li.add(li.mul(dst[1],ws0[3]),dst[0]);h[0]=li.v(1);h[1]=li.neg(dst[1]);return;}
internal_exp(src,dst,gn,gxni,h,tmp1,tmp2,tmp3,(len>>1),true);
{
ui pos=1;__m256i restrict *pp=(__m256i*)tmp1,*iv=(__m256i*)num.get();ui restrict *p1=src+1;
for(;pos+8<=(len>>1);pos+=8,p1+=8,++pp,++iv) *pp=la.mul(_mm256_loadu_si256((__m256i*)p1),*iv);
for(;pos<(len>>1);++pos) tmp1[pos-1]=li.mul(src[pos],num[pos-1]);tmp1[(len>>1)-1]=li.v(0);
}
dif(tmp1,__builtin_ctz(len>>1));
if(len<8){
for(ui i=0;i<(len>>1);++i) tmp1[i]=li.mul(tmp1[i],gn[i]);
}else{
__m256i restrict *p1=(__m256i*)(tmp1),*p2=(__m256i*)(gn);
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p1)=la.mul((*p1),(*p2));
}
dit(tmp1,__builtin_ctz(len>>1));
{
ui pos=1;__m256i restrict *pp=(__m256i*)tmp1,*iv=(__m256i*)num.get();ui restrict *p1=dst+1;
for(;pos+8<=(len>>1);pos+=8,p1+=8,++pp,++iv) *pp=la.sub(la.mul(_mm256_loadu_si256((__m256i*)p1),*iv),(*pp));
for(;pos<(len>>1);++pos) tmp1[pos-1]=li.sub(li.mul(dst[pos],num[pos-1]),tmp1[pos-1]);tmp1[(len>>1)-1]=li.neg(tmp1[(len>>1)-1]);
}
std::memmove(tmp1+1,tmp1,sizeof(ui)*(len>>1));tmp1[0]=tmp1[(len>>1)];
std::memset(tmp1+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp1,__builtin_ctz(len));
std::memcpy(tmp3,h,sizeof(ui)*(len>>1));std::memset(tmp3+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp3,__builtin_ctz(len));
if(len<8){
for(ui i=0;i<len;++i) tmp1[i]=li.mul(tmp3[i],tmp1[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp1,*p2=(__m256i*)tmp3;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p1)=la.mul((*p1),(*p2));
}
dit(tmp1,__builtin_ctz(len));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp2[i]=li.sub(src[i+(len>>1)],li.mul(_inv[i+(len>>1)-1],tmp1[i]));
}else{
__m256i restrict *p1=(__m256i*)tmp1,*p3=(__m256i*)(tmp2),*p4=(__m256i*)(src+(len>>1));ui* restrict p2=_inv.get()+(len>>1)-1;
for(ui i=0;i<(len>>1);i+=8,++p1,p2+=8,++p3,++p4) (*p3)=la.sub((*p4),la.mul((*p1),_mm256_loadu_si256((__m256i*)p2)));
}
#else
{
for(ui i=0;i<(len>>1);++i) tmp2[i]=li.sub(src[i+(len>>1)],li.mul(_inv[i+(len>>1)-1],tmp1[i]));
}
#endif
std::memset(tmp2+(len>>1),0,sizeof(ui)*(len>>1));
dif(tmp2,__builtin_ctz(len));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=16){
ui mip=ws1[3];
for(ui i=0;i<(len>>2);++i) tmp1[i]=li.mul(li.mul(li.add(dst[i],li.mul(dst[i+(len>>2)],mip)),ws0[(len>>1)+i]),ws0[(len>>2)+i]);
}else{
__m256i restrict *p1=(__m256i*)dst,*p2=(__m256i*)(dst+(len>>2)),*p3=(__m256i*)(tmp1),*p4=(__m256i*)(ws0.get()+(len>>1)),*p5=(__m256i*)(ws0.get()+(len>>2));
__m256i mip=_mm256_set1_epi32(ws1[3]);
for(ui i=0;i<(len>>2);i+=8,++p1,++p2,++p3,++p4,++p5) (*p3)=la.mul(la.add((*p1),la.mul((*p2),mip)),la.mul((*p4),(*p5)));
}
#else
{
ui mip=ws1[3];
for(ui i=0;i<(len>>2);++i) tmp1[i]=li.mul(li.mul(li.add(dst[i],li.mul(dst[i+(len>>2)],mip)),ws0[(len>>1)+i]),ws0[(len>>2)+i]);
}
#endif
dif(tmp1,__builtin_ctz(len>>2));
std::memcpy(tmp1+(len>>2)*3,tmp1,sizeof(ui)*(len>>2));
std::memcpy(tmp1,gn,sizeof(ui)*(len>>1));
std::memcpy(tmp1+(len>>1),gxni,sizeof(ui)*(len>>2));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=4){
for(ui i=0;i<len;++i) tmp1[i]=li.mul(tmp2[i],tmp1[i]);
}else{
__m256i restrict *p1=(__m256i*)tmp1,*p2=(__m256i*)(tmp2);
for(ui i=0;i<len;i+=8,++p1,++p2) (*p1)=la.mul((*p1),(*p2));
}
#else
for(ui i=0;i<len;++i) tmp1[i]=li.mul(tmp2[i],tmp1[i]);
#endif
dit(tmp1,__builtin_ctz(len));
std::memcpy(dst+(len>>1),tmp1,sizeof(ui)*(len>>1));
//inv iteration start
if(!calc_h) return;
std::memcpy(gxni,dst,sizeof(ui)*(len>>1));std::memcpy(tmp2,h,sizeof(ui)*(len>>1));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=8){
ui mip=ws0[3];
for(ui i=0;i<(len>>1);++i) gxni[i]=li.add(gxni[i],li.mul(mip,dst[i+(len>>1)]));
}
else{
__m256i mip=_mm256_set1_epi32(ws0[3]);
__m256i restrict *p1=(__m256i*)(dst+(len>>1)),*p2=(__m256i*)gxni;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p2)=la.add((*p2),la.mul((*p1),mip));
}
#else
{
ui mip=ws0[3];
for(ui i=0;i<(len>>1);++i) gxni[i]=li.add(gxni[i],li.mul(mip,dst[i+(len>>1)]));
}
#endif
dif_xni(gxni,__builtin_ctz(len>>1));
dif_xni(tmp2,__builtin_ctz(len>>1));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=8){
for(ui i=0;i<(len>>1);++i) tmp2[i]=li.mul(li.mul(tmp2[i],gxni[i]),tmp2[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)gxni;
for(ui i=0;i<(len>>1);i+=8,++p1,++p2) (*p1)=la.mul((*p2),la.mul((*p1),(*p1)));
}
#else
for(ui i=0;i<(len>>1);++i) tmp2[i]=li.mul(li.mul(tmp2[i],gxni[i]),tmp2[i]);
#endif
dit_xni(tmp2,__builtin_ctz(len>>1));
std::memcpy(gn,dst,sizeof(ui)*len);
dif(gn,__builtin_ctz(len));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=8){
for(ui i=0;i<len;++i) tmp3[i]=li.mul(li.mul(tmp3[i],gn[i]),tmp3[i]);
}
else{
__m256i restrict *p1=(__m256i*)tmp3,*p2=(__m256i*)gn;
for(ui i=0;i<len;i+=8,++p1,++p2) (*p1)=la.mul((*p2),la.mul((*p1),(*p1)));
}
#else
for(ui i=0;i<len;++i) tmp3[i]=li.mul(li.mul(tmp3[i],gn[i]),tmp3[i]);
#endif
dit(tmp3,__builtin_ctz(len));
#if defined(__AVX__) && defined(__AVX2__)
if(len<=8){
ui mip=ws0[3],iv2=_inv[1];
for(ui i=0;i<(len>>1);++i) h[i+(len>>1)]=li.mul(li.sub(li.mul(li.sub(li.add(tmp2[i],tmp3[i]),li.add(h[i],h[i])),mip),tmp3[i+(len>>1)]),iv2);
}
else{
__m256i restrict *p1=(__m256i*)tmp2,*p2=(__m256i*)(tmp3+(len>>1)),*p3=(__m256i*)(tmp3),*p4=(__m256i*)(h+(len>>1)),*p5=(__m256i*)(h);
__m256i mip=_mm256_set1_epi32(ws0[3]),iv2=_mm256_set1_epi32(_inv[1]);
for(ui i=0;i<(len>>1);i+=8,++p1,++p2,++p3,++p4,++p5) (*p4)=la.mul(la.sub(la.mul(la.sub(la.add((*p1),(*p3)),la.add((*p5),(*p5))),mip),(*p2)),iv2);
}
#else
{
ui mip=ws0[3],iv2=_inv[1];
for(ui i=0;i<(len>>1);++i) h[i+(len>>1)]=li.mul(li.sub(li.mul(li.sub(li.add(tmp2[i],tmp3[i]),li.add(h[i],h[i])),mip),tmp3[i+(len>>1)]),iv2);
}
#endif
}
void internal_multipoint_eval_interpolation_calc_Q(std::vector<poly> &Q_storage,const poly &input_coef,ui l,ui r,ui id);
void internal_multipoint_eval_interpolation_calc_P(const std::vector<poly> &Q_storage,std::vector<poly> &P_stack,poly &result_coef,ui l,ui r,ui id,ui dep);
void internal_lagrange_interpolation_dvc_mul(ui l,ui r,const poly &a,ui id,std::vector<std::pair<poly,poly>> &R_storage);
void internal_lagrange_interpolation_calc_P(const std::vector<std::pair<poly,poly>> &R_storage,std::vector<poly> &P_stack,poly &result_coef,ui l,ui r,ui id,ui dep);
poly internal_lagrange_interpolation_dvc_mul_ans(ui l,ui r,const poly &a,ui id,const std::vector<std::pair<poly,poly>> &R_storage);
lmi li;
#if defined(__AVX__) && defined(__AVX2__)
lma la;
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
lm5 l5;
#endif
public:
polynomial_kernel_ntt(ui max_conv_size,ui P0,ui G0){init(max_conv_size,P0,G0);}
void init(ui max_conv_size,ui P0,ui G0){
if(max_conv_size>=(1u<<30)) throw std::runtime_error("invalid range!");
max_conv_size=std::max(max_conv_size,16u);
li=lmi(P0);
#if defined(__AVX__) && defined(__AVX2__)
la=lma(P0);
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
l5=lm5(P0);
#endif
release();P=P0,G=G0;mx=max_conv_size;
fn=1;fb=0;while(fn<(max_conv_size<<1)) fn<<=1,++fb;
if((P0-1)%fn) throw std::runtime_error("invalid range!");
_inv=create_aligned_array<ui,64>(fn+32);ws0 =create_aligned_array<ui,64>(fn+32);
ws1 =create_aligned_array<ui,64>(fn+32);num =create_aligned_array<ui,64>(fn+32);
for(ui i=0;i<tmp_size;++i) tt[i] =create_aligned_array<ui,64>(fn+32);
_inv[0]=li.v(1);for(ui i=2;i<=fn+32;++i) _inv[i-1]=li.mul(li.v(P-P/i),_inv[(P%i)-1]);
for(ui i=1;i<=fn+32;++i) num[i-1]=li.v(i);
ui j0=_fastpow(li.v(G),(P-1)/fn),j1=_fastpow(_fastpow(li.v(G),(P-2)),(P-1)/fn);
for(ui mid=(fn>>1);mid>=1;mid>>=1,j0=li.mul(j0,j0),j1=li.mul(j1,j1)){
ui w0=li.v(1),w1=li.v(1);
for(ui i=0;i<mid;++i,w0=li.mul(w0,j0),w1=li.mul(w1,j1)) ws0[i+mid]=w0,ws1[i+mid]=w1;
}
}
polynomial_kernel_ntt(const polynomial_kernel_ntt &d);
polynomial_kernel_ntt(){fn=fb=mx=0;}
~polynomial_kernel_ntt(){release();}
poly rev(const poly &a){
poly ret(a);
std::reverse(ret.begin(),ret.end());
return ret;
}
poly mul(const poly &a,const poly &b){
ui la=a.size(),lb=b.size();if((!la) && (!lb)) return poly();
if(la>mx || lb>mx) throw std::runtime_error("Convolution size out of range!");
ui m=0;if(la+lb>2) m=32-__builtin_clz(la+lb-2);
std::memcpy(tt[0].get(),&a[0],sizeof(ui)*la);std::memset(tt[0].get()+la,0,sizeof(ui)*((1<<m)-la));
std::memcpy(tt[1].get(),&b[0],sizeof(ui)*lb);std::memset(tt[1].get()+lb,0,sizeof(ui)*((1<<m)-lb));
internal_mul(tt[0].get(),tt[1].get(),tt[2].get(),m);
poly ret(la+lb-1);
std::memcpy(&ret[0],tt[2].get(),sizeof(ui)*(la+lb-1));
return ret;
}
poly transpose_mul(const poly &a,const poly &b){
ui la=a.size(),lb=b.size();if((!la) && (!lb)) return poly();
if(la>mx || lb>mx) throw std::runtime_error("Convolution size out of range!");
ui m=0;if(la+lb>2) m=32-__builtin_clz(la+lb-2);
std::memcpy(tt[0].get(),&a[0],sizeof(ui)*la);std::memset(tt[0].get()+la,0,sizeof(ui)*((1<<m)-la));
std::memcpy(tt[1].get(),&b[0],sizeof(ui)*lb);std::memset(tt[1].get()+lb,0,sizeof(ui)*((1<<m)-lb));
internal_transpose_mul(tt[0].get(),tt[1].get(),tt[2].get(),m);
poly ret(la);
std::memcpy(&ret[0],tt[2].get(),sizeof(ui)*(la));
return ret;
}
poly multipoint_eval_interpolation(const poly &a,const poly &b);
poly lagrange_interpolation(const std::vector<std::pair<mi,mi>> &a);
poly inv(const poly &src){
ui la=src.size();if(!la) throw std::runtime_error("Inversion calculation of empty polynomial!");
if((la*4)>fn) throw std::runtime_error("Inversion calculation size out of range!");
if(!li.rv(src[0].get_val())){
throw std::runtime_error("Inversion calculation of polynomial which has constant not equal to 1!");
}
ui m=0;if(la>1) m=32- __builtin_clz(la-1);
std::memcpy(tt[0].get(),&src[0],sizeof(ui)*la);std::memset(tt[0].get()+la,0,sizeof(ui)*((1<<m)-la));
// internal_inv(tt[0].get(),tt[1].get(),tt[2].get(),tt[3].get(),(1<<m));
internal_inv_faster(tt[0].get(),tt[1].get(),tt[2].get(),tt[3].get(),tt[4].get(),(1<<m));
poly ret(la);
std::memcpy(&ret[0],tt[1].get(),sizeof(ui)*la);
return ret;
}
poly ln(const poly &src){
ui la=src.size();if(!la) throw std::runtime_error("Ln calculation of empty polynomial!");
if((la*2)>fn) throw std::runtime_error("Ln calculation size out of range!");
if(li.rv(src[0].get_val())!=1){
throw std::runtime_error("Ln calculation of polynomial which has constant not equal to 1!");
}
ui m=0;if(la>1) m=32- __builtin_clz(la-1);
std::memcpy(tt[0].get(),&src[0],sizeof(ui)*la);std::memset(tt[0].get()+la,0,sizeof(ui)*((1<<m)-la));
// internal_ln(tt[0].get(),tt[1].get(),tt[2].get(),tt[3].get(),tt[4].get(),(1<<m));
internal_ln_faster(tt[0].get(),tt[1].get(),tt[2].get(),tt[3].get(),tt[4].get(),tt[5].get(),(1<<m));
poly ret(la);
std::memcpy(&ret[0],tt[1].get(),sizeof(ui)*la);
return ret;
}
poly exp(const poly &src){
ui la=src.size();if(!la) throw std::runtime_error("Exp calculation of empty polynomial!");
if((la*2)>fn) throw std::runtime_error("Exp calculation size out of range!");
if(li.rv(src[0].get_val())!=0){
throw std::runtime_error("Exp calculation of polynomial which has constant not equal to 0!");
}
ui m=0;if(la>1) m=32- __builtin_clz(la-1);
std::memcpy(tt[0].get(),&src[0],sizeof(ui)*la);std::memset(tt[0].get()+la,0,sizeof(ui)*((1<<m)-la));
internal_exp(tt[0].get(),tt[1].get(),tt[2].get(),tt[3].get(),tt[4].get(),tt[5].get(),tt[6].get(),tt[7].get(),(1<<m));
poly ret(la);
std::memcpy(&ret[0],tt[1].get(),sizeof(ui)*la);
return ret;
}
poly derivative(const poly &a){
if(!a.size()) throw std::runtime_error("Derivative calculation of empty polynomial!");
if(a.size()>fn) throw std::runtime_error("Derivative calculation size out of range!");
ui len=a.size();poly ret(len-1);
for(ui i=0;i<len-1;++i) ret[i]=ui2mi(li.mul(li.v(i+1),a[i+1].get_val()));
return ret;
}
poly integrate(const poly &a){
if(!a.size()) throw std::runtime_error("Integrate calculation of empty polynomial!");
if(a.size()>=fn) throw std::runtime_error("Integrate calculation size out of range!");
ui len=a.size();poly ret(len+1);ret[0]=ui2mi(li.v(0));
for(ui i=1;i<=len;++i) ret[i]=ui2mi(li.mul(_inv[i-1],a[i-1].get_val()));
return ret;
}
poly add(const poly &a,const poly &b){
ui la=a.size(),lb=b.size(),len=std::max(la,lb);poly ret(len);
std::memcpy(&ret[0],&b[0],sizeof(mi)*(lb));
for(ui i=0;i<la;++i) ret[i]=ui2mi(li.add(ret[i].get_val(),a[i].get_val()));
return ret;
}
poly sub(const poly &a,const poly &b){
ui la=a.size(),lb=b.size(),len=std::max(la,lb);poly ret(len);
std::memcpy(&ret[0],&a[0],sizeof(mi)*(la));
for(ui i=0;i<lb;++i) ret[i]=ui2mi(li.sub(ret[i].get_val(),b[i].get_val()));
return ret;
}
};
class polynomial_kernel_mtt
{
private:
static constexpr ui P1=167772161,G1=3,P2=469762049,G2=3,P3=754974721,G3=11,I1=104391568,I2=190329765;
polynomial_kernel_ntt k1,k2,k3;ui P,fn;lmi li,li1,li2,li3;
fast_mod_32 F,F1,F2,F3;
aligned_array<ui,64> _inv;
void release(){
k1.release();k2.release();k3.release();P=fn=0;
_inv.reset();
}
ui _fastpow(ui a,ui b){ui ans=li.v(1),off=a;while(b){if(b&1) ans=li.mul(ans,off);off=li.mul(off,off);b>>=1;}return ans;}
public:
void init(ui max_conv_size,ui P0){
if(P0>=(1u<<30) || !P0) throw std::runtime_error("invalid prime!");
// if(!factorization::miller_rabin_u32(P0)) throw std::runtime_error("invalid prime!");
try{
release();
k1.init(max_conv_size,P1,G1);
k2.init(max_conv_size,P2,G2);
k3.init(max_conv_size,P3,G3);
P=P0;fn=k1.fn;li=lmi(P);
F=fast_mod_32(P);
if(P<=fn+32) throw std::runtime_error("invalid prime!");
_inv=create_aligned_array<ui,64>(fn+32);
_inv[0]=li.v(1);for(ui i=2;i<=fn+32;++i) _inv[i-1]=li.mul(li.v(P-P/i),_inv[(P%i)-1]);
}catch(std::exception &e){
P=0;
throw std::runtime_error(e.what());
}
}
polynomial_kernel_mtt(ui max_conv_size,ui P0):F1(P1),F2(P2),F3(P3),li1(P1),li2(P2),li3(P3){
init(max_conv_size,P0);
}
polynomial_kernel_mtt(const polynomial_kernel_mtt &d):
P(d.P),fn(d.fn),k1(d.k1),k2(d.k2),k3(d.k3),F1(P1),F2(P2),F3(P3),F(d.F),li(d.li),li1(P1),li2(P2),li3(P3){}
polynomial_kernel_mtt():F1(P1),F2(P2),F3(P3),li1(P1),li2(P2),li3(P3){
P=fn=0;
}
~polynomial_kernel_mtt(){release();}
poly add(const poly &a,const poly &b){
ui la=a.size(),lb=b.size(),len=std::max(la,lb);poly ret(len);
std::memcpy(&ret[0],&b[0],sizeof(mi)*(lb));
for(ui i=0;i<la;++i) ret[i]=ui2mi(li.add(ret[i].get_val(),a[i].get_val()));
return ret;
}
poly sub(const poly &a,const poly &b){
ui la=a.size(),lb=b.size(),len=std::max(la,lb);poly ret(len);
std::memcpy(&ret[0],&a[0],sizeof(mi)*(la));
for(ui i=0;i<lb;++i) ret[i]=ui2mi(li.sub(ret[i].get_val(),b[i].get_val()));
return ret;
}
poly mul(const poly &a,const poly &b){
ui la=a.size(),lb=b.size();if((!la) && (!lb)) return poly();
if((la+lb-1)>fn) throw std::runtime_error("Convolution size out of range!");
poly a1(a.size()),a2(a.size()),a3(a.size()),b1(b.size()),b2(b.size()),b3(b.size());
for(ui i=0;i<a.size();++i){
ui ra=li.rv(a[i].get_val());
a1[i]=ui2mi(li1.v(ra));
a2[i]=ui2mi(li2.v(ra));
a3[i]=ui2mi(li3.v(ra));
}
for(ui i=0;i<b.size();++i){
ui rb=li.rv(b[i].get_val());
b1[i]=ui2mi(li1.v(rb));
b2[i]=ui2mi(li2.v(rb));
b3[i]=ui2mi(li3.v(rb));
}
poly r1=k1.mul(a1,b1),r2=k2.mul(a2,b2),r3=k3.mul(a3,b3);
poly ret(la+lb-1);
ui I3=F.reduce(1ull*P1*P2);
for(int i=0;i<la+lb-1;++i){
ui x1=li1.rv(r1[i].get_val());
ui x2=li2.rv(r2[i].get_val());
ui x3=li3.rv(r3[i].get_val());
ui k1=F2.reduce(1ull*(x2-x1+P2)*I1);
ull x4=1ull*k1*P1+x1;
ui k2=F3.reduce((x3-F3.reduce(x4)+P3)*I2);
ui x=F.reduce(x4+1ull*k2*I3);
ret[i]=ui2mi(li.v(x));
}
return ret;
}
poly inv(const poly &a)
{
ui l(a.size());if(!l) throw std::runtime_error("Inversion calculation of empty polynomial!");
if(l==1) return {ui2mi(_fastpow(a[0].get_val(),P-2))};
poly b(a);ui k((l+1)/2);b.resize(k);
b=inv(b);poly r(mul(b,a));r.resize(l);
for(mi &v:r) v=ui2mi(li.neg(v.get_val()));
r[0]=ui2mi(li.add(li.v(2),r[0].get_val()));
r=mul(b,r);r.resize(l);return r;
}
poly ln(const poly &a)
{
ui l(a.size());poly b(l-1);
for(ui i(1);i<l;++i) b[i-1]=ui2mi(li.mul(a[i].get_val(),li.v(i)));
b=mul(b,inv(a));
poly r(l);r[0]=ui2mi(li.v(0));
for(ui i(1);i<l;++i) r[i]=ui2mi(li.mul(_inv[i-1],b[i-1].get_val()));
return r;
}
poly exp(const poly &a)
{
ui l(a.size());
if(l==1) return {ui2mi(li.v(1))};
poly b(a);ui k((l+1)/2);b.resize(k);
b=exp(b);poly c(b);c.resize(l);
c=sub(a,ln(c));c[0]=ui2mi(li.add(li.v(1),c[0].get_val()));
c=mul(b,c);c.resize(l);return c;
}
mi linear_recurrence(poly P,poly Q,ll k)
{
while(k)
{
poly nQ(Q);
{
ui l(nQ.size());
for(ui i(1);i<l;i+=2) nQ[i]=ui2mi(li.neg(nQ[i].get_val()));
}
poly _lQ(mul(nQ,Q));
{
ui l(Q.size());
for(ui i(0);i<l;i++) Q[i]=_lQ[i*2];
}
poly _lP(mul(P,nQ));
{
ui l(_lP.size());
ui t((l+1-(k&1))/2);
P.resize(t);
for(ui i(0),j(k&1);i<t;++i,j+=2) P[i]=_lP[j];
}
k>>=1;
}
return P[0];
}
};
#undef restrict
#undef NTT_partition_size
}
using library::polynomial_kernel_ntt;
using library::poly;
using library::mi;
using library::set_mod_mi;
using library::polynomial_kernel_mtt;
constexpr int N(90010),M(310),p(1e9+9);
mi inv[N],ifac[N],fac[N],a[N],b[N],ap[N],bp[N];
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
set_mod_mi(p);
inv[1]=1;for(int i(2);i<N;++i) inv[i]=inv[p%i]*(p-p/i);
fac[0]=ifac[0]=1;for(int i(1);i<N;++i) ifac[i]=ifac[i-1]*inv[i],fac[i]=fac[i-1]*i;
int n,m;std::cin>>n>>m;
polynomial_kernel_mtt ker(n*n+1,p);
for(int i(1);i<=n;++i) std::cin>>a[i];
for(int i(1);i<=n;++i) std::cin>>b[i];
poly pa(n*n+1),pb(n*n+1);
pa[0]=pb[0]=1;
for(int i=1;i<=n;i++) pa[i]=-a[i],pb[i]=-b[i];
pa=ker.ln(pa);pb=ker.ln(pb);
for(int i(0);i<=n*n;++i) pa[i]=(-pa[i])*i*ifac[i],pb[i]=(-pb[i])*i*ifac[i];
pa[0]=n;pb[0]=n;
poly K(ker.mul(pa,pb));K.resize(n*n+1);
for(int i(0);i<=n*n;++i) K[i]=-K[i]*fac[i]*inv[i];
K[0]=0;
pa=poly(n*n+1);pb=poly(n*n+1);
for(int i=1;i<=n;i++) pa[i]=-a[i],pb[i]=-b[i];
pa[0]=pb[0]=1;
pa=ker.inv(pa);pb=ker.inv(pb);
for(int i=0;i<=n*n;i++) pa[i]*=ifac[i],pb[i]*=ifac[i];
poly pc=ker.mul(pa,pb);
for(int i=0;i<=n*n;i++) pc[i]*=fac[i];
K=ker.exp(K);
poly F(ker.mul(K,pc));F.resize(n*n+1);
std::cout<<ker.linear_recurrence(F,K,m)<<std::endl;
return 0;
}
Details
Tip: Click on the bar to expand more detailed information
Test #1:
score: 100
Accepted
time: 5ms
memory: 5896kb
input:
2 3 1 1 1 1
output:
18
result:
ok 1 number(s): "18"
Test #2:
score: -100
Dangerous Syscalls
input:
3 4 1 2 3 1 3 2