QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#356751#8303. Junior MathematicianFOY#WA 1ms3600kbC++143.9kb2024-03-18 10:42:412024-03-18 10:42:41

Judging History

你现在查看的是最新测评结果

  • [2024-03-18 10:42:41]
  • 评测
  • 测评结果:WA
  • 用时:1ms
  • 内存:3600kb
  • [2024-03-18 10:42:41]
  • 提交

answer

#include <iostream>
#include <vector>
#include <array>
#include <cassert>
using ll = long long;
using namespace std;
using pll = pair<ll, ll>;
using a2 = array<ll, 2>;
const ll mod = 1e9+7;

ll sq(ll x) {
	return x*x;
}

ll fastFix(ll x, ll m) {
	if (x >= m) x-=m;
	return x;
}
ll midFix(ll x, ll m){
	return x%m;
}
ll fix(ll x, ll m) {
	return ((x%m)+m)%m;
}

ll test(ll sum, ll sumsq, ll m) {
	ll x = fix(sum*sum-sumsq, m);
	return x/2;
}

ll modPow(ll n, ll p, ll mod) {
	ll b = n, o = 1;
	while (p) {
		if (p&1) o = (o*b)%mod;
		b = (b*b)%mod;
		p/=2;
	}
	return o;
}

void solve(string l, string r, int m) {
	string s = "";
	while (s.size() + l.size() < r.size()) s += "0";
	l = s+l;
	l = "0" + l;
	r = "0" + r;
	int om = m;
	m *= 2;
	// dp[i][j][k][l] -> i=0 => low, i=1 => mid, i=2 => hi
	// cnt such that dig sum is j, dig squared sum is k
	vector<vector<ll>> dp(m, vector<ll>(m));
	vector<a2> top(l.size()), bot(l.size());
	vector<ll> topMod(l.size()), botMod(l.size());
	vector<ll> pref(l.size());
	for (int i = 1; i < l.size(); i++) {
		top[i] = top[i-1];
		bot[i] = bot[i-1];
		
		top[i][0] = midFix(top[i][0] + sq(r[i]-'0'), m);
		top[i][1] = midFix(top[i][1] + r[i]-'0', m);
		bot[i][0] = midFix(bot[i][0] + sq(l[i]-'0'), m);
		bot[i][1] = midFix(bot[i][1] + l[i]-'0', m);
	}
	for (int i = 1; i < l.size(); i++) {
		ll suf = 2*modPow(10, l.size()-i-1, m);
		pref[i] = fix(suf, m);
		topMod[i] = topMod[i-1]*10 + r[i]-'0';
		topMod[i] = midFix(topMod[i],m);

		botMod[i] = botMod[i-1]*10 + l[i]-'0';
		botMod[i] = midFix(botMod[i],m);
	}
	dp[0][0] = 1;
	ll out = 0;

	if (fix(top.back()[0] + topMod.back()*2 - sq(top.back()[1]), m) == 0) {
		out++;
	}
	if (l != r && fix(bot.back()[0] + botMod.back()*2 - sq(bot.back()[1]), m) == 0) {
		out++;
	}

	if (l == r) {
		cout << out << endl;
		return;
	}

	auto test = [&](int x, int d, vector<a2> &prefVals, vector<ll> &prefMod) {
		// count total such that top[i][0] - (top[i][1])^2 + j - k = 0
		ll curSum = midFix(prefVals[x][1] + d, m);
		ll curSq = midFix(prefVals[x][0] + prefMod[x]*pref[x] + d*pref[x+1] + d*d, m);
		for (int j = 0; j < m; j++) {
			for (int k = 0; k < m; k++) {
				if (fix(sq(curSum + j) - curSq - k, m) == 0) {
					out = fastFix((out+dp[j][k]), mod);
				}
			}
		}
	};
	vector<bool> eq(l.size()+1);
	bool eqNow = true;
	for (int i =0; i < l.size(); i++) {
		if (l[i] != r[i]) eqNow = false;
		eq[i+1] = eqNow;
	}
	ll base = 1;
	while (l.size()) {
		int lo = l.back() - '0';
		int hi = r.back() - '0';
		l.pop_back();
		r.pop_back();
		vector<vector<ll>> ndp(m, vector<ll>(m));
		for (int j = 0; j < m; j++) {
			for (int k = 0; k < m; k++) {
				for (int l = 0; l < 10; l++) {
					ll a = midFix(j+l, m);
					ll b = midFix(k+sq(l) + 2*base*l,m);
					ndp[a][b] += dp[j][k];
					ndp[a][b] = fastFix(ndp[a][b], mod);
				}
			}
		}
		base = fix(base*10, m);
		int x = r.size()-1;
		for (int i = 0; i < 10; i++) {
			// count total such that top[i][0] - (top[i][1])^2 + j - k = 0
			if (eq[l.size()] && i < hi && i > lo) {
				test(x, i, top, topMod);
			}
			else if (!eq[l.size()]) {
				if (i < hi) {
					test(x, i, top, topMod);
				}
				if (i > lo) {
					test(x, i, bot, botMod);
				}
			}
		}
		dp = ndp;
	}
  cout << fix(out, mod)  << '\n';
	//if (out != cnt) assert(out == cnt);
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(0);
			/*for (int i = 10; i < 100; i++) {
			for (int j = i; j < 100; j++) {
				for (int k = 2; k <= 20; k++) {
					cout << i << ' ' << j << ' ' << k << endl;
					solve(to_string(i), to_string(j), k);
				}
			}
		}*/
	int t; cin >> t;
	while (t--) {
		string l, r; cin >> l >> r;
		int m; cin >> m;

		//int m = fix(rand(), 60)+1;
		//string l = to_string(fix(rand(), 10000));
		//string r = to_string(fix(rand(), 10000) + stoi(l));
		//cout << l <<' ' << r << ' ' << m << endl;
		solve(l, r, m);
	}
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 0
Wrong Answer
time: 1ms
memory: 3600kb

input:

2
10
50
17
33
33
3

output:

58
1

result:

wrong answer 1st lines differ - expected: '2', found: '58'