ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [1153D] Serval and Rooted Tree
    문제 풀이/CodeForces 2023. 2. 5. 19:57

    난이도: 1900

     

    태그

    더보기
    • Dynamic Programming → DP on Tree (트리에서의 동적 계획법)
    • Greedy (그리디)
    • Depth-First Search (깊이 우선 탐색)

     

    풀이

    1. 특정 정점에 쓰인 값을 맨 위로 올려보자.

    더보기

    특정 값을 맨 위로 올리는 대신, 특정 "정점"을 맨 위로 올려보자는 생각입니다.

     

    그 정점에 쓰인 값을 \( 0 \)이라고 하면, 다른 정점에는 \( 0 \)보다 큰 값이나 \( 0 \)보다 작은 값이 들어가게 됩니다.

    이를 각각 \( + \)와 \( - \)라고 표현해봅시다.

     

    2. [1]을 토대로, 특정 정점을 맨 위로 올릴 때 우리가 해야 하는 걸 보자.

    더보기

    대강 이렇게 생각해보면, 0이 맨 위로 올라간다는 소리는 이 과정에서 만나는 max와 min에 대해

    • max의 다른 자식들은 -로
    • min의 다른 자식들은 +로

    고정된다는 의미입니다.

     

    또한, 같은 수를 2번 이상 쓰지 않으므로, 이러한 다른 자식들에는 0이 들어갈 수 없고, -와 +만으로 구성되어야겠죠.

     

    저희의 목표인, 맨 위에 올라가는 값을 최대화하기 위해서는

    0보다 더 큰 값인 +의 개수를 최소화하는 것이 목표가 됩니다.

     

    3. \( + \)의 개수를 최소화하는 방법은?

    더보기

    아무래도 이렇게만 쓰면 뭔가 어려워 보입니다.

    만약에 그 주변 정점들이 리프노드였다면, 어차피 +나 -로 고정되니까 간단해지지만

    그게 리프가 아닌 다른 정점이었다면, 그 정점을 +로/-로 만들기 위해 또 이런저런 걸 해야 하니까요.

     

    그러니, 계산은 컴퓨터에게 맡겨둡시다.

    \( dp_{v, s} \) = \( v \)를 루트로 하는 서브트리에서, 정점 \( v \)에 쓰인 값이 \( s \)가 되게 하기 위해 필요한 \( + \)의 최소 개수 를 정의해봅시다.

    저희가 이 DP를 사용하는 위치는 [2]에서 0이 쓰이지 않는 위치이므로, 모든 정점에는 \( + \)랑 \( - \)만이 쓰인다고 생각하면 됩니다.

     

    초항은, 모든 리프노드 \( v \)에 대해 \( dp_{v, -} = 0 \), \( dp_{v, +} = 1 \)이 됩니다.

    또한, 서브트리의 모든 정점을 \( - \)로 채워넣으면 정점 \( v \)에도 \( - \)가 쓰일테니, \( dp_{v, -} = 0 \)이 됩니다.

     

    점화식은, 아래 4가지 경우로 나눠서 생각해봅시다.

    • 정점 \( v \)에 적힌 연산이 \( \max \)인 경우
      • \( dp_{v, +} = \min\limits_{w} \sum\limits_{u \in \text{child}(v) \text{ and } u \neq w} (dp_{u, -}) + dp_{w, +} = \min\limits_{w} dp_{w, +} \)
        정점 \( v \)에 \( + \)를 적게 하려면 \( + \)가 적어도 1개 있어야 하지만, \( + \)의 개수를 최소화하려면 1개만 있는 게 최적이겠죠.
        그래서, \( + \)를 적을 위치 \( w \)를 정한 뒤, 나머지는 \( - \)로 채우는 걸 생각해보면 됩니다.
      • \( dp_{v, -} = 0 \)
    • 정점 \( v \)에 적힌 연산이 \( \min \)인 경우
      • \( dp_{v, +} = \sum\limits_{w \in \text{child}(v)} dp_{w, +} \)
        모든 곳에 \( + \)가 적혀야 min값도 \( + \)가 될테니, 어쩔 수 없죠.
      • \( dp_{v, -} = 0 \)

     

    이렇게 하면, \( dp_{v, -} \)와 \( dp_{v, +} \)를 모두 계산해낼 수 있습니다.

    이제 경로 탐색을 한 칸씩 들어가면서, 경로와 인접한, 선택되지 않은 정점들에 +와 -를 적절히 넣어주면서

    dp의 합을 구해주면 됩니다.

     

    그런데... 모든 리프노드를 시작점으로 잡아봐야 문제의 정답을 알 수 있을텐데

    이는 어떡할까요?

     

    4. 백트래킹하듯이, 이미 계산한 건 그대로 놔두자

    더보기

    dfs를 돌리면서, 각 정점에 대해 \( dp_{\text{child}(v), s} \)의 합을 적절히 구해둡시다.

    \( s \)는 정점 \( v \)의 연산 \( \text{min or max} \)에 따라 적절히 정해집니다.

     

    그리고, 거기서 한 정점을 0으로 선택한다면, 다른 자식들도 들어가야 할 값이 결정되므로

    이 값들을 (위에서 구해둔 합을 토대로) \( O(1) \)에 구해주면 됩니다.

     

    리프노드에 도달하면, 내려오면서 계산한 합을 토대로 답을 업데이트해주면 되고,

    탐색이 끝나면, 백트래킹하듯이 아까 구해둔 합을 다시 Revert해주면 됩니다.

     

    코드

    더보기

    문제의 답은, 이렇게 구한 \( + \)의 개수에 대해

    (리프노드의 개수) - (필요한 최소 \( + \)의 개수)가 됩니다.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    vector<int> adj[300020];
    int arr[300020];
     
    int dp[300020][2];
    void dpf(int now){
        int s0 = 0, s1 = 0, m0 = 1e9, m1 = 1e9;
        for (int nxt : adj[now]){
            dpf(nxt);
            s0 += dp[nxt][0]; s1 += dp[nxt][1];
            m0 = min(m0, dp[nxt][0]); m1 = min(m1, dp[nxt][1]);
        }
        if (adj[now].size() == 0){ dp[now][0= 0; dp[now][1= 1; }
        else{
            if (arr[now] == 0){ dp[now][0= s0; dp[now][1= s1; }
            if (arr[now] == 1){ dp[now][0= s0; dp[now][1= m1; }
        }
    }
     
    int ans = 1e9, res = 0;
    void dfs(int now){
        int sum = 0;
        for (int nxt : adj[now]){ sum += dp[nxt][ !arr[now] ]; }
        if (adj[now].size() == 0){ ans = min(ans, res); }
        else{
            for (int nxt : adj[now]){
                int val = sum - dp[nxt][ !arr[now] ];
                res += val; dfs(nxt); res -= val;
            }
        }
    }
     
    void Main(){
        int n; cin >> n;
        for (int i = 1; i <= n; i++){ cin >> arr[i]; }
        for (int i = 2; i <= n; i++){ int x; cin >> x; adj[x].push_back(i); }
        int m = 0for (int i = 1; i <= n; i++){
            if (adj[i].size() == 0){ m += 1; }
        }
        dpf(1); dfs(1); cout << m-ans;
    }
    cs
Designed by Tistory.