QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#186399#6399. Classic: Classical Problemucup-team228#WA 1ms3836kbC++209.3kb2023-09-23 19:18:382023-09-23 19:18:39

Judging History

你现在查看的是最新测评结果

  • [2023-09-23 19:18:39]
  • 评测
  • 测评结果:WA
  • 用时:1ms
  • 内存:3836kb
  • [2023-09-23 19:18:38]
  • 提交

answer

#include <bits/stdc++.h>

using namespace std;

#define all(v) (v).begin(), (v).end()
#define rall(v) (v).rbegin(), (v).rend()

typedef long long ll;
typedef pair<int, int> pii;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

template<typename T>
std::ostream& operator << (std::ostream& os, const vector<T>& a) {
    for (const T& x : a) {
        os << x << ' ';
    }
    return os;
}

/*
 * TESTED ON: https://judge.yosupo.jp/problem/convolution_mod
 * 260 ms on one multiplication with n = m = 5e5
 */

template<int mod>
class Modular {
public:
    int val;
    Modular() : val(0) {}
    Modular(int new_val) : val(new_val) {
    }
    friend Modular operator+(const Modular& a, const Modular& b) {
        if (a.val + b.val >= mod) return a.val + b.val - mod;
        else return a.val + b.val;
    }
    friend Modular operator-(const Modular& a, const Modular& b) {
        if (a.val - b.val < 0) return a.val - b.val + mod;
        else return a.val - b.val;
    }
    friend Modular operator*(const Modular& a, const Modular& b) {
        return 1ll * a.val * b.val % mod;
    }
    friend Modular binpow(Modular a, long long n) {
        Modular res = 1;
        for (; n; n >>= 1) {
            if (n & 1) res *= a;
            a *= a;
        }
        return res;
    }
    /* ALTERNATIVE INVERSE FUNCTION USING EXTENDED EUCLIDEAN ALGORITHM
    friend void gcd(int a, int b, Modular& x, Modular& y) {
        if (a == 0) {
            x = Modular(0);
            y = Modular(1);
            return;
        }
        Modular x1, y1;
        gcd(b % a, a, x1, y1);
        x = y1 - (b / a) * x1;
        y = x1;
    }
    friend Modular inv(const Modular& a) {
        Modular x, y;
        gcd(a.val, mod, x, y);
        return x;
    }
    */
    friend Modular inv(const Modular& a) {
        return binpow(a, mod - 2);
    }
    Modular operator/(const Modular& ot) const {
        return *this * inv(ot);
    }
    Modular& operator++() {
        if (val + 1 == mod) val = 0;
        else ++val;
        return *this;
    }
    Modular operator++(int) {
        Modular tmp = *this;
        ++(*this);
        return tmp;
    }
    Modular operator+() const {
        return *this;
    }
    Modular operator-() const {
        return 0 - *this;
    }
    Modular& operator+=(const Modular& ot) {
        return *this = *this + ot;
    }
    Modular& operator-=(const Modular& ot) {
        return *this = *this - ot;
    }
    Modular& operator*=(const Modular& ot) {
        return *this = *this * ot;
    }
    Modular& operator/=(const Modular& ot) {
        return *this = *this / ot;
    }
    bool operator==(const Modular& ot) const {
        return val == ot.val;
    }
    bool operator!=(const Modular& ot) const {
        return val != ot.val;
    }
    bool operator<(const Modular& ot) const {
        return val < ot.val;
    }
    bool operator>(const Modular& ot) const {
        return val > ot.val;
    }
    explicit operator int() const {
        return val;
    }
};

template <int mod>
Modular<mod> any_to_mint(ll a) {
    a %= mod;
    return a < 0 ? a + mod : a;
}

template<int mod>
istream& operator>>(istream& istr, Modular<mod>& x) {
    return istr >> x.val;
}

template<int mod>
ostream& operator<<(ostream& ostr, const Modular<mod>& x) {
    return ostr << x.val;
}

template <int mod = 998244353, int root = 3>
class NTT {
    using Mint = Modular<mod>;
public:
    static vector<int> mult(const vector<int>& a, const vector<int>& b) {
        vector<Mint> amod(a.size());
        vector<Mint> bmod(b.size());
        for (int i = 0; i < a.size(); i++) {
            amod[i] = any_to_mint<mod>(a[i]);
        }
        for (int i = 0; i < b.size(); i++) {
            bmod[i] = any_to_mint<mod>(b[i]);
        }
        vector<Mint> resmod = mult(amod, bmod);
        vector<int> res(resmod.size());
        for (int i = 0; i < res.size(); i++) {
            res[i] = resmod[i].val;
        }
        return res;
    }
    static vector<Mint> mult(const vector<Mint>& a, const vector<Mint>& b) {
        int n = int(a.size()), m = int(b.size());
        if (!n || !m) return {};
        int lg = 0;
        while ((1 << lg) < n + m - 1) lg++;
        int z = 1 << lg;
        auto a2 = a, b2 = b;
        a2.resize(z);
        b2.resize(z);
        nft(false, a2);
        nft(false, b2);
        for (int i = 0; i < z; i++) a2[i] *= b2[i];
        nft(true, a2);
        a2.resize(n + m - 1);
        Mint iz = inv(Mint(z));
        for (int i = 0; i < n + m - 1; i++) a2[i] *= iz;
        return a2;
    }

private:
    static void nft(bool type, vector<Modular<mod>> &a) {
        int n = int(a.size()), s = 0;
        while ((1 << s) < n) s++;
        assert(1 << s == n);
        static vector<Mint> ep, iep;
        while (int(ep.size()) <= s) {
            ep.push_back(binpow(Mint(root), (mod - 1) / (1 << ep.size())));
            iep.push_back(inv(ep.back()));
        }
        vector<Mint> b(n);
        for (int i = 1; i <= s; i++) {
            int w = 1 << (s - i);
            Mint base = type ? iep[i] : ep[i], now = 1;
            for (int y = 0; y < n / 2; y += w) {
                for (int x = 0; x < w; x++) {
                    auto l = a[y << 1 | x];
                    auto r = now * a[y << 1 | x | w];
                    b[y | x] = l + r;
                    b[y | x | n >> 1] = l - r;
                }
                now *= base;
            }
            swap(a, b);
        }
    }
};


int mult(int a, int b, int MOD) {
    return (a * 1LL * b) % MOD;
}
int bin_pow(int a, int n, int MOD) {
    int res = 1;
    for (; n; n >>= 1, a = mult(a, a, MOD)) {
        if (n & 1) {
            res = mult(res, a, MOD);
        }
    }
    return res;
}

vector<int> get_primes(int n) {
    vector<int> res;
    for (int d = 2; d * d <= n; ++d) {
        if (n % d == 0) {
            res.push_back(d);
            while (n % d == 0) {
                n /= d;
            }
        }
    }
    if (n > 1) {
        res.push_back(n);
    }
    return res;
}

bool check_primitive_root(int g, int P) {
    vector<int> prime_divs = get_primes(P - 1);
    for (int p : prime_divs) {
        if (bin_pow(g, (P - 1) / p, P) == 1) {
            return false;
        }
    }
    return true;
}

int get_primitive_root(int P) {
    int g = 2;
    while (!check_primitive_root(g, P)) {
        ++g;
    }
    //cout << g << '\n';
    return g;
}

void solve() {
    int n, P;
    cin >> n >> P;
    int g = get_primitive_root(P);

    vector<int> log_g(P);
    vector<int> pw_g(P - 1);
    for (int i = 0, power = 1; i < P - 1; ++i, power = mult(power, g, P)) {
        //cout << power << ' ' << i << '\n';
        pw_g[i] = power;
        log_g[power] = i;
    }
    //cout << endl;

    vector<int> a(P - 1);
    bool zero = false;
    for (int i = 0; i < n; ++i) {
        int y;
        cin >> y;
        if (y == 0) {
            zero = true;
            continue;
        }
        a[log_g[y]] = 1;
    }

    if (!zero) {
        cout << "1 1\n";
        cout << "0\n";
        return;
    }

    if (n == 1) {
        cout << P << " 1\n";
        for (int c = 0; c < P; ++c) {
            cout << c << ' ';
        }
        cout << '\n';
        return;
    }

    for (int i = 0; i < P - 1; ++i) {
        a.push_back(a[i]);
    }
    //cout << a << endl;
    int k = 0;
    while ((1 << k) < 3 * (P - 1)) {
        ++k;
    }
    int N = 1 << k;

    /*
    vector<Mint> res_a(N);
    while (a.size() < N) {
        a.push_back(0);
    }
    fft(a.data(), res_a.data(), N);
    */

    vector<int> out;

    auto check = [&](int mid) {
        out.clear();
        //cout << "mid = " << mid << endl;

        vector<int> b(P - 1);
        for (int x = 1; x <= mid; ++x) {
            b[log_g[x]] = 1;
        }
        //cout << b << '\n';
        reverse(all(b));
        /*
        while (b.size() < N) {
            b.push_back(0);
        }
        */
        //cout << a << '\n';
        //cout << b << '\n';
        /*
        vector<Mint> res_b(N);
        fft(b.data(), res_b.data(), N);
        for (int i = 0; i < N; ++i) {
            res_a[i] *= res_b[i];
        }

        vector<Mint> res = inter(res_a);
        */
        vector<int> res = NTT<998244353, 3>::mult(a, b);
        //cout << res << '\n';

        bool ok = false;
        for (int i = P - 2; i <= 2 * (P - 2); ++i) {
            if (res[i] == mid) {
                out.push_back(i - (P - 2));
                ok = true;
            }
        }

        return ok;
    };

    int low = 0, high = P - 1;
    while (high - low > 1) {
        int mid = (low + high) / 2;
        if (check(mid)) {
            low = mid;
        } else {
            high = mid;
        }
    }

    check(low);

    cout << out.size() << ' ' << low + 1 << '\n';
    vector<int> output;
    for (int c : out) {
        int cc = (P - 1) - c;
        cc %= P - 1;
        output.push_back(pw_g[c]);
    }
    sort(all(output));
    cout << output << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int t;
    cin >> t;
    while (t--) {
        solve();
    }

    return 0;
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1ms
memory: 3568kb

input:

3
2 3
0 2
3 5
2 3 4
3 5
0 2 3

output:

1 2
2 
1 1
0
2 2
2 3 

result:

ok 6 lines

Test #2:

score: -100
Wrong Answer
time: 0ms
memory: 3836kb

input:

3
1 2
0
1 2
1
2 2
1 0

output:

2 1
0 1 
1 1
0
1 1
1 

result:

wrong answer 5th lines differ - expected: '1 2', found: '1 1'