QOJ.ac

QOJ

ID题目提交者结果用时内存语言文件大小提交时间测评时间
#512086#9167. Coprime Arrayucup-team133#WA 10ms11020kbPython35.6kb2024-08-10 13:29:292024-08-10 13:29:31

Judging History

你现在查看的是最新测评结果

  • [2024-08-11 17:38:28]
  • hack成功,自动添加数据
  • (/hack/775)
  • [2024-08-10 13:29:31]
  • 评测
  • 测评结果:WA
  • 用时:10ms
  • 内存:11020kb
  • [2024-08-10 13:29:29]
  • 提交

answer

import sys
from itertools import permutations
from heapq import heappop,heappush
from collections import deque
import random
import bisect
from math import gcd

input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

def floor_sum(n: int, m: int, a: int, b: int) -> int:
    ans = 0

    if a >= m:
        ans += (n - 1) * n * (a // m) // 2
        a %= m

    if b >= m:
        ans += n * (b // m)
        b %= m

    y_max = (a * n + b) // m
    x_max = y_max * m - b

    if y_max == 0:
        return ans

    ans += (n - (x_max + a - 1) // a) * y_max
    ans += floor_sum(y_max, a, m, (a - x_max % a) % a)

    return ans

def _inv_gcd(a,b):
    a %= b
    if a == 0:
        return (b, 0)
 
    # Contracts:
    # [1] s - m0 * a = 0 (mod b)
    # [2] t - m1 * a = 0 (mod b)
    # [3] s * |m1| + t * |m0| <= b
    s = b
    t = a
    m0 = 0
    m1 = 1
 
    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u  # |m1 * u| <= |m1| * s <= b
 
        # [3]:
        # (s - t * u) * |m1| + t * |m0 - m1 * u|
        # <= s * |m1| - t * u * |m1| + t * (|m0| + |m1| * u)
        # = s * |m1| + t * |m0| <= b
 
        s, t = t, s
        m0, m1 = m1, m0
 
    # by [3]: |m0| <= b/g
    # by g != b: |m0| < b/g
    if m0 < 0:
        m0 += b // s
 
    return (s, m0)
 
def crt(r,m):
    assert len(r) == len(m)
 
    n = len(r)
 
    # Contracts: 0 <= r0 < m0
    r0 = 0
    m0 = 1
    for i in range(n):
        assert 1 <= m[i]
        r1 = r[i] % m[i]
        m1 = m[i]
        if m0 < m1:
            r0, r1 = r1, r0
            m0, m1 = m1, m0
        if m0 % m1 == 0:
            if r0 % m1 != r1:
                return (0, 0)
            continue
 
        # assume: m0 > m1, lcm(m0, m1) >= 2 * max(m0, m1)
 
        '''
        (r0, m0), (r1, m1) -> (r2, m2 = lcm(m0, m1));
        r2 % m0 = r0
        r2 % m1 = r1
        -> (r0 + x*m0) % m1 = r1
        -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1)
        -> x = (r1 - r0) / g * inv(u0) (mod u1)
        '''
 
        # im = inv(u0) (mod u1) (0 <= im < u1)
        g, im = _inv_gcd(m0, m1)
 
        u1 = m1 // g
        # |r1 - r0| < (m0 + m1) <= lcm(m0, m1)
        if (r1 - r0) % g:
            return (0, 0)
 
        # u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1)
        x = (r1 - r0) // g % u1 * im % u1
 
        '''
        |r0| + |m0 * x|
        < m0 + m0 * (u1 - 1)
        = m0 + m0 * m1 / g - m0
        = lcm(m0, m1)
        '''
 
        r0 += x * m0
        m0 *= u1  # -> lcm(m0, m1)
        if r0 < 0:
            r0 += m0
 
    return (r0, m0)

def isPrimeMR(n):
    if n==1:
        return 0
    d = n - 1
    d = d // (d & -d)
    L = [2, 3, 5, 7, 11, 13, 17]
    if n in L:
        return 1
    for a in L:
        t = d
        y = pow(a, t, n)
        if y == 1: continue
        while y != n - 1:
            y = (y * y) % n
            if y == 1 or t == n - 1: return 0
            t <<= 1
    return 1
def findFactorRho(n):
    from math import gcd
    m = 1 << n.bit_length() // 8
    for c in range(1, 99):
        f = lambda x: (x * x + c) % n
        y, r, q, g = 2, 1, 1, 1
        while g == 1:
            x = y
            for i in range(r):
                y = f(y)
            k = 0
            while k < r and g == 1:
                ys = y
                for i in range(min(m, r - k)):
                    y = f(y)
                    q = q * abs(x - y) % n
                g = gcd(q, n)
                k += m
            r <<= 1
        if g == n:
            g = 1
            while g == 1:
                ys = f(ys)
                g = gcd(abs(x - ys), n)
        if g < n:
            if isPrimeMR(g): return g
            elif isPrimeMR(n // g): return n // g
            return findFactorRho(g)
def primeFactor(n):
    i = 2
    ret = {}
    rhoFlg = 0
    while i*i <= n:
        k = 0
        while n % i == 0:
            n //= i
            k += 1
        if k: ret[i] = k
        i += 1 + i % 2
        if i == 101 and n >= 2 ** 20:
            while n > 1:
                if isPrimeMR(n):
                    ret[n], n = 1, 1
                else:
                    rhoFlg = 1
                    j = findFactorRho(n)
                    k = 0
                    while n % j == 0:
                        n //= j
                        k += 1
                    ret[j] = k

    if n > 1: ret[n] = 1
    if rhoFlg: ret = {x: ret[x] for x in sorted(ret)}
    return ret

def solve(s,x):
    if gcd(s,x) == 1:
        return [s]
    
    if s & 1 == 1 and x & 1 == 0:
        res = solve(s-1,x)
        return [1] + res
    
    """
    [s+t,-t] such that
    t != 0 mod p
    t != -s mod p
    """
    pf = primeFactor(x)
    R = []
    M = []
    for p in pf:
        for r in range(1,p):
            if r!=((-s) % p):
                R.append(r)
                M.append(p)
                break
        else:
            assert False
    
    t,m = crt(R,M)
    assert gcd(s+t,p) == 1
    assert gcd(t,p) == 1

    return [s-t,t]

def brute(s,x):
    if gcd(s,x) == 1:
        return [s]
    
    for t in range(-1000,1000):
        if gcd(abs(s+t),x) == 1 and gcd(abs(t),x) == 1:
            return [s+t,-t]
    
    return [1,1,s-2]




while False:
    s,x = random.randint(2,10**9),random.randint(2,10**9)
    A = solve(s,x)
    print(A)
    assert max(abs(a) for a in A) <= 10**9
s,x = mi()
A = solve(s,x)
assert sum(A) == s
print(len(A))
print(*A)

详细

Test #1:

score: 0
Wrong Answer
time: 10ms
memory: 11020kb

input:

9 6

output:

3
1 3 5

result:

wrong answer Element at position 2 is not coprime to x