# Sasha and a Very Easy Test

# 解题思路

因为模数 PP 不一定是质数,不一定存在逆元,所以除法的操作需要特殊处理。

考虑一个数只有与模数 PP 不互质时才没有逆元,我们可以考虑将 PP 质因数分解,那么一个数被分成两部分:与 PP 互质的部分 和 与 PP 不互质的部分。与 PP 互质的部分是 PP 中不包含的质因数 (的乘积),不互质是包含的部分。

互质的部分除就直接乘逆元,不互质的部分需要记录一个数所含 PP 的各个质因子的个数,一个数的值将这两部分乘起来就可以,除法的话直接指数相减再更新答案。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define long long long
using namespace std;
const int MAXN = 5e5+15;
struct node {
    int cnt[10];
    long mul;
    node(){
        memset(cnt, 0, sizeof(cnt));
        mul = 1;
    }
};
namespace Tree {
    struct tree {
        int l, r;
        long sum; node lz;
    } a[MAXN*4];
    long val[MAXN];
    auto build(int l, int r, int i = 1) -> void;
    auto modify(int l, int r, node &tmp, int i = 1) -> void;
    auto div(int x, node &tmp, int i = 1) -> void;
    auto query(int l, int r ,int i = 1) -> long;
}
using namespace Tree;
int n, P, q;
int data[MAXN], opt, l, r, x;
int p[10], tot, pow_v[10][MAXN*20];
auto prework() -> void;
auto get_pow(int i, int j) -> long;
signed main()
{
    scanf("%d%d", &n, &P);
    for(int i = 1; i <= n; i += 1)
        scanf("%d", &data[i]);
    prework(); Tree::build(1, n);
    scanf("%d", &q);
    for(int i = 1; i <= q; i += 1)
    {
        scanf("%d", &opt);
        if(opt == 1)
        {
            scanf("%d%d%d", &l, &r, &x);
            node tmp;
            for(int j = 1; j <= tot; j += 1)
                while(x%p[j] == 0) x /= p[j], tmp.cnt[j] += 1;
            tmp.mul = x;
            Tree::modify(l, r, tmp);
        }
        else if(opt == 2)
        {
            scanf("%d%d", &l, &r);
            node tmp;
            for(int j = 1; j <= tot; j += 1)
                while(x%p[j] == 0) x /= p[j], tmp.cnt[j] += 1;
            tmp.mul = x;
            Tree::div(l, tmp);
        }
        else if(opt == 3)
        {
            scanf("%d%d", &l, &r);
            printf("%lld\n", Tree::query(l, r));
        }
    }
}
auto prework() -> void
{
    int tmp = P;
    for(int i = 2; (long)i*i <= tmp; i += 1)
    {
        if(tmp%i == 0)
        {
            p[++tot] = i;
            while(tmp%i == 0) tmp /= i;
        }
    }
    if(tmp != 1) p[++tot] = tmp;
    for(int i = 1; i <= tot; i += 1)
    {
        pow_v[i][0] = 1;
        for(int j = 1; j < MAXN*20; j += 1)
            pow_v[i][j] = (long)pow_v[i][j-1]*p[i]%P;
    }
}
auto get_pow(int i, int j) -> long
{
    return pow_v[i][j];
}
namespace Tree {
    #define ls (i<<1)
    #define rs (i<<1|1)
    auto exgcd(long a, long b, long &x, long &y) -> long
    {
        if(!b) return x = 1, y = 0, b;
        long d = exgcd(b, a%b, y, x);
        y -= (a/b)*x;
        return d;
    }
    auto inv(long a, long p) -> long
    {
        long x, y;
        exgcd(a, p, x, y);
        x = (x%p+p)%p;
        return x;
    }
    auto modify(tree &i, node &tmp) -> void
    {
        (i.sum *= tmp.mul) %= P;
        (i.lz.mul *= tmp.mul) %= P;
        if(i.l == i.r) (val[i.l] *= tmp.mul) %= P;
        for(int j = 1; j <= tot; j++)
        {
            i.lz.cnt[j] += tmp.cnt[j];
            (i.sum *= get_pow(j,tmp.cnt[j])) %= P;
        }
    }
    auto push_up(int i) -> void
    {
        a[i].sum = (a[ls].sum + a[rs].sum)%P;
    }
    auto push_down(int i)
    {
        modify(a[ls], a[i].lz);
        modify(a[rs], a[i].lz);
        memset(a[i].lz.cnt, 0, sizeof(a[i].lz.cnt));
        a[i].lz.mul = 1;
    }
    auto build(int l, int r, int i) -> void
    {
        a[i].l = l, a[i].r = r;
        if(l == r)
        {
            int x = data[l];
            for(int j = 1; j <= tot; j += 1)
            {
                int cnt = 0;
                while(x%p[j] == 0)
                    x /= p[j], cnt++;
                a[i].lz.cnt[j] = cnt;
            }
            val[l] = x;
            a[i].sum = data[l]%P;
            return void();
        }
        int mid = (l+r)>>1;
        build(l, mid, ls);
        build(mid+1, r, rs);
        push_up(i);
    }
    auto modify(int l, int r, node &tmp, int i) -> void
    {
        if(l <= a[i].l && a[i].r <= r)
            return modify(a[i], tmp);
        push_down(i);
        if(l <= a[ls].r) modify(l, r, tmp, ls);
        if(r >= a[rs].l) modify(l, r, tmp, rs);
        push_up(i);
    }
    auto div(int x, node &tmp, int i) -> void
    {
        if(a[i].l == a[i].r)
        {
            (val[a[i].l] *= inv(tmp.mul, P)) %= P;
            a[i].sum = val[a[i].l];
            for(int j = 1; j <= tot; j++)
            {
                a[i].lz.cnt[j] -= tmp.cnt[j];
                (a[i].sum *= get_pow(j,a[i].lz.cnt[j])) %= P;
            }
            return void();
        }
        push_down(i);
        if(x <= a[ls].r) div(x, tmp, ls);
        if(x >= a[rs].l) div(x, tmp, rs);
        push_up(i);
    }
    auto query(int l, int r, int i) -> long
    {
        if(l <= a[i].l and a[i].r <= r)
            return a[i].sum;
        push_down(i); long ans = 0;
        if(l <= a[ls].r) ans += query(l, r, ls);
        if(r >= a[rs].l) ans += query(l, r, rs);
        return ans%P;
    }
}

# Squirrel Migration

# 解题思路

考虑对于每一条边的最大贡献,对于树上的一条边,把它切断的话,树会分成两个联通块,分别记为 S1S_1S2S_2 ,设 S1<S2|S_1| < |S_2|

那么这一条边最多被经过 2S12|S_1| 次,此时有 xS1,pxS2\forall x \in S_1, p_x \in S_2 , 考虑将重心 GG 作为根,那么现在的每一个 S1S_1 都是 GG 的一棵子树,那么使得权值最大的条件即为 xsubtree(y),px∉subtree(y)\forall x \in subtree(y), p_x \not \in subtree(y)

现在有了结论,就可以容斥求答案了。

f(i)f(i) 表示至少有 ii 个点的 pxp_x 属于 ii 所在子树的答案,则题目所求为 i=1n(1)if(i)(ni)!\sum\limits _{i=1}^{n} (-1)^i*f(i)*(n-i)! , 单独考虑每一个子树 subtree(y)subtree(y),设其子树大小为 siz(y)siz(y),则在这个子树中 f(i)=(xi)2i!f(i) = \binom{x}{i}^2*i!, 最终以 GG 为根的 ff 就是把所有子树使用背包合并起来,时间复杂度 O(n2)O(n^2)

考虑到这个子树背包的形式与多项式乘法一样,我们可以将每个子树的 ff 看成一个多项式,每次选出两个长度最短的多项式相乘,因为所有多项式的总长度为 nn,所以时间复杂度是 O(nlog2n)O(n \log^2 n) 的。(有点类似于启发式合并)。

#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#define long long long
using namespace std;
namespace polynomial { 
    // @Ptilopsis's 板子
    const int P = 998244353;
    const int g = 3, gi = 332748118;
    int lim, bit;
    vector<int> r;
    struct poly : vector<long> {
        poly() {};
        poly(long a)
        {
            resize(1);
            (*this)[0] = (a%P+P)%P;
        }
        poly& modx(int n)
        {
            resize(n);
            return *this;
        }
        friend bool operator < (const poly &a, const poly &b)
        { return a.size() < b.size(); }
    };
    inline long ksm(long a, long b, long p = P)
    {
        long ans = 1;
        for(; b; b >>= 1, a = a*a%p)
            if(b&1) ans = ans*a%p;
        return ans;
    }
    inline void dnt_prework(int n)
    {
        lim = 1, bit = 0;
        while(lim < n+1) lim <<= 1, bit++;
        r.resize(lim);
        for(int i = 0; i < lim; i++)
            r[i] = (r[i>>1]>>1) | ((i&1)<<(bit-1));
    }
    inline void builtin_change(poly &a, bool type)
    {
        for(int i = 0; i < lim; i++)
            if(i < r[i]) swap(a[i], a[r[i]]);
        for(int mid = 1; mid < lim; mid <<= 1)
        {
            long wn = ksm(type?g:gi, (P-1)/(mid<<1));
            for(int j = 0, k = 0; j < lim; j += (mid<<1), k = j)
            {
                for(long w = 1; k < j+mid; k++, w = w*wn%P)
                {
                    long x = a[k], y = w*a[k+mid]%P;
                    a[k] = (x+y)%P, a[k+mid] = (x-y+P)%P;
                }
            }
        }
    }
    inline void dnt(poly &a)
    {
        a.resize(lim);
        builtin_change(a, true);
    }
    inline void idnt(poly &a)
    {
        a.resize(lim);
        builtin_change(a, false);
        long liminv = ksm(lim, P-2);
        for(int i = 0; i < lim; i++)
            a[i] = a[i]*liminv%P;
    }
     
    inline poly operator + (const poly &a, const poly &b)
    {
        poly c = a;
        if(a.size() < b.size()) c.resize(b.size());
        for(int i = 0; i < b.size(); i++)
            c[i] = (c[i]+b[i])%P;
        return c;
    }
    inline poly operator - (const poly &a, const poly &b)
    {
        poly c = a;
        if(a.size() < b.size()) c.resize(b.size());
        for(int i = 0; i < b.size(); i++)
            c[i] = (c[i]-b[i]+P)%P;
        return c;
    }
    poly operator * (poly a, poly b)
    {
        int len = a.size()+b.size()-2;
        dnt_prework(len);
        poly ans; ans.resize(lim+1);
        dnt(a); dnt(b);
        for(int i = 0; i < lim; i++)
            ans[i] = a[i]*b[i]%P;
        idnt(ans);
        ans.resize(len+1);
        return ans;
    }
    inline poly operator * (poly a, const long &k)
    {
        for(int i = 0; i < a.size(); i++)
            a[i] = a[i]*k%P;
        return a;
    }
     
    inline poly derivation(const poly &a)
    {
        poly b; b.resize(a.size()-1);
        for(int i = 1; i < a.size(); i++)
            b[i-1] = a[i]*i%P;
        return b;
    }
    inline poly integral(const poly &a)
    {
        poly b; b.resize(a.size()+1);
        for(int i = a.size()-1; i >= 1; i--)
            b[i] = a[i-1]*ksm(i, P-2)%P;
        return b;
    }
    inline poly inv(const poly &a)
    {
        stack<int> st;
        int n = a.size();
        while(n > 1) st.push(n), n = (n+1)/2;
        poly b = ksm(a[0], P-2);
        while(st.size())
        {
            n = st.top(); st.pop();
            poly c = a; c.modx(n); b.modx(n);
            b = (b*2-((c*b).modx(n)*b).modx(n)).modx(n);
        }
        return b.modx(a.size());
    }
    inline poly ln(const poly &a)
    {
        poly b, tmp;
        b = integral((derivation(a)*inv(a)));
        return b.modx(a.size());
    }
    inline poly sqrt(const poly &a)
    {
        stack<int> st;
        int n = a.size();
        while(n > 1) st.push(n), n = (n+1)/2;
        poly b = ksm(a[0], P-2);
        while(st.size())
        {
            n = st.top(); st.pop();
            poly c = a;
            c.modx(n); b.modx(n);
            b = ((c+(b*b).modx(n))*inv(b*2)).modx(n);
        }
        return b.modx(a.size());
    }
    inline poly exp(const poly &a)
    {
        stack<int> st;
        int n = a.size();
        while(n > 1) st.push(n), n = (n+1)/2;
        poly b = 1;
        while(st.size())
        {
            n = st.top(); st.pop();
            poly c = a;
            b.modx(n); c.modx(n);
            b = ((1-ln(b)+c)*b).modx(n);
        }
        return b.modx(a.size());
    }
    inline poly pow(const poly &a, const long &k)
    {
        return exp(k*ln(a));
    }
}
using namespace polynomial;
const int MAXN = 1e5+10;
struct edge{
    int to, next;
}a[MAXN<<1];
int n, u, v, G, Gmx, siz[MAXN];
int head[MAXN], cnt = 1;
long fact[MAXN], finv[MAXN];
auto add_edge(int from, int to) -> void;
auto prework(int n) -> void;
auto dfs(int x, int fa) -> void;
auto C(int n, int m) -> long;
signed main()
{
    scanf("%d", &n);
    prework(n);
    for(int i = 1; i < n; i += 1)
    {
        scanf("%d%d", &u, &v);
        add_edge(u, v);
        add_edge(v, u);
    }
    Gmx = n; dfs(1, 0);
    dfs(G, 0); poly f;
    priority_queue<poly, vector<poly>, greater<poly>> q;
    for(int  i = head[G]; i; i = a[i].next)
    {
        int y = a[i].to;
        f.resize(siz[y]+1); f[0] = 1;
        for(int i = 0; i <= siz[y]; i += 1)
            f[i] = C(siz[y],i)%P * C(siz[y],i)%P * fact[i]%P;
        q.push(f);
    }
    while(q.size() > 1)
    {
        poly a = q.top(); q.pop();
        poly b = q.top(); q.pop();
        a = a*b;
        if((int)a.size() > n+2) a.resize(n+2);
        q.push(a);
    }
    f = q.top(); long ans = 0;
    for(int i = 0; i <= min(n,(int)f.size()-1); i += 1)
        (ans += ((i%2?-1:1) * f[i]%P * fact[n-i]%P + P)%P) %= P;
    printf("%lld\n", ans);
}
 
auto dfs(int x, int fa) -> void
{
    int mx = 0; siz[x] = 1;
    for(int  i = head[G]; i; i = a[i].next)
    {
        int y = a[i].to;
        if(y == fa) continue;
        dfs(y, x); siz[x] += siz[y];
        mx = max(mx, siz[y]);
    }
    mx = max(mx, n-siz[x]);
    if(mx < Gmx) G = x, Gmx = mx;
}
auto C(int n, int m) -> long
{
    if(n < m or n < 0 or m < 0) return 0;
    return fact[n]%P * finv[m]%P *finv[n-m]%P;
}
auto prework(int n) -> void
{
    fact[0] = finv[0] = 1;
    for(int i = 1; i <= n; i += 1) fact[i] = fact[i-1] * i % P;
    for(int i = 1; i <= n; i += 1) finv[i] = ksm(fact[i], P-2);
}

# EntropyIncreaser 与金字塔

# 解题思路

第一个矩阵是环形的,第二个和第三个矩阵都是蛇形填数。
我们先按照第一个矩阵分层考虑: 先考虑 x1=y1=1,x2=y2=nx_1=y_1=1, x_2=y_2=n 时的做法。设当前的环为从外到内第 aa 个环,则这个环里最小的数 LL4(a1)(na1)+14(a-1)(n-a-1)+1,最大的数 RR4a(na)4a(n-a),这个可以按每个环里数的个数推导。 先把每个环里左上角的数抠出来,剩下的数的答案就是 i=L+1Ri(R+L+1i)\sum\limits _{i=L+1}^{R} i*(R+L+1-i),把这个式子拆开得到 (R+L+1)i=L+1Rii=L+1Ri2(R+L+1)\sum\limits_{i=L+1}^Ri - \sum\limits _{i=L+1}^{R}i^2,用平方和公式就可以 O(1)O(1) 算了。
现在我们可以 O(1)O(1) 地算每个环的贡献了,然后开始考虑满分做法。满分做法跟上面的只有一点不同:环不完整。这个我们可以把环拆成上下左右四部分,每部分被包含的都是连续的一段,处理好上下界就可以套用上面的式子了。

#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
#define long long long
#define int128 __int128_t
const int P = 1e8;
long n0, xa, ya, xb, yb;
int128 n, x1, x2, y1, y2, ans;
auto solve() -> long;
int main()
{
    scanf("%lld", &n0);
    scanf("%lld%lld", &xa, &ya);
    scanf("%lld%lld", &xb, &yb);
    n = n0; x1 = xa; y1 = ya; x2 = xb; y2 = yb;
    printf("%lld", solve());
    return 0;
}
auto S1(int128 x) -> int128
{
    return x*(x+1)/2 % P;
}
auto S2(int128 x) -> int128
{
    return x*(x+1)/2*(x*2+1)/3 % P;
}
auto calc(int128 a, int type) -> int128
{
    int128 L = (4*(a-1)*(n-a+1)+1)%P;
    int128 R = 4*a*(n-a)%P;
    int128 lval = 0, rval = 0;
    int128 tot = n-a*2+1;
    if(type == 1)
    {
        if(x2 < a or a < x1) return 0;
        if(y2 < a+1 or n-a < y1) return 0;
        lval = max(a+1,y1)-a + L-1;
        rval = min(n-a,y2)-a + L;
    }
    else if(type == 2)
    {
        if(y2 < n-a+1 or n-a+1 < y1) return 0;
        if(x2 < a or n-a < x1) return 0;
        lval = max(a,x1)-a   + L+tot-1;
        rval = min(n-a,x2)-a + L+tot;
    }
    else if(type == 3)
    {
        if(x2 < n-a+1 or n-a+1 < x1) return 0;
        if(y2 < a+1 or n-a+1 < y1) return 0;
        lval = n-a+1-min(n-a+1,y2) + L+2*tot-1;
        rval = n-a+1-max(a+1,y1)   + L+2*tot;
    }
    else if(type == 4)
    {
        if(y2 < a or a < y1) return 0;
        if(x2 < a+1 or n-a+1 < x1) return 0;
        lval = n-a+1-min(n-a+1,x2) + L+3*tot-1;
        rval = n-a+1-max(a+1,x1)   + L+3*tot;
    }
    int128 tmp1 = (S1(rval)-S1(lval)+P)%P * (L+R+1)%P;
    int128 tmp2 = (S2(rval)-S2(lval)+P)%P;
    
    return (tmp1-tmp2+P)%P;
}
auto solve() -> long
{
    for(int a = 1; a <= (n+1)/2; a += 1)
    {
        int128 sum = 0;
        int128 val = (4*(a-1)*(n-a+1)+1)%P; 
        if(x1 <= a && a <= x2 && y1 <= a && a <= y2) (sum += val*val%P) %= P;
        (sum += calc(a, 1)) %= P; (sum += calc(a, 2)) %= P;
        (sum += calc(a, 3)) %= P; (sum += calc(a, 4)) %= P;
        (ans += sum*a%P) %= P;
    }
    return ans;
}

# 「XXOI 2019」惠和惠惠和惠惠惠

# 题目描述

题意相当于在二维坐标系下,一开始在 (0,0)(0,0),每次可以向右上走,向右走,向下走,且不能低于 xx 轴,在 nn 次操作后,需要到达 (n,0)(n,0),同时要求恰好触碰 xxkk 次。(包括开始和结束)

# 解题思路

f(i,j)f(i,j) 表示走到 (i,0)(i,0) ,触碰 xxjj 次的方案数,则转移方程式为 f(i,j)=k=0i1f(k,j1)f(ik,2)f(i,j) = \sum\limits _{k=0}^{i-1}f(k,j-1)*f(i-k,2)。初始状态 f(0,1)=1f(0,1)=1,目标状态 f(n,k)f(n,k),设 g(n)=f(n,2)g(n) = f(n,2),考虑另一个东西 h(n)h(n) 表示从 (0,0)(0,0) 走到 (n,0)(n,0) ,同时不能进入第四象限的方案数,则有 g(n)=h(n2)g(n) = h(n-2),因为开头和结尾必须分别往上和往下。
下面就是式子了... ...
为了不重不漏地计数,我们只需要枚举除了 (0,0)(0,0) 外第一次触碰 xx 轴的位置即可:

h(n)=h(n1)+i=2ng(i)h(ni)h(n)=h(n1)+i=2nh(i2)h(ni)h(n)=h(n1)+i=0n2h(i)h(n2i)\begin{aligned} h(n) & = h(n-1) + \sum _{i=2}^{n}g(i)*h(n-i) \\ h(n) & = h(n-1) + \sum _{i=2}^{n}h(i-2)*h(n-i) \\ h(n) & = h(n-1) + \sum _{i=0}^{n-2}h(i)*h(n-2-i) \end{aligned}

H(x)H(x)hh 的生成函数,则有:

H(x)=xH(x)+x2H2(x)+10=x2H2(x)+(x1)H(x)+1H(x)=(1x)±12x3x22x2\begin{aligned} H(x) & = xH(x) + x^2H^2(x) + 1 \\ 0 & = x^2H^2(x) + (x-1)H(x) + 1 \\ H(x) & = \frac{(1-x) \pm \sqrt{1-2x-3x^2}}{2x^2} \end{aligned}

因为要求 H(0)=1H(0) = 1,所以该式应取负号。

考虑 ff 的生成函数,设 f(,j)=Fj(x)f(*,j) = F_j(x) ,不难得到 F1(x)=1F_1(x) = 1

同时有 Fj(x)=(x+x2H(x))Fj1(x)F_j(x) = (x+x^2H(x))*F_{j-1}(x),所以 Fj(x)=(x+x2H(x))j1F_j(x) = (x+x^2H(x))^{j-1}

H(x)H(x) 代入得:

Fj(x)=(x+x2H(x))j1=(x+1x12x3x22)j1=(1+x12x3x22)j1\begin{aligned} F_j(x) & = (x+x^2H(x))^{j-1} \\ & = \left( x + \frac{1-x-\sqrt{1-2x-3x^2}}{2} \right) ^{j-1} \\ & = \left( \frac{1+x-\sqrt{1-2x-3x^2}}{2} \right)^{j-1} \end{aligned}

可以用多项式开根和暴力乘的快速幂得到 O(nlognlogk)O(n \log n \log k) 的做法。

可以用整式递推数列推导式子做到 O(mlogm+m3+nm)O(m \log m + m^3 + nm) 的时间复杂度,其中 m=12m = 12

# Cigar Box

# 解题思路

考虑一个合法的操作序列,一个数可能被操作过很多次,但只有最后一次有用,我们定义一个数的最后一次操作为一个关键操作。设 f(i,l,r()f(i,l,r() 为进行了 ii 次操作,其中有 ll 次是放到左边的关键操作, rr 次是放到右边的关键操作的方案数,于是我们就有了一个 DP:f(i,l,r)=f(i1,l,r)(l+r)+f(i1,l1,r)+f(i1,l,r1)f(i,l,r) = f(i-1,l,r)*(l+r) + f(i-1,l-1,r) + f(i-1,l,r-1) 。但时间复杂度 O(mn2)O(mn^2) , 原地爆炸。
我们发现,单独记录 l,rl,r 并没有太大作用,所以我们可以把这两维压成一维,即把 l,rl, r 合在一起: f(i,j)f(i,j) 表示进行了 ii 次操作,有 jj 次是关键操作的方案数。f(i,j)=f(i1,j1)+f(i1,j)2jf(i,j) = f(i-1,j-1) + f(i-1,j)*2j
统计最终答案时,我们可以枚举 i,ji,j 表示下标在 [i,j][i,j] 之间的数没有进行过任何关键操作,如果目标序列中 [i,j][i,j] 之间的数是递增的,那么这就是一种合法方案 (因为原序列是递增的,没有被操作过的数相对大小不变),此方案贡献为 f(m,i1+nj)(i1+nji1)f(m,i-1+n-j)*\binom{i-1+n-j}{i-1}
但是此时我们还少考虑一种情况:所有数都被操作过。这种情况直接加到答案里就行了。其贡献为 \sum\limits_{i=0}^{n} f(m,n)*\binom{n}

#include<iostream>
#include<cstdio>
using namespace std;
#define long long long
const int MAXN = 3005;
const int p = 998244353;
int n, m, a[MAXN];
long f[MAXN][MAXN], ans;
long fact[MAXN], finv[MAXN];
auto solve() -> void;
auto prework(int n) -> void;
int main()
{
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i += 1)
        scanf("%d", &a[i]);
    prework(n); solve();
    printf("%lld", ans);
    return 0;
}
long C(int n, int m)
{
    if(n < m) return 0;
    return fact[n]*finv[m]%p*finv[n-m]%p;
}
long ksm(long a, long b)
{
    long ans = 1; a %= p;
    for(; b; b >>= 1, (a *= a) %= p)
        if(b&1) (ans *= a) %= p;
    return ans;
}
auto prework(int n) -> void
{
    fact[0] = finv[0] = 1;
    for(int i = 1; i <= n; i += 1) fact[i] = fact[i-1] * i % p;
    for(int i = 1; i <= n; i += 1) finv[i] = ksm(fact[i], p-2);
}
auto solve() -> void
{
    f[0][0] = 1;
    for(int i = 1; i <= m; i += 1)
    for(int j = 0; j <= n; j += 1)
    {
        if(j) (f[i][j] += f[i-1][j-1]) %= p;
        (f[i][j] += f[i-1][j]*j*2%p) %= p;
    }
    for(int i = 1; i <= n; i += 1)
    for(int j = i; j <= n; j += 1)
    {
        if(j != i and a[j-1] > a[j]) break;
        (ans += f[m][i-1+n-j] * C(i-1+n-j,i-1)%p) %= p;
    }
    for(int i = 0; i <= n; i += 1)
        (ans += f[m][n] * C(n,i)%p) %= p;
}

# Wine Thief

# 解题思路

直接算显然没法算,于是我们考虑计算每个数的贡献,即 aia_i 乘以包含 aia_i 的合法的方案数。 设 G(n,k,i)G(n,k,i) 表示 nn 个数中选 kk 个, 必须选第 ii 个数的合法方案数,所以答案就是: i=1naiG(n,k,i)\sum _{i=1}^{n} a_i*G(n,k,i),但是 G(n,k,i)G(n,k,i) 比较难求,暴力求的话只能枚举这个数前面后面选了多少个,如果能用其他的 GG 值求的话会比较方便。

为了不枚举这个数前面后面选了多少个,我们可以将序列连接成环来考虑。若把序列连接成环,只多了一个限制: a1a_1ana_n 不能同时选。先把 G(n,k,i)G(n,k,i) 的定义放到环上变成 G(n,k,i)G'(n,k,i),因为环是旋转同构的,所以对于任意的 ii ,他们的 G(n,k,i)G'(n,k,i) 都相同。设 f(n,k)f(n,k) 为在一个长度为 nn序列中,选出 kk 个数的方案数,考虑把这 kk 个数插入剩下的 nkn-k 个数之间的空隙中,所以 f(n,k)=Cnk+1nf(n,k) = C_{n-k+1}^{n} 。设 F(n,k)=G(n,k,i)F(n,k) = G'(n,k,i) ,首先有 F(n,k)=[k=1],n<3F(n,k)=[k=1], n<3 ,否则任意选一个数,环就会被这个数和它旁边的两个数断开,所以 F(n,k)=f(n3,k1),n>3F(n,k)=f(n-3,k-1), n>3

所以得到

F(n,k)={[k=1],n<3f(n3,k1),n3F(n,k) = \begin{cases} [k=1], & n < 3 \\ f(n-3,k-1), & n \ge 3 \end{cases}

但是环上还有一个 a1a_1ana_n 不能同时选的限制,考虑把这个限制去掉,我们只需要加上强制同时选 a1a_1ana_n 的情况就可以了,如果强制选上 a1a_1ana_n ,那么中间的 n4n-4 个点就又构成一个子问题,可以递归解决:

G(n,k,i)={0,i0F(n,k)+f(n4,k2),i=1F(n,k)+G(n4,k2,i1),i>1G(n,k,ni+1),i>n2G(n,k,i) = \begin{cases} 0 & ,i \le 0 \\ F(n,k) + f(n-4,k-2) & ,i=1 \\ F(n,k) + G(n-4,k-2,i-1) & ,i>1 \\ G(n,k,n-i+1) & ,i > \lceil \frac{n}{2} \rceil \end{cases}

如果每个 G(n,k,i)G(n,k,i) 直接递归的话是 O(n2)O(n^2) 的,考虑展开每个 G(n,k,i)G(n,k,i) 的递归:

G(n,k,i)=F(n,k)+F(n4,k2)++G(ni24,ki22,imod2)G(n,k,i) = F(n,k)+F(n-4,k-2)+ \cdots + G( n- \lfloor \frac{i}{2} \rfloor *4, k- \lfloor \frac{i}{2} \rfloor *2, i \bmod 2 )

对于每个 ii 而言,前面所有的 FF 项都是相同的,所以我们就可以直接预处理出 FF 的值的前缀和,然后 O(1)O(1) 计算出 GG 的值。
时间复杂度 O(能过)O(能过)

#include<iostream>
#include<cstdio>
using namespace std;
#define long long long
const int MAXN = 3e5+10;
const int p = 998244353;
int n, k, d, a[MAXN];
long sum[MAXN], ans;
long fact[MAXN], finv[MAXN];
auto solve() -> void;
auto prework() -> void;
int main()
{
    scanf("%d%d%d", &n, &k, &d);
    for(int i = 1; i <= n; i += 1)
        scanf("%d", &a[i]);
    prework(); solve();
    printf("%lld", ans);
    return 0;
}
auto C(int n, int m) -> long
{
    if(n < m or n < 0 or m < 0) return 0;
    return fact[n]%p * finv[m]%p * finv[n-m]%p;
}
auto f(int n, int k) -> long
{
    return C(n-k+1, k);
}
auto F(int n, int k) -> long
{
    if(n < 3) return (k == 1);
    else return f(n-3, k-1);
}
auto G(int n, int k, int i) -> long
{
    if(i <= 0) return 0;
    if(i == 1) return f(n-2, k-1);
    if(i > (n+1)/2) return G(n, k, n-i+1);
    return (sum[i/2] + G(n-(i/2)*4, k-(i/2)*2, i%2))%p;
}
auto ksm(long a, long b) -> long
{
    long ans = 1; a %= p;
    for(; b; b >>= 1, a = a*a%p)
        if(b&1) ans = ans*a%p;
    return ans;
}
auto prework() -> void
{
    fact[0] = finv[0] = 1;
    for(int i = 1; i <= n; i += 1) fact[i] = fact[i-1] * i % p;
    for(int i = 1; i <= n; i += 1) finv[i] = ksm(fact[i], p-2);
    for(int i = 0; i < n; i += 1) sum[i+1] = (sum[i]+F(n-i*4, k-i*2))%p;
}
auto solve() -> void
{
    for(int i = 1; i <= n; i += 1)
        (ans += a[i]*G(n, k, i)%p) %= p;
}