操作は、全て加算/減算のどちらかしか行いません。この問題では操作後の配列がどのようになるかを求めることで答えを求めることが出来ます。
と の大小関係で場合分けをします。
の場合(全て減算)
方針: となる を選び が になるまで を減算をする。これを 回繰り返す。
をソートします。
インデックスが小さい順に処理をしていきます。
操作を行った要素が になるまで、それが常に配列内で最小値となるので、この貪欲法が成り立ちます。
の場合(全て加算)
方針: となる を選び を加算をする。これを 回繰り返す。
操作後では下図のようになります。これをどうやって実現するか考えましょう。
をソートします。(これにブロックをいくつか入れていくという方針になります。)
各高さで目印を付け、そこまでの容量を求めます。
どこの目印までブロックを入れることが出来るか求めます。
残ったブロックで目印以下の要素で共通の高さにします。
さらに余ったブロックを乗せます。
※ 整数 から までの和は です。
from bisect import *
def f(a,b):
return (a+b)*(b-a+1)//2
n,s = map(int,input().split())
A = list(map(int,input().split()))
A.sort()
if(s<=sum(A)):
cnt = sum(A)-s
t = 0
ans = 0
for i in range(n):
if(t+A[i]<=cnt):
ans += f(1,A[i])
t += A[i]
else:
ans += f(A[i]-(cnt-t)+1,A[i])
break
print(ans)
else:
H = [A[0]]
W = []
CS = [0]
for i in range(1,n):
if(A[i-1]!=A[i]):
H.append(A[i])
W.append(i)
H.append(10**18)
W.append(n)
for i in range(len(H)-1):
CS.append(CS[-1] + (H[i+1]-H[i])*W[i])
cnt = s-sum(A)
idx = bisect_right(CS, cnt)-1
bl = False
if(idx==n):
bl = True
idx -= 1
v = H[idx]
t = 0
ans = 0
for i in range(n):
if(v<A[i]):
break
if(t+(v-A[i])<=cnt):
ans += f(A[i],v-1)
t += v-A[i]
s = cnt - t
num = W[idx]
q,r = s//num, s%num
ans += f(v,v+q-1)*num + (v+q)*r
print(ans)の場合、全ての値を最低でどのくらいまで上げるかを二分探索することが出来ます。
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
ll culc(ll min_v, ll max_v){
return max_v * (max_v + 1) / 2 - min_v * (min_v - 1) / 2;
}
ll lowSolve(ll n, ll s, vector<ll> a){
ll ans = 0;
for(int i = 0; i < n; i++){
ll mn = max(0LL, a[i] - s);
ans += culc(mn + 1, a[i]);
s = max(s - a[i], 0LL);
}
return ans;
}
ll highSolve(ll n, ll s, vector<ll> a){
ll ok = 0 , ng = 1e9;
auto f = [&](ll mid)-> ll {
ll res = 0;
for(int i = 0; i < n; i++) res += max(0LL, mid - a[i]);
return res;
};
while(ng - ok > 1){
ll mid = (ok + ng) / 2;
if(f(mid) <= s) ok = mid;
else ng = mid;
}
ll ans = 0;
for(int i = 0; i < n; i++){
if(a[i] < ok)ans += culc(a[i],ok-1);
s -= max(0LL, ok - a[i]);
}
return ans + ok * s;
}
int main(){
ll n,s;
cin >> n >> s;
vector<ll> a(n);
for(int i = 0; i < n; i++) cin >> a[i];
ll sum = 0;
for(int i = 0; i < n; i++) sum += a[i];
sort(a.begin(), a.end());
cout << (sum >= s ? lowSolve(n, sum - s, a) : highSolve(n, s - sum, a)) << endl;
}