DPでシミュレーションを行うと、計算量は となり、TLEしてしまいます。
しかし、くしらくんが 段目に達する遷移を除くと、それまでの遷移に規則性があることに気づくと思います。
まず のときシミュレーションをすればよいです。
以下 のときについて考えます。
ここで
段目にいく通り数
段目から 段目へ 遷移できるかどうか ( or )
と定義します。
そして、DPの遷移をもとに の漸化式を求めます
となります。
高速に求めるためを行列累乗をします。計算式は以下の通りになります。
最終的にを求め、それぞれから遷移をすることで答えを求めることができます。 計算量は です。
from bisect import * mod = 998244353 def idx_le(A, x): # x 以下の最大の要素位置 / なければ "No" return bisect_right(A, x)-1 if bisect_right(A, x)-1 != -1 else "No" def idx_lt(A, x): # x 未満の最大の要素位置 / なければ "No" return bisect_left(A, x)-1 if bisect_right(A, x)-1 != -1 else "No" def idx_ge(A, x): # x 以上の最小の要素位置 / なければ "No" return bisect_left(A, x) if bisect_left(A, x) != len(A) else "No" def idx_gt(A, x): # x 超過の最小の要素位置 / なければ "No" return bisect_right(A, x) if bisect_right(A, x) != len(A) else "No" def cnt_le(A, x): # x 以下の要素の個数 if(idx_le(A, x) == "No"): return 0 return idx_le(A, x) + 1 def cnt_lt(A, x): # x 未満の要素の個数 if(idx_lt(A, x) == "No"): return 0 return idx_lt(A, x) + 1 def cnt_ge(A, x): # x 以上の要素の個数 return len(A) - cnt_lt(A, x) def matrixMultiplication_2D(a,b,m): #行列の掛け算(a×b) m:mod I,J,K,L = len(a),len(b[0]),len(b),len(a[0]) if(L!=K): return -1 c = [[0] * J for _ in range(I)] for i in range(I) : for j in range(J) : for k in range(K) : c[i][j] += a[i][k] * b[k][j] c[i][j] %= m return c def matrixExponentiation_2D(x,n,m): #行列の累乗 (x^n) m:mod y = [[0] * len(x) for _ in range(len(x))] for i in range(len(x)): y[i][i] = 1 while n > 0: if n & 1: y = matrixMultiplication_2D(x,y,m) x = matrixMultiplication_2D(x,x,m) n >>= 1 return y def f(x): A = [] for i in range(m): A.append([dp[i]]) B = matrixMultiplication_2D(matrixExponentiation_2D(P,x-m,mod), A, mod) C = [] for i in range(m): C.append(B[i][0]) return C n,X = map(int,input().split()) A = list(map(int,input().split())) A.sort() m = max(A) L = [0] * m for a in A: L[m-a] = 1 dp = [0 for i in range(m+10)] dp[0] = 1 for i in range(m+10): for a in A: if(i+a<m+10): dp[i+a] += dp[i] P = [] for i in range(m-1): Q = [0 for i in range(m)] Q[i+1] = 1 P.append(Q) P.append(L) if(X<m): S = [] for i in range(X): S.append(dp[i]) else: S = f(X) ans = 0 s = len(S) for i in range(s): ans += cnt_ge(A,s-i)*S[i] ans %= mod print(ans)
#include<bits/stdc++.h> using namespace std; using ll = long long; template<class T = long long> struct Matrix{ int _R , _C; std::vector<std::vector<T>> val; Matrix(int r, int c, T Val): val(r,std::vector<T>(c,Val)), _R(r), _C(c) {} Matrix(int r, int c): Matrix(r, c, 0) {} Matrix(int n): Matrix(n, n, 0) {} Matrix(std::vector<std::vector<T>>&matrix): val(matrix), _R(val.size()), _C(matrix[0].size()) {} Matrix(Matrix& matrix): val(matrix.val), _R(matrix._R), _C(matrix._C) {} std::vector<T>& operator[](int i){ return val[i]; } Matrix operator*(const Matrix& other){ assert(this->_C == other._R); Matrix result(this->_R, other._C, 0); for(int i = 0; i < this->_R; i++){ for(int j = 0; j < other._C; j++){ for(int k = 0; k < this->_C; k++){ result[i][j] += this->val[i][k] * other.val[k][j]; } } } return result; } }; struct mint{ static const int mod = 998244353; ll val; mint(int k = 0): val((k % mod + mod) % mod) {} mint operator*(const mint other){return val * other.val % mod;} mint operator+(const mint other){return val + other.val >= mod? val + other.val - mod : val + other.val;} mint& operator+=(const mint other){ val = val + other.val >= mod? val + other.val - mod : val + other.val; return *this; } }; Matrix<mint> pow(Matrix<mint>a,ll n){ int si = a[0].size(); Matrix<mint> res(si,si,0); for(int i = 0; i < si; i++){ res[i][i] = 1; } while(n > 0){ if(n&1) res = res * a; a = a * a; n >>=1; } return res; } int main(){ ll n, x, k = 100; cin >> n >> x; vector<int> a(n); for(int i = 0; i < n; i++)cin >> a[i]; assert(1 <= n && n <= 100); assert(1 <= x && x <= 1'000'000'000'000); for(int i = 0; i < n; i++)assert(1 <= a[i] && a[i] <= 100); Matrix<mint> mat(k, k, 0); for(int i = 0; i < n; i++) mat[0][a[i] - 1] = 1; for(int i = 1; i < k; i++) mat[i][i - 1] = 1; Matrix<mint> dp(k, 1, 0); dp[0][0] = 1; mat = pow(mat, x - 1); dp = mat * dp; mint ans = 0; for(int i = 0; i < n; i++){ for(int j = 0; j < a[i]; j++){ ans += dp[j][0]; } } cout << ans.val << endl; }