C++/PS

[BOJ/C++] 17476. 수열과 쿼리 28

Kareus 2021. 12. 27. 02:49

문제 : https://www.acmicpc.net/problem/17476

 

17476번: 수열과 쿼리 28

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.  1 L R X: 모든 L ≤ i ≤ R에 대해서 Ai = Ai + X를 적용한다.  2 L R: 모든 L ≤ i ≤ R에 대해서 Ai = ⌊√A

www.acmicpc.net

 

1. 설명

길이가 $N$인 수열 $A$에 대해 3개의 쿼리를 실행해야 합니다.

이때 $N \le 100000$이고, 총 쿼리 수 $Q \le 100000$입니다.

 

1 L R X : $L \le i \le R$인 모든 $A_i$에 대해 $A_i = A_i + X$를 실행한다.

2 L R : $L \le i \le R$인 모든 $A_i$에 대해 $A_i = \sqrt{A_i|$를 실행한다.

3 L R : $A_L + A_{L+1} + ... + A_R$의 값을 출력한다.

 

2. 풀이

더보기

구간에 대한 쿼리를 수행하므로 Segment Tree, 그 중에서도 Lazy propagation을 사용하는 Segment Tree를 이용해야 합니다.

1번과 3번 쿼리는 구현 예제로 나올만큼 간단한 쿼리지만, 2번이 문제입니다.

 

쿼리 실행을 리프 노드까지 일일이 실행하는 것은 시간 초과가 발생하기 때문에

더이상 노드를 내려가지 않아도 되는 조건을 적당히 잘 잡아주는 게 중요합니다.

이런 방식을 Segment Tree Beats라고 합니다. (적어도 그런 개념을 일컫는 것 같습니다. 이해하기 너무 어려워잉)

 

Segment Tree Beats에서는 query를 실행할 때 중단 조건과 갱신 조건을 체크합니다.

말 그대로, 더이상 해당 구간에 갱신할 필요가 없다면 중단 조건을 만족하여 갱신을 중단하게 됩니다.

갱신 조건은 반대로, 해당 구간에 존재하는 모든 노드가 갱신될 필요가 있다는 것을 의미합니다.

 

예를 들어, $A_i = \max(A_i, X)$라는 쿼리를 실행하는 경우, 갱신하려는 노드의 구간에 존재하는 최댓값이 이미 $X$ 이상임을 알고 있다면, 그 구간에 대해서는 더이상 쿼리를 실행할 필요가 없다는 것을 미리 알 수 있습니다.

 

문제에 따라 이 조건을 잘 찾아서 최적화를 하는 것이 관건입니다.

이를 이용해서 푸는 문제는 이 문제 이외에도 17473. 수열과 쿼리 2519455. Bitwise Queries 등이 있습니다.

 

아무튼 돌아와서, 2번 쿼리의 중단/갱신 조건을 알아내야 합니다.

우선, 1번 쿼리와 2번 쿼리가 각각 작용하는 게 다르기 때문에, lazy value를 두 개 설정했습니다.

덧셈에 관한 lazy value를 lazy_add, 제곱근에 관한 lazy value를 lazy_sqrt라고 하겠습니다.

 

동시에 판단하기엔 복잡하니 더 어려워 보이는 lazy_sqrt 부터 생각해봤습니다.

사실, lazy_sqrt 값이 없다면 (0이라면) lazy_add만 고려하면 되므로 이 경우에 대한 구현은 쉽습니다.

10999. 구간 합 구하기 2 에서 푸는 방식 그대로 적용해도 된다는 이야기니까요.

void lazy_update(int node, int start, int end)
{
    if (lazy_sqrt == 0)
    {
        v[node].sum += v[node].lazy_add * (end - start + 1);
        v[node].M += v[node].lazy_add;
        v[node].m += v[node].lazy_add;
        
        if (start != end)
        {
            v[node << 1].lazy_add += v[node].lazy_add;
            v[node << 1 | 1].lazy_add += v[node].lazy_add;
        }
    }
    ...
    v[node].lazy_add = 0;
}

void add(int node, int start, int end, int left, int right, long long val)
{
    lazy_update(node, start, end);

    if (left > end || right < start) return;

    if (left <= start && end <= right)
    {
        v[node].lazy_add += val;
        lazy_update(node, start, end);
        return;
    }

    int mid = (start + end) >> 1;

    add(node << 1, start, mid, left, right, val);
    add(node << 1 | 1, mid + 1, end, left, right, val);

    v[node].m = min(v[node << 1].m, v[node << 1 | 1].m);
    v[node].M = max(v[node << 1].M, v[node << 1 | 1].M);
    v[node].sum = v[node << 1].sum + v[node << 1 | 1].sum;
}

 

구간 내에 있는 모든 노드의 sqrt 값이 같다면, 구간 합은 한 노드의 sqrt * 구간의 크기와 같습니다.

따라서 이러한 경우에는 $sum = lazy_{sqrt} \times (end - start + 1)$을 적용하면 됩니다.

그렇지 않다면, 조건을 만족할 때까지 구간을 쪼개야겠죠.

 

중단 조건 또한 이걸 그대로 적용시켰습니다. 모든 구간의 sqrt 값이 같다는 말은,

최솟값의 sqrt 값이나 최댓값의 sqrt 값이나 같다는 것을 의미합니다.

따라서 각 노드마다 그 구간의 min, max 값을 저장하고 활용해야 합니다.

Segment Tree Beats를 사용하는 문제에서는 이걸 자주 활용하는 것 같네요.

푼 문제마다 이게 없던 적이 없었습니다.

if (floor(sqrt(v[node].m)) == floor(sqrt(v[node].M))) //break condition of sqrt query
{
    lazy_sqrt = floor(sqrt(v[node].m)); //also = floor(sqrt(v[node].max))
    lazy_update(node, start, end);
    return;
}

 

이제 lazy propagation으로 돌아옵시다.

모든 구간의 sqrt 값이 같을 때 lazy_sqrt 값을 업데이트하고 propagation을 실행하므로,

... //lazy_update
if (v[node].lazy_sqrt != 0)
{
    v[node].sum = v[node].lazy_sqrt * (end - start + 1);
    v[node].m = v[node].M = v[node].lazy_sqrt;
    
    v[node << 1].lazy_sqrt = v[node << 1 | 1].lazy_sqrt = v[node].lazy_sqrt;
}

v[node].lazy_add = 0;
v[node].lazy_sqrt = 0;

와 같은 형태가 될 것입니다.

이제 lazy_add와 합쳐야 합니다.

1] lazy_add = 0이고, lazy_sqrt = 0이면 노드 업데이트할 이유가 없습니다.

2] lazy_add != 0이고, lazy_sqrt = 0이면 노드에 값을 더하기만 하면 됩니다.

3] lazy_add = 0이고, lazy_sqrt != 0이면 위와 같이 업데이트하면 됩니다.

4] lazy_add != 0이고, lazy_sqrt != 0이면, lazy_add가 sqrt 적용 이전의 값인지 후의 값인지에 따라 계산이 달라집니다.

 

4번의 경우를 좀 더 자세히 봅시다.

현재 노드에 대한 쿼리를 실행할 때는 항상 lazy propagation을 먼저 실행합니다.

따라서 현재 노드의 lazy value는 모두 0으로 초기화된 상태입니다.

sqrt 쿼리를 실행할 때, 그 쿼리 이전에 더해야 할 건 모두 더했음을 의미합니다.

sqrt 쿼리의 중단 조건이 $\left \lceil \sqrt{\min} \right \rceil = \left \lceil \sqrt{\max} \right \rceil$ 이므로, 해당 구간에 존재하는 노드의 값은 모두 현재 노드의 값과 같습니다.

따라서 자식 노드의 값 또한 lazy_add 값에 상관 없이 현재 노드의 lazy_sqrt 값이 됩니다.

그러므로 sqrt 쿼리에서는 4번의 경우가 발생하지 않습니다. (항상 3번의 경우가 발생합니다)

다시 말하면 lazy_add는 sqrt 적용 이후에 더한 값만 존재할 수 있다고 볼 수 있습니다.

쿼리를 실행할 때 이미 propagation을 모두 실행했으니까요.

 

그러니 lazy_sqrt != 0일 때 add 쿼리를 실행한 경우를 살펴봅시다. (sqrt 이후에 더하는 경우)

부모 노드에서 sqrt 쿼리를 실행하고, 현재 노드에서 add 쿼리를 실행한 경우라면 가능합니다.

$A_i = \left \lceil \sqrt{A_i} \right \rceil + X$ 꼴이 됩니다.

add 쿼리이므로, 그냥 sqrt한 값에 lazy_add를 더하면 됩니다. (= lazy_sqrt + lazy_add)

lazy_sqrt != 0은 거꾸로 말하면, sqrt 쿼리의 중단 조건을 만족했다는 말이므로 현재 노드 및 자식 노드의 값이 모두 동일한 값입니다. lazy_sqrt와 lazy_add 모두 그대로 자식에게 전파해줄 수 있습니다.

v[node].sum = (v[node].lazy_sqrt + v[node].lazy_add) * (end - start + 1);
v[node].m = v[node].M = v[node].lazy_sqrt + v[node].lazy_add;

if (start != end)
{
    v[node << 1].lazy_sqrt = v[node << 1].lazy_sqrt = v[node].lazy_sqrt;
    v[node << 1].lazy_add = v[node << 1].add = v[node].lazy_add;
}

sqrt 쿼리를 실행할 때는 lazy_add = 0을 대입하면 코드가 동일하게 작동함을 알 수 있습니다.

따라서 최종적인 lazy_update 코드는 다음과 같습니다.

void lazy_update(int node, int start, int end)
{
    if (lazy_sqrt == 0)
    {
        v[node].sum += v[node].lazy_add * (end - start + 1);
        v[node].M += v[node].lazy_add;
        v[node].m += v[node].lazy_add;
        
        if (start != end)
        {
        	v[node << 1].lazy_add += v[node].lazy_add;
            v[node << 1 | 1].lazy_add += v[node].lazy_add;
        }
    }
    else
    {
        v[node].sum = (v[node].lazy_sqrt + v[node].lazy_add) * (end - start + 1);
        v[node].m = v[node].M = v[node].lazy_sqrt + v[node].lazy_add;

        if (start != end)
        {
            v[node << 1].lazy_sqrt = v[node << 1].lazy_sqrt = v[node].lazy_sqrt;
            v[node << 1].lazy_add = v[node << 1].add = v[node].lazy_add;
        }
    }
    
    v[node].lazy_add = 0;
    v[node].lazy_sqrt = 0;
}

 

...라고 하고 제출했더니 시간 초과가 발생했습니다.

고민해보다 빡쳐서 검색해서 JusticeHui 님의 풀이를 참고했더니 break condition이 하나 더 있었습니다.

 

구간의 $\min, \max$ 값이 딱 하나 차이나는 경우를 생각해봅시다.

위의 break condition에 걸리지 않았다면, $\sqrt{\min} \not= \sqrt{\max}$이므로

$\left \lceil \sqrt{\max} \right \rceil = \left \lceil \sqrt{\min} \right \rceil + 1$인 경우가 될 것입니다.

( $\sqrt{a+1} - \sqrt{a}$에 대해 제곱하면 $2a+1 - 2\sqrt{a^2+a} < 2a+1 - 2a = 1$

사실 엄밀하게 하려니까 바닥 함수 처리하기 너무 귀찮네요 ㅎㅎ;)

 

아무튼, $x \to \sqrt{x}$는 다시 말하면 $x = x + \sqrt{x} - x$와 같으므로,

해당 구간 전체에 이를 적용하는 것과 같습니다.

$\max - \min = \sqrt{\max} - \sqrt{\min}$이므로 덧셈을 적용해도 차이값은 유지가 됩니다.

이 점 때문에 중단 조건을 걸 수 있습니다.

 

예시로, 어떤 구간에 존재하는 값들이 8 8 8 8 9라면, $\left \lceil \sqrt{8} \right \rceil = 2, \left \lceil \sqrt{9} \right \rceil = 3$입니다. 따라서 중단 조건 1이 적용되지 않습니다.

sqrt 쿼리를 적용한 후의 값은 2 2 2 2 3이 될 것이며, $x = 8$인 노드에 대해서는 $x = x + \left \lceil \sqrt{x} \right \rceil - x$를 적용한 것과 같고, $y = 9$인 노드에 대해서도 $y = y + \left \lceil \sqrt{y} \right \rceil - y$를 적용한 것과 같습니다.

$y = x + 1$ 등을 적용하면 $y = y + \left \lceil \sqrt{x} + 1 \right \rceil - (x + 1) = y + \left \lceil \sqrt{x} \right \rceil - x$ 이므로, 더하는 값이 동일함을 알 수 있습니다. 따라서 전체 구간에 이 값을 add 하는 것으로 계산을 일찍 끝낼 수 있습니다.

if (v[node].m + 1 == v[node].M) //break condition 2
{
    v[node].lazy_add = floor(sqrt(v[node].m)) - v[node].m;
    lazy_update(node, start, end);
    return;
}

 

이렇게까지 하고나서야 제출하니 AC를 받을 수 있었습니다.

속이 메스껍네요. 풀이가 틀렸을 수도 있겠다는 생각이 듭니다.

 

3. 코드

더보기
#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long ll;

struct node
{
    ll m, M, sum, lazy_add, lazy_sqrt;
};

struct segTree_beats
{
    node v[1 << 18];

    void lazy_update(int node, int start, int end)
    {
        if (v[node].lazy_add == 0 && v[node].lazy_sqrt == 0) return;

        if (v[node].lazy_sqrt == 0)
        {
            v[node].sum += v[node].lazy_add * (end - start + 1);
            v[node].M += v[node].lazy_add;
            v[node].m += v[node].lazy_add;

            if (start != end)
            {
                v[node << 1].lazy_add += v[node].lazy_add;
                v[node << 1 | 1].lazy_add += v[node].lazy_add;
            }
        }
        else
        {
            v[node].sum = (v[node].lazy_sqrt + v[node].lazy_add) * (end - start + 1);
            v[node].m = v[node].M = v[node].lazy_sqrt + v[node].lazy_add;

            if (start != end)
            {
                v[node << 1].lazy_add = v[node << 1 | 1].lazy_add = v[node].lazy_add;
                v[node << 1].lazy_sqrt = v[node << 1 | 1].lazy_sqrt = v[node].lazy_sqrt;
            }
        }

        v[node].lazy_add = v[node].lazy_sqrt = 0;
    }

    node init(ll* arr, int node, int start, int end)
    {
        if (start == end) return v[node] = { arr[start], arr[start], arr[start], 0, 0 };

        int mid = (start + end) >> 1;
        init(arr, node << 1, start, mid);
        init(arr, node << 1 | 1, mid + 1, end);

        return v[node] = { min(v[node << 1].m, v[node << 1 | 1].m), max(v[node << 1].M, v[node << 1 |  1].M) , v[node << 1].sum + v[node << 1 | 1].sum, 0, 0};
    }

    void add(int node, int start, int end, int left, int right, ll val)
    {
        lazy_update(node, start, end);

        if (left > end || right < start) return;

        if (left <= start && end <= right)
        {
            v[node].lazy_add += val;
            lazy_update(node, start, end);
            return;
        }

        int mid = (start + end) >> 1;

        add(node << 1, start, mid, left, right, val);
        add(node << 1 | 1, mid + 1, end, left, right, val);

        v[node].m = min(v[node << 1].m, v[node << 1 | 1].m);
        v[node].M = max(v[node << 1].M, v[node << 1 | 1].M);
        v[node].sum = v[node << 1].sum + v[node << 1 | 1].sum;
    }

    void square(int node, int start, int end, int left, int right)
    {
        lazy_update(node, start, end);
        
        if (left > end || right < start) return;
        if (left <= start && end <= right)
        {
            if (floor(sqrt(v[node].m)) == floor(sqrt(v[node].M))) //each results would be all the same
            {
                v[node].lazy_sqrt = floor(sqrt(v[node].m));
                lazy_update(node, start, end);
                return;
            }

            if (v[node].m + 1 == v[node].M) //differernce of sqrt also be 1
            {
                v[node].lazy_add = floor(sqrt(v[node].m)) - v[node].m; //new_val = sqrt(old_val) => new_val = old_val + sqrt(old_val) - old_val
                lazy_update(node, start, end);
                return;
            }
        }

        int mid = (start + end) >> 1;
        square(node << 1, start, mid, left, right);
        square(node << 1 | 1, mid + 1, end, left, right);

        v[node].m = min(v[node << 1].m, v[node << 1 | 1].m);
        v[node].M = max(v[node << 1].M, v[node << 1 | 1].M);
        v[node].sum = v[node << 1].sum + v[node << 1 | 1].sum;
    }

    ll sum(int index, int start, int end, int left, int right)
    {
        lazy_update(index, start, end);

        if (left > end || right < start) return 0;
        if (left <= start && end <= right) return v[index].sum;

        int mid = (start + end) >> 1;
        return sum(index << 1, start, mid, left, right) + sum(index << 1 | 1, mid + 1, end, left, right);
    }
};

int N, M, op, a, b, c, d;
ll arr[100001];
segTree_beats tree;

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> N;
    for (int i = 0; i < N; i++) cin >> arr[i];
    tree.init(arr, 1, 0, N - 1);

    cin >> M;

    while (M--)
    {
        cin >> op >> a >> b;
        if (op == 1)
        {
            cin >> c;
            tree.add(1, 0, N - 1, a - 1, b - 1, c);
        }
        else if (op == 2)
            tree.square(1, 0, N - 1, a - 1, b - 1);
        else
            cout << tree.sum(1, 0, N - 1, a - 1, b - 1) << '\n';
    }

    return 0;
}

 

이제 잠시 엔진 쪽으로 넘어가고 싶습니다.

문제 푸는데 머리가 너무 고통받아요