世の中にはWavelet Matrixという便利なデータ構造があります。
調べて貼り付けましょう。
計算量は 構築に , クエリ処理に で、 です。
ライブラリ
ei1333さんのWavelet Matrixを参考にしています。
xxxxxxxxxx
struct SIDict {
private:
int blk;
std::vector<int> bit, sum;
public:
SIDict(){}
SIDict(const int len): blk((len + 31) >> 5), bit(blk), sum(blk){}
void set(const int k){ bit[k >> 5] |= 1 << (k & 31); }
void build() {
sum[0] = 0;
for(int i = 0; ++i < blk;) {
sum[i] = sum[i - 1] + __builtin_popcount(bit[i - 1]);
}
}
int rank(const int k) const { return (sum[k >> 5] + __builtin_popcount(bit[k >> 5] & ((1 << (k & 31)) - 1))); }
int rank(const bool val, const int k) const { return val ? rank(k) : k - rank(k); }
bool operator[](const int k) noexcept { return (bit[k >> 5] >> (k & 31)) & 1; }
};
template <class T, int log = 18> struct WMBeta {
private:
SIDict matrix[log];
int mid[log];
T access(int k) const {
T ret = 0;
for(int level = log; --level >= 0;) {
const bool f = matrix[level][k];
if(f) {
ret |= (T)1 << level;
}
k = matrix[level].rank(f, k) + mid[level] * f;
}
return ret;
}
std::pair<int, int> succ(const bool f, const int l, const int r, const int level) const { return {matrix[level].rank(f, l) + mid[level] * f, matrix[level].rank(f, r) + mid[level] * f}; }
public:
WMBeta(){}
WMBeta(std::vector<T> v) {
const int len = v.size();
std::vector<T> l(len), r(len);
for(int level = log; --level >= 0;) {
matrix[level] = SIDict(len + 1);
int left = 0, right = 0;
for(int i = 0; i < len; ++i) {
if((v[i] >> level) & 1) {
matrix[level].set(i);
r[right++] = v[i];
}
else {
l[left++] = v[i];
}
}
mid[level] = left;
matrix[level].build();
v.swap(l);
for(int i = 0; i < right; ++i) {
v[left + i] = r[i];
}
}
}
T operator[](const int k) noexcept { return access(k); }
int rank(const T x, int r) const {
int l = 0;
for(int level = log; --level >= 0;) {
std::tie(l, r) = succ((x >> level) & 1, l, r, level);
}
return r - l;
}
T kth_min(int l, int r, int k) const {
assert(0 <= k && k < r - l);
T ret = 0;
for(int level = log; --level >= 0;) {
const int cnt = matrix[level].rank(false, r) - matrix[level].rank(false, l);
const bool f = cnt <= k;
if(f) {
ret |= T(1) << level;
k -= cnt;
}
std::tie(l, r) = succ(f, l, r, level);
}
return ret;
}
T kth_max(const int l, const int r, const int k) const { return kth_min(l, r, r - l - k - 1); }
int range_freq(int l, int r, const T upper) const {
int ret = 0;
for(int level = log; --level;) {
const bool f = (upper >> level) & 1;
if(f) {
ret += matrix[level].rank(false, r) - matrix[level].rank(false, l);
}
std::tie(l, r) = succ(f, l, r, level);
}
return ret;
}
int range_freq(const int l, const int r, const T lower, const T upper) const { return range_freq(l, r, upper) - range_freq(l, r, lower); }
T prev(const int l, const int r, const T upper) const {
const int cnt = range_freq(l, r, upper);
return cnt == 0 ? (T)-1 : kth_min(l, r, cnt - 1);
}
T next(const int l, const int r, const T lower) const {
const int cnt = range_freq(l, r, lower);
return cnt == r - l ? (T)-1 : kth_min(l, r, cnt);
}
};
template <class T, int log = 20> struct WaveletMatrix {
private:
WMBeta<int, log> mat;
std::vector<T> ys;
inline int get(const T x) const { return std::lower_bound(ys.cbegin(), ys.cend(), x) - ys.cbegin(); }
T access(const int k) const { return ys[mat[k]]; }
public:
WaveletMatrix(const std::vector<T> v): ys(v) {
std::sort(ys.begin(), ys.end());
ys.erase(std::unique(ys.begin(), ys.end()), ys.end());
std::vector<int> t(v.size());
for(int i = 0; auto &el: v) {
t[i++] = get(el);
}
mat = WMBeta<int, log>(t);
}
T operator[](const int k) noexcept { return access(k); }
int rank(const T x, const int r) const {
const auto pos = get(x);
if(pos == std::ssize(ys) || ys[pos] != x) {
return 0;
}
return mat.rank(pos, r);
}
T kth_min(const int l, const int r, const int k) const { return ys[mat.kth_min(l, r, k)]; }
T kth_max(const int l, const int r, const int k) const { return ys[mat.kth_max(l, r, k)]; }
int range_freq(const int l, const int r, const T upper) const { return mat.range_freq(l, r, get(upper)); }
int range_freq(const int l, const int r, const T lower, const T upper) const { return mat.range_freq(l, r, get(lower), get(upper)); }
T prev(const int l, const int r, const T upper) {
const auto ret = mat.prev(l, r, get(upper));
return ret == -1 ? (T)-1 : ys[ret];
}
T next(const int l, const int r, const T lower) {
const auto ret = mat.next(l, r, get(lower));
return ret == -1 ? (T)-1 : ys[ret];
}
};
/**
* @brief Wavelet Matrix
*/
int main() {
int n, q;
std::cin >> n >> q;
std::vector<int64_t> a(n);
for(auto &el: a) {
std::cin >> el;
}
int64_t plus = 0;
WaveletMatrix wm(a);
while(q--) {
int t;
std::cin >> t;
if(t == 1) {
int x;
std::cin >> x;
plus += x;
} else {
int l, r, k;
std::cin >> l >> r >> k;
std::cout << wm.kth_max(--l, r, --k) + plus << '\n';
}
}
}
xxxxxxxxxx
import static java.lang.Math.*;
import java.io.Closeable;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Objects;
import java.util.Scanner;
import java.util.function.IntPredicate;
public final class Main implements Closeable {
private static final Scanner sc = new Scanner(System.in);
private static final PrintWriter pw = new PrintWriter(System.out, false);
public static final void main(final String[] args) {
final int n = sc.nextInt(), q = sc.nextInt();
final var a = new int[n];
Arrays.setAll(i -> sc.nextInt());
final var wm = new WaveletMatrix(a);
long plus = 0;
for(int i = 0; i < q; ++i) {
final int t = sc.nextInt();
if(t == 1) {
plus += sc.nextInt();
} else {
pw.println(wm.kthMax(sc.nextInt() - 1, sc.nextInt(), sc.nextInt() - 1) + plus);
}
}
new Main().close();
}
public final void close() {
pw.flush();
pw.close();
sc.close();
}
}
class Utility {
static final boolean scope(final int l, final int x, final int r){ return l <= x && x <= r; }
static final int lowerBound(final long[] a, final long x){ return bins(a.length, -1, (IntPredicate) y -> a[y] >= x); }
private static final int bins(int ok, int ng, final IntPredicate fn) {
while(abs(ok - ng) > 1) {
final int mid = (ok + ng) / 2;
if(fn.test(mid)) {
ok = mid;
}
else {
ng = mid;
}
}
return ok;
}
}
class Pair<F extends Comparable<? super F>, S extends Comparable<? super S>> implements Comparable<Pair<F, S>>, Cloneable {
public F first;
public S second;
protected Pair(final F first, final S second) {
this.first = first;
this.second = second;
}
static final <F extends Comparable<? super F>, S extends Comparable<? super S>> Pair<F, S> of(final F a, final S b){ return new Pair<>(a, b); }
Pair<S, F> swap(){ return Pair.of(second, first); }
public final boolean equals(final Object o) {
if(this == o) {
return true;
}
if(o == null || getClass() != o.getClass()) {
return false;
}
final Pair<?, ?> p = (Pair<?, ?>) o;
return first.equals(p.first) && second.equals(p.second);
}
public final int hashCode(){ return Objects.hash(first, second); }
public final String toString(){ return "(" + first + ", " + second + ")"; }
("unchecked")
public final Pair<F, S> clone() {
try {
return (Pair<F, S>) super.clone();
} catch(final CloneNotSupportedException e){
e.printStackTrace();
}
throw new Error();
}
public final int compareTo(final Pair<F, S> p) {
if(first.compareTo(p.first) == 0) {
return second.compareTo(p.second);
}
return first.compareTo(p.first);
}
}
final class IntPair extends Pair<Long, Long> {
private IntPair(final long first, final long second){ super(first, second); }
static final IntPair of(final long a, final long b){ return new IntPair(a, b); }
}
final class WaveletMatrix {
private final WaveletMatrixBeta mat;
private final long[] ys;
WaveletMatrix(final int[] arr){ this(arr, 20); }
WaveletMatrix(final long[] arr){ this(arr, 20); }
WaveletMatrix(final int[] arr, final int log) {
ys = Arrays.stream(arr).asLongStream().sorted().distinct().toArray();
final long[] t = new long[arr.length];
Arrays.setAll(t, i -> index(arr[i]));
mat = new WaveletMatrixBeta(t, log);
}
WaveletMatrix(final long[] arr, final int log) {
ys = Arrays.stream(arr).sorted().distinct().toArray();
final long[] t = new long[arr.length];
Arrays.setAll(t, i -> index(arr[i]));
mat = new WaveletMatrixBeta(t, log);
}
private final int index(final long x){ return Utility.lowerBound(ys, x); }
final long get(final int k){ return ys[(int) mat.access(k)]; }
final int rank(final int r, final long x) {
final int pos = index(x);
if(pos == ys.length || ys[pos] != x) {
return 0;
}
return mat.rank(pos, r);
}
final int rank(final int l, final int r, final long x){ return rank(r, x) - rank(l, x); }
final long kthMin(final int l, final int r, final int k){ return ys[(int) mat.kthMin(l, r, k)]; }
final long kthMax(final int l, final int r, final int k){ return ys[(int) mat.kthMax(l, r, k)]; }
final int rangeFreq(final int l, final int r, final long upper){ return mat.rangeFreq(l, r, index(upper)); }
final int rangeFreq(final int l, final int r, final long lower, final long upper){ return mat.rangeFreq(l, r, index(lower), index(upper)); }
final long prev(final int l, final int r, final long upper) {
final long ret = mat.prev(l, r, index(upper));
return ret == -1 ? -1 : ys[(int) ret];
}
final long next(final int l, final int r, final long lower) {
final long ret = mat.next(l, r, index(lower));
return ret == -1 ? -1 : ys[(int) ret];
}
private final class WaveletMatrixBeta {
private final int log;
private final SuccinctIndexableDictionary[] matrix;
private final int[] mid;
WaveletMatrixBeta(final long[] arr, final int log) {
final int len = arr.length;
this.log = log;
matrix = new SuccinctIndexableDictionary[log];
mid = new int[log];
final long[] l = new long[len], r = new long[len];
for(int level = log; --level >= 0;) {
matrix[level] = new SuccinctIndexableDictionary(len + 1);
int left = 0, right = 0;
for(int i = 0; i < len; ++i) {
if(((arr[i] >> level) & 1) == 1) {
matrix[level].set(i);
r[right++] = arr[i];
} else {
l[left++] = arr[i];
}
}
mid[level] = left;
matrix[level].build();
final long[] tmp = new long[len];
System.arraycopy(arr, 0, tmp, 0, len);
System.arraycopy(l, 0, arr, 0, len);
System.arraycopy(tmp, 0, l, 0, len);
for(int i = 0; i < right; ++i) {
arr[left + i] = r[i];
}
}
}
private final IntPair succ(final boolean f, final int l, final int r, final int level){ return IntPair.of(matrix[level].rank(f, l) + mid[level] * (f ? 1 : 0), matrix[level].rank(f, r) + mid[level] * (f ? 1 : 0)); }
final long access(int k) {
long ret = 0;
for(int level = log; --level >= 0;) {
final boolean f = matrix[level].get(k);
if(f) {
ret |= 1L << level;
}
k = matrix[level].rank(f, k) + mid[level] * (f ? 1 : 0);
}
return ret;
}
final int rank(final long x, int r) {
int l = 0;
for(int level = log; --level >= 0;) {
final IntPair p = succ(((x >> level) & 1) == 1, l, r, level);
l = p.first.intValue();
r = p.second.intValue();
}
return r - l;
}
final long kthMin(int l, int r, int k) {
if(!Utility.scope(0, k, r - l - 1)) {
throw new IndexOutOfBoundsException();
}
long ret = 0;
for(int level = log; --level >= 0;) {
final int cnt = matrix[level].rank(false, r) - matrix[level].rank(false, l);
final boolean f = cnt <= k;
if(f) {
ret |= 1 << level;
k -= cnt;
}
final IntPair p = succ(f, l, r, level);
l = p.first.intValue();
r = p.second.intValue();
}
return ret;
}
final long kthMax(final int l, final int r, final int k){ return kthMin(l, r, r - l - k - 1); }
final int rangeFreq(int l, int r, final long upper) {
int ret = 0;
for(int level = log; --level >= 0;) {
final boolean f = ((upper >> level) & 1) == 1;
if(f) {
ret += matrix[level].rank(false, r) - matrix[level].rank(false, l);
}
final IntPair p = succ(f, l, r, level);
l = p.first.intValue();
r = p.second.intValue();
}
return ret;
}
final int rangeFreq(final int l, final int r, final long lower, final long upper){ return rangeFreq(l, r, upper) - rangeFreq(l, r, lower); }
final long prev(final int l, final int r, final long upper) {
final int cnt = rangeFreq(l, r, upper);
return cnt == 0 ? -1 : kthMin(l, r, cnt - 1);
}
final long next(final int l, final int r, final long lower) {
final int cnt = rangeFreq(l, r, lower);
return cnt == r - l ? -1 : kthMin(l, r, cnt);
}
private final class SuccinctIndexableDictionary {
private final int blk;
private final int[] bit, sum;
SuccinctIndexableDictionary(final int len) {
blk = (len + 31) >> 5;
bit = new int[blk];
sum = new int[blk];
}
final void set(final int k){ bit[k >> 5] |= 1 << (k & 31); }
final void build() {
sum[0] = 0;
for(int i = 0; ++i < blk;) {
sum[i] = sum[i - 1] + Integer.bitCount(bit[i - 1]);
}
}
final boolean get(final int k){ return ((bit[k >> 5] >> (k & 31)) & 1) == 1; }
final int rank(final int k){ return (sum[k >> 5] + Integer.bitCount(bit[k >> 5] & ((1 << (k & 31)) - 1))); }
final int rank(final boolean val, final int k){ return val ? rank(k) : k - rank(k); }
}
}
}
Javaは入力か出力のどちらかを高速化をしてあげると通ります。
import strutils, sequtils, bitops, algorithm proc `|`(x: int, y: int): int = x or y proc `&`(x: int, y: int): int = x and y proc `>>`(x: int, y: int): int = x shr y proc `<<`(x: int, y: int): int = x shl y proc `|=`(x: var int, y: int): void = x = x | y proc `?`(x: bool): int = if x: 1 else: 0 type Pair[F, S] = ref object first*: F second*: S proc initPair[F, S](f: F, s: S): Pair[F, S] {.inline} = Pair[F, S](first: f, second: s) type SIDict = ref object blk: int bit, sum: seq[int] proc initSIDict(n: int): SIDict {.inline} = var sid = new SIDict sid.blk = (n + 31) >> 5 sid.bit.setLen((n + 31) >> 5) sid.sum.setLen((n + 31) >> 5) sid proc set(sid: SIDict, i: int) {.inline} = sid.bit[i >> 5] |= 1 << (i & 31) proc build(sid: SIDict) {.inline} = sid.sum[0] = 0 for i in 0..<sid.blk - 1: sid.sum[i + 1] = sid.sum[i] + popcount(sid.bit[i]) proc `[]`(sid: SIDict, i: int): bool {.inline} = ((sid.bit[i >> 5] >> (i & 31)) & 1) == 1 proc rank(sid: SIDict, i: int): int {.inline} = (sid.sum[i >> 5] + popcount(sid.bit[i >> 5] & ((1 << (i & 31)) - 1))) proc rank(sid: SIDict, val: bool, i: int): int {.inline} = if val: sid.rank(i) else: i - sid.rank(i) type WMBeta = ref object log: int mat: seq[SIDict] mid: seq[int] proc initWMBeta(arr: seq[int], log: int): WMBeta {.inline} = var wmb = new WMBeta var a = arr let len = len(a) wmb.log = log wmb.mat.setLen(log) wmb.mid.setLen(log) var l, r: seq[int] l.setLen(len) r.setLen(len) for level in countdown(log - 1, 0): wmb.mat[level] = initSIDict(len + 1) var left, right: int left = 0 right = 0 for i in 0..<len: if ((a[i] >> level) & 1) == 1: wmb.mat[level].set(i) r[right] = a[i] right += 1 else: l[left] = a[i] left += 1 wmb.mid[level] = left; wmb.mat[level].build(); swap(a, l) for i in 0..<right: a[left + i] = r[i] wmb proc `[]`(wmb: WMBeta, i: int): int {.inline} = var res, k: int res = 0 k = i for level in countdown(wmb.log - 1, 0): let f = wmb.mat[level][k] if f: res |= 1 << level k = wmb.mat[level].rank(f, k) + wmb.mid[level] * ?f res proc succ(wmb: WMBeta, f: bool, l: int, r: int, level: int): Pair[int, int] {.inline} = initPair(wmb.mat[level].rank(f, l) + wmb.mid[level] * ?f, wmb.mat[level].rank(f, r) + wmb.mid[level] * ?f) proc rank(wmb: WMBeta, x: int, id: int): int {.inline} = var l = 0 var r = id for level in countdown(wmb.log - 1, 0): let p = wmb.succ((x >> level) & 1 == 1, l, r, level) l = p.first r = p.second r - l proc kthMin(wmb: WMBeta, a: int, b: int, i: int): int {.inline} = assert 0 <= i and i < b - a var l, r, k, ret: int l = a r = b k = i ret = 0 for level in countdown(wmb.log - 1, 0): let cnt = wmb.mat[level].rank(false, r) - wmb.mat[level].rank(false, l) let f = cnt <= k if f: ret |= 1 << level k -= cnt let p = wmb.succ(f, l, r, level) l = p.first r = p.second ret proc kthMax(wmb: WMBeta, l: int, r: int, k: int): int {.inline} = wmb.kthMin(l, r, r - l - k - 1) proc rangeFreq(wmb: WMBeta, a: int, b: int, upper: int): int {.inline} = var l, r, ret: int l = a r = b ret = 0 for level in countdown(wmb.log - 1, 0): let f = upper >> level & 1 == 1 if f: ret += wmb.mat[level].rank(false, r) - wmb.mat[level].rank(false, l) let p = wmb.succ(f, l, r, level) l = p.first r = p.second ret proc rangeFreq(wmb: WMBeta, l: int, r: int, lower: int, upper: int): int {.inline} = wmb.rangeFreq(l, r, upper) - wmb.rangeFreq(l, r, lower) proc prev(wmb: WMBeta, l: int, r: int, upper: int): int {.inline} = let cnt = wmb.rangeFreq(l, r, upper) if cnt == 0: -1 else: wmb.kthMin(l, r, cnt - 1) proc next(wmb: WMBeta, l: int, r: int, lower: int): int {.inline} = let cnt = wmb.rangeFreq(l, r, lower) if cnt == 0: r - l else: wmb.kthMin(l, r, cnt) type WaveletMatrix = ref object mat: WMBeta ys: seq[int] proc get(wm: WaveletMatrix, x: int): int {.inline} = lowerBound(wm.ys, x) proc uniq(a: openArray[int]): seq[int] {.inline} = var res: seq[int] var j = -1 if len(a) > 0: j = 0 res.add(a[0]) for i, el in a: if a[j] == el: continue j = i res.add(el) res proc initWaveletMatrix*(a: seq[int], log: int = 20): WaveletMatrix {.inline} = var wm = new WaveletMatrix wm.ys = sorted(a).uniq var t = newSeq[int](len(a)) for i, el in a: t[i] = wm.get(el) wm.mat = initWMBeta(t, log) wm proc `[]`*(wm: WaveletMatrix, i: int): int {.inline} = wm.ys[wm.mat[i]] proc rank*(wm: WaveletMatrix, r: int, x: int): int {.inline} = let pos = wm.get(x) if pos == len(wm.ys) or wm.ys[pos] != x: 0 else: wm.mat.rank(pos, r) proc rank*(wm: WaveletMatrix, l: int, r: int, x: int): int {.inline} = wm.rank(r, x) - wm.rank(l, x) proc kthMin*(wm: WaveletMatrix, l: int, r: int, k: int): int {.inline} = wm.ys[wm.mat.kthMin(l, r, k)] proc kthMax*(wm: WaveletMatrix, l: int, r: int, k: int): int {.inline} = wm.ys[wm.mat.kthMax(l, r, k)] proc rangeFreq*(wm: WaveletMatrix, l: int, r: int, upper: int): int {.inline} = wm.mat.rangeFreq(l, r, wm.get(upper)) proc rangeFreq*(wm: WaveletMatrix, l: int, r: int, lower: int, upper: int): int {.inline} = wm.mat.rangeFreq(l, r, wm.get(lower), wm.get(upper)) proc prev*(wm: WaveletMatrix, l: int, r: int, upper: int): int {.inline} = let ret = wm.mat.prev(l, r, wm.get(upper)) if ret == -1: -1 else: wm.ys[ret] proc next*(wm: WaveletMatrix, l: int, r: int, lower: int): int {.inline} = let ret = wm.mat.next(l, r, wm.get(lower)) if ret == -1: -1 else: wm.ys[ret] let nq = stdin.readLine.split.map parseInt let q = nq[1] let a = stdin.readLine.split.map parseInt var plus = 0 var wm = initWaveletMatrix(a) for _ in 0..<q: let qry = stdin.readLine.split.map parseInt if qry[0] == 1: plus += qry[1] else: echo wm.kthMax(qry[1] - 1, qry[2], qry[3] - 1) + plus