QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#559604#6187. Digit Sum Problemucup-team4435TL 1ms3816kbC++207.7kb2024-09-12 03:14:522024-09-12 03:14:52

Judging History

你现在查看的是最新测评结果

  • [2024-09-12 03:14:52]
  • 评测
  • 测评结果:TL
  • 用时:1ms
  • 内存:3816kb
  • [2024-09-12 03:14:52]
  • 提交

answer

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;

#define all(a) begin(a), end(a)
#define len(a) int((a).size())

/*
 ! WARNING: MOD must be prime.
 ! WARNING: MOD must be less than 2^30.
 * Use .get() to get the stored value.
 */
template<uint32_t mod>
class montgomery {
    static_assert(mod < uint32_t(1) << 30, "mod < 2^30");
    using mint = montgomery<mod>;

private:
    uint32_t value;

    static constexpr uint32_t inv_neg_mod = []() {
        uint32_t x = mod;
        for (int i = 0; i < 4; ++i) {
            x *= uint32_t(2) - mod * x;
        }
        return -x;
    }();
    static_assert(mod * inv_neg_mod == -1);

    static constexpr uint32_t neg_mod = (-uint64_t(mod)) % mod;

    static uint32_t reduce(const uint64_t &value) {
        return (value + uint64_t(uint32_t(value) * inv_neg_mod) * mod) >> 32;
    }

    inline static const mint ONE = mint(1);

public:
    montgomery() : value(0) {}
    montgomery(const mint &x) : value(x.value) {}

    template<typename T, typename U = std::enable_if_t<std::is_integral<T>::value>>
    montgomery(const T &x) : value(!x ? 0 : reduce(int64_t(x % int32_t(mod) + int32_t(mod)) * neg_mod)) {}

    static constexpr uint32_t get_mod() {
        return mod;
    }

    uint32_t get() const {
        auto real_value = reduce(value);
        return real_value < mod ? real_value : real_value - mod;
    }

    template<typename T>
    mint power(T degree) const {
        degree = (degree % int32_t(mod - 1) + int32_t(mod - 1)) % int32_t(mod - 1);
        mint prod = 1, a = *this;
        for (; degree > 0; degree >>= 1, a *= a)
            if (degree & 1)
                prod *= a;

        return prod;
    }

    mint inv() const {
        return power(-1);
    }

    mint& operator=(const mint &x) {
        value = x.value;
        return *this;
    }

    mint& operator+=(const mint &x) {
        if (int32_t(value += x.value - (mod << 1)) < 0) {
            value += mod << 1;
        }
        return *this;
    }

    mint& operator-=(const mint &x) {
        if (int32_t(value -= x.value) < 0) {
            value += mod << 1;
        }
        return *this;
    }

    mint& operator*=(const mint &x) {
        value = reduce(uint64_t(value) * x.value);
        return *this;
    }

    mint& operator/=(const mint &x) {
        return *this *= x.inv();
    }

    friend mint operator+(const mint &x, const mint &y) {
        return mint(x) += y;
    }

    friend mint operator-(const mint &x, const mint &y) {
        return mint(x) -= y;
    }

    friend mint operator*(const mint &x, const mint &y) {
        return mint(x) *= y;
    }

    friend mint operator/(const mint &x, const mint &y) {
        return mint(x) /= y;
    }

    mint& operator++() {
        return *this += ONE;
    }

    mint& operator--() {
        return *this -= ONE;
    }

    mint operator++(int) {
        mint prev = *this;
        *this += ONE;
        return prev;
    }

    mint operator--(int) {
        mint prev = *this;
        *this -= ONE;
        return prev;
    }

    mint operator-() const {
        return mint(0) - *this;
    }

    bool operator==(const mint &x) const {
        return get() == x.get();
    }

    bool operator!=(const mint &x) const {
        return get() != x.get();
    }

    bool operator<(const mint &x) const {
        return get() < x.get();
    }

    template<typename T>
    explicit operator T() {
        return get();
    }

    friend std::istream& operator>>(std::istream &in, mint &x) {
        std::string s;
        in >> s;
        x = 0;
        bool neg = s[0] == '-';
        for (const auto c : s)
            if (c != '-')
                x = x * 10 + (c - '0');

        if (neg)
            x *= -1;

        return in;
    }

    friend std::ostream& operator<<(std::ostream &out, const mint &x) {
        return out << x.get();
    }

    static int32_t primitive_root() {
        if constexpr (mod == 1'000'000'007)
            return 5;
        if constexpr (mod == 998'244'353)
            return 3;
        if constexpr (mod == 786433)
            return 10;

        static int root = -1;
        if (root != -1)
            return root;

        std::vector<int> primes;
        int value = mod - 1;
        for (int i = 2; i * i <= value; i++)
            if (value % i == 0) {
                primes.push_back(i);
                while (value % i == 0)
                    value /= i;
            }

        if (value != 1)
            primes.push_back(value);

        for (int r = 2;; r++) {
            bool ok = true;
            for (auto p : primes)
                if ((mint(r).power((mod - 1) / p)).get() == 1) {
                    ok = false;
                    break;
                }

            if (ok)
                return root = r;
        }
    }
};

// constexpr uint32_t MOD = 1'000'000'007;
constexpr uint32_t MOD = 998'244'353;
using mint = montgomery<MOD>;

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);

    ll n;
    mint a, b, c;
    cin >> n >> a >> b >> c;

    auto digit_sum = [&](ll x, int power) {
        int sum = 0;
        while (x > 0) {
            sum += x % power;
            x /= power;
        }
        return sum;
    };

    if (n <= 1000) {
        mint ans = 0;
        for (int k = 1; k <= n; k++) {
            ans += mint(a).power(k) * mint(b).power(digit_sum(k, 2)) * mint(c).power(digit_sum(k, 3));
        }
        cout << ans << '\n';
        return 0;
    }

    int sq = sqrt(n) / 2 + 1;

    int power2 = 1;
    while (power2 < sq) {
        power2 *= 2;
    }
    int power3 = 1;
    while (power3 < power2) {
        power3 *= 3;
    }

    vector<mint> val3(power3);
    for (int i = 0; i < power3; i++) {
        val3[i] = a.power(i) * c.power(digit_sum(i, 3));
    }
    // binary lifting stuf
    for (int power = 1; power < power2; power *= 2) {
        for (int i = 0; i + power < power3; i++) {
            val3[i] += val3[i + power] * b;
        }
    }

    vector<mint> val2(power2);
    for (int i = 0; i < power2; i++) {
        val2[i] = a.power(i) * b.power(digit_sum(i, 2));
    }
    // another shit binary lifting
    for (int power = 1; power < power3; power *= 3) {
        for (int i = 0; i < power2; i++) {
            if (i + power < power2) {
                val2[i] += val2[i + power] * c;
            }
            if (i + 2 * power < power2) {
                val2[i] += val2[i + 2 * power] * c * c;
            }
        }
    }

    ll left = 1, prev2 = 0, prev3 = 0;
    mint ans = 0;
    while (left <= n) {
        while (prev2 + power2 <= left) {
            prev2 += power2;
        }
        while (prev3 + power3 <= left) {
            prev3 += power3;
        }

        ll next2 = prev2 + power2 - 1;
        ll next3 = prev3 + power3 - 1;
        ll next = min({next2, next3, n});

        int ppc = digit_sum(prev2, 2);
        int sum = digit_sum(prev3, 3);

        if (next == next3 || (next == next2 && left == prev2)) {
            ans += a.power(prev3) * b.power(ppc) * c.power(sum) * val3[left - prev3];
        } else if (next == next2 && left == prev3) {
            ans += a.power(prev2) * b.power(ppc) * c.power(sum) * val2[left - prev2];
        } else {
            for (ll i = left; i <= next; i++) {
                ans += a.power(i) * b.power(ppc + digit_sum(i - prev2, 2))
                    * c.power(sum + digit_sum(i - prev3, 3));
            }
        }

        left = next + 1;
    }
    cout << ans << '\n';
}

详细

Test #1:

score: 100
Accepted
time: 1ms
memory: 3816kb

input:

123456 12345 234567 3456789

output:

664963464

result:

ok 1 number(s): "664963464"

Test #2:

score: -100
Time Limit Exceeded

input:

9876543210987 12816 837595 128478

output:

7972694

result: