QOJ.ac

QOJ

IDProblemSubmitterResultTimeMemoryLanguageFile sizeSubmit timeJudge time
#356825#8303. Junior MathematicianFOYRE 1ms3840kbC++144.2kb2024-03-18 12:54:272024-03-18 12:54:27

Judging History

你现在查看的是最新测评结果

  • [2024-03-18 12:54:27]
  • 评测
  • 测评结果:RE
  • 用时:1ms
  • 内存:3840kb
  • [2024-03-18 12:54:27]
  • 提交

answer

#pragma GCC optimize "Ofast"
#include <iostream>
#include <vector>
#include <array>
using ll = long long;
using namespace std;
using a2 = array<int, 2>;
const ll mod = 1e9+7;

int sq(int x) {
	return x*x;
}

ll fastFix(ll x, ll m) {
	if (x >= m) x-=m;
	return x;
}
int fastFix(int x, int m) {
	if (x >= m) x-=m;
	return x;
}
ll midFix(ll x, ll m){
	return x%m;
}
int midFix(int x, int m){
	return x%m;
}
ll fix(ll x, ll m) {
	return ((x%m)+m)%m;
}

ll modPow(ll n, ll p, ll mod) {
	ll b = n, o = 1;
	while (p) {
		if (p&1) o = (o*b)%mod;
		b = (b*b)%mod;
		p/=2;
	}
	return o;
}

void solve(string l, string r, int m) {
	string s = "";
	while (s.size() + l.size() < r.size()) s += "0";
	l = s+l;
	l = "0" + l;
	r = "0" + r;
	int om = m;
	m *= 2;
	// dp[i][j][k][l] -> i=0 => low, i=1 => mid, i=2 => hi
	// cnt such that dig sum is j, dig squared sum is k
	vector<a2> top(l.size()), bot(l.size());
	vector<int> topMod(l.size()), botMod(l.size());
	vector<int> pref(l.size());
	for (int i = 1; i < l.size(); i++) {
		top[i] = top[i-1];
		bot[i] = bot[i-1];
		
		top[i][0] = midFix(top[i][0] + sq(r[i]-'0'), m);
		top[i][1] = midFix(top[i][1] + r[i]-'0', m);
		bot[i][0] = midFix(bot[i][0] + sq(l[i]-'0'), m);
		bot[i][1] = midFix(bot[i][1] + l[i]-'0', m);
	}
	for (int i = 1; i < l.size(); i++) {
		ll suf = 2*modPow(10, l.size()-i-1, m);
		pref[i] = fix(suf, m);
		topMod[i] = topMod[i-1]*10 + r[i]-'0';
		topMod[i] = midFix(topMod[i],m);

		botMod[i] = botMod[i-1]*10 + l[i]-'0';
		botMod[i] = midFix(botMod[i],m);
	}
	ll out = 0;

	if (fix(top.back()[0] + topMod.back()*2 - sq(top.back()[1]), m) == 0) {
		out++;
	}
	if (l != r && fix(bot.back()[0] + botMod.back()*2 - sq(bot.back()[1]), m) == 0) {
		out++;
	}

	if (l == r) {
		cout << out << endl;
		return;
	}

	vector<vector<long long>> dp(m, vector<long long>(m));
	dp[0][0] = 1;
	auto test = [&](int x, int d, vector<a2> &prefVals, vector<int> &prefMod) {
		// count total such that top[i][0] - (top[i][1])^2 + j - k = 0
		ll curSum = midFix(prefVals[x][1] + d, m);
		ll curSq = midFix(prefVals[x][0] + prefMod[x]*pref[x] + d*pref[x+1] + d*d, m);
		for (int j = 0; j < m; j++) {
			for (int k = 0; k < m; k++) {
				if (fix(sq(curSum + j) - curSq - k, m) == 0) {
					out = fastFix((out+dp[j][k]), mod);
				}
			}
		}
	};
	vector<bool> eq(l.size()+1);
	eq[0] = true;
	bool eqNow = true;
	for (int i =0; i < l.size(); i++) {
		if (l[i] != r[i]) eqNow = false;
		eq[i+1] = eqNow;
	}
	int base = 1;
	vector<vector<long long>> ndp(m, vector<long long>(m));
	while (l.size()) {
		int lo = l.back() - '0';
		int hi = r.back() - '0';
		l.pop_back();
		r.pop_back();
		for (int j = 0; j < m; j++) {
			for (int k = 0; k < m; k++) {
				ndp[j][k] = 0;
			}
		}
		for (int l = 0; l < 10; l++) {
			int x = midFix(2*base*l + sq(l), m);
			for (int j = 0; j < m; j++) {
				int a = fastFix(j+l, m);
				for (int k = 0; k+x < m; k++) {
					ndp[a][k+x] += dp[j][k];
				}
				for (int k = m-x; k < m; k++) {
					ndp[a][k+x-m] += dp[j][k];
				}
			}
		}
		for (int j = 0; j < m; j++) {
			for (int k = 0; k < m; k++) {
				ndp[j][k] = midFix(ndp[j][k], mod);
			}
		}
		base = fix(base*10, m);
		int x = r.size()-1;
		for (int i = 0; i < 10; i++) {
			// count total such that top[i][0] - (top[i][1])^2 + j - k = 0
			if (eq[l.size()] && i < hi && i > lo) {
				test(x, i, top, topMod);
			}
			else if (!eq[l.size()]) {
				if (i < hi) {
					test(x, i, top, topMod);
				}
				if (i > lo) {
					test(x, i, bot, botMod);
				}
			}
		}
		swap(dp, ndp);
	}
  cout << fix(out, mod)  << '\n';
	//if (out != cnt) assert(out == cnt);
}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(0);
			/*for (int i = 10; i < 100; i++) {
			for (int j = i; j < 100; j++) {
				for (int k = 2; k <= 20; k++) {
					cout << i << ' ' << j << ' ' << k << endl;
					solve(to_string(i), to_string(j), k);
				}
			}
		}*/
	int t; cin >> t;
	while (t--) {
		string l, r; cin >> l >> r;
		int m; cin >> m;

		//int m = fix(rand(), 60)+1;
		//string l = to_string(fix(rand(), 10000));
		//string r = to_string(fix(rand(), 10000) + stoi(l));
		//cout << l <<' ' << r << ' ' << m << endl;
		solve(l, r, m);
	}
}

Details

Tip: Click on the bar to expand more detailed information

Test #1:

score: 100
Accepted
time: 1ms
memory: 3840kb

input:

2
10
50
17
33
33
3

output:

2
1

result:

ok 2 lines

Test #2:

score: -100
Runtime Error

input:

1252
3893
6798
7
5883
8853
7
2999
4351
2
565
1767
7
1759
4751
10
79
8631
2
2128
8721
7
7890
8423
6
4708
7458
9
4501
6027
4
932
2708
2
3518
5859
7
4355
8296
3
2642
4470
10
7408
8939
8
4892
6777
7
4962
7976
6
2722
3171
7
6616
7527
6
7070
7612
5
429
2087
7
8786
8823
3
8831
8994
2
6346
8524
4
6026
8648
...

output:


result: