AtCoder Beginner Contest 217 D – Cutting Woods をPython3で解く

Share

AtCoder上にある問題のうち、AtCoder Problemsでdiff 800以上と判定されているものを順番に解いていく企画。
基本的な考え方は全てコード中のコメントに入れてあるので、参照のこと。

出典:
AtCoder Beginner Contest 217 D – Cutting Woods

実はC++なら、二分探索のみ用いて、ほぼ愚直でも通る。実のところ、Pythonでもリストではなく Array を使用すれば、同様にほぼ愚直でも通るらしい。

ここでは、愚直ではない\(O(Q \log Q)\)の回答を二通り紹介しておく。

# AtCoder Beginner Contest 217 D - Cutting Woods
# https://atcoder.jp/contests/abc217/tasks/abc217_d
# tag: 座標圧縮 BIT Union_Find 平衡二分木

# 実は C++ だと std::set を利用して割とあっさり解けるが、
# Python3 では該当する標準ライブラリが無いため、かなり
# 難しくなる問題。

# ひとまず、1 <= L <= 10^9 という制限が厳しいので、
# 座標圧縮をしてから考える。

# 切られているか切られていないかは BIT で管理する。
# つまり、圧縮後の座標において切られているところに
# 1 を加えていくことで、左端から数えて何回切られているか
# という情報を管理するようにする。

# 長さを求めるクエリが来たときには、BIT 上で
# 手前で切られている地点と次に切られている地点を求め、
# 長さを出力する……という方針。

# BITクラス
# 内部処理は 1-indexed だが、引数・返り値は 0-indexed に統一。
class Binary_Indexed_Tree:
    def __init__(self, N):
        self._len = 1 << ((N-1).bit_length())
        self._tree = [0] * (self._len + 1)

    # pos に対して x を加える。
    def add_to(self, x, pos):
        pos += 1
        while pos <= self._len:
            self._tree[pos] += x
            pos = pos + (pos & -pos)

    # pos までの累積和を返す。
    def get_csum(self, pos):
        pos += 1
        result = 0
        while pos > 0:
            result += self._tree[pos]
            pos = pos - (pos & -pos)
        return result
    
    # 累積和が value となる最小のインデックスを返す。
    def get_lower(self, value):
        if value < 0:
            return 0

        result = 0
        check = self._len // 2

        # ここからBIT上で直接二分探索を行っているような感じ。
        while check > 0:
            if value > self._tree[result + check]:
                value -= self._tree[result + check]
                result += check
            check //= 2

        return result

# ここからメイン。
def main():
    L, Q = map(int, input().split())
    queries = [list(map(int, input().split())) for _ in range(Q)]

    # 座標圧縮。 0 と L も加えておく。
    x_list = [x for c, x in queries] + [0, L]
    x_list = list(set(x_list))
    x_list.sort()
    comp_dic = {x:i for i, x in enumerate(x_list)}

    # BIT 作成。0 と L の地点をあらかじめ切っておく。
    bit = Binary_Indexed_Tree(len(x_list))
    bit.add_to(1, comp_dic[0])
    bit.add_to(1, comp_dic[L])

    # クエリ処理。
    for c, x in queries:
        if c == 1:
            bit.add_to(1, comp_dic[x])
        
        else:
            # クエリで指定されているところが、
            # 左端から何回切られている地点かを求める。
            value = bit.get_csum(comp_dic[x])

            # 上記情報を元に、左側の切断地点を求める。
            lower = bit.get_lower(value)

            # 同様に、右側の切断地点を求める。
            nxt = bit.get_lower(value + 1)

            print(x_list[nxt] - x_list[lower])

main()

Union_Findを用いつつ、クエリを逆順に処理する解き方もある。実は本番ではこれを最初に思いついたので、こちらで通した。

# クエリを逆から見ていった場合、バラバラになっている木を
# どんどんくっつけていくという操作になる。
# これを Union Find で管理しつつ、順次(逆から)答えを
# 求めていくことが可能。

from typing import Union


class Union_Find:
    # 親管理リストと高さ管理リストを初期化し、
    # 要素N個のUnion-Find森を作成する。
    # 親管理リストは、基本的には自分のひとつ上の親を表すが、
    # 値が負の場合には、自身が最上位の親(リーダー)であることを表し、
    # 自分を含めたグループの人数を管理することとする
    def __init__(self, N):
        self.parent = [-1] * N
        self.rank = [0] * N
        self.group_count = N
        self.N = N

    # xの所属するグループのリーダーを返す
    def find(self, x):
        # 自分自身がリーダーなら、自分を返す
        if self.parent[x] < 0:
            return x

        # 再帰的に捜索し、見つかれば繋ぎ変えておく
        # (計算量が増える=面倒くさいので)高さ管理は行わない
        par = self.find(self.parent[x])
        self.parent[x] = par
        return par

    # xとyのグループを統合する
    # (xのリーダー(統合先), yのリーダー(統合元)) を返す
    def unite(self, x, y):
        # それぞれのリーダーに対する操作を行うことになる
        x = self.find(x)
        y = self.find(y)

        # リーダーが同じなら何もする必要がない
        if x == y:
            return (-1, -1)

        # 木の高さが同じ場合:
        # グループの人数を合計しつつ適当に繋ぎ、繋げられた方の高さを1増やす
        if self.rank[x] == self.rank[y]:
            self.parent[x] += self.parent[y]
            self.parent[y] = x
            self.rank[x] += 1
            self.group_count -= 1
            return (x, y)

        # 木の高さが違うなら、低い方を高い方につなぐ
        if self.rank[x] < self.rank[y]:
            x, y = y, x
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        self.group_count -= 1
        return (x, y)

    # xとyが同じグループかどうかを調べる
    def samep(self, x, y):
        return self.find(x) == self.find(y)

    # xの所属するグループのメンバー数を返す
    def get_member_count(self, x):
        x = self.find(x)
        return -self.parent[x]

# ここからメイン
def main():
    L, Q = map(int, input().split())
    queries = [list(map(int, input().split())) for _ in range(Q)]

    # あらかじめ出てくる地点を先読みする。
    # 0, L も切られる地点として含めておく。
    points = list(set([x for c, x in queries] + [0, L]))
    points.sort()

    # 各部分の長さ。
    piece_length = [r - l for l, r in zip(points, points[1:])]

    # クエリ上の位置 → n 番目の座標圧縮用辞書
    pos_to_n = {pos: n for n, pos in enumerate(points)}

    # Union_Find木の作成。
    # 各要素は最終状態での各ピースとする。
    # 結合時はリーダーを元にして piece_length を更新する。
    uft = Union_Find(len(piece_length))

    # 実際に切られる地点。
    cut_points = set([x for c, x in queries if c == 1] + [0, L])

    # 切られない地点は、あらかじめ結合しておく。
    for p in points:
        if p not in cut_points:
            x, y = uft.unite(pos_to_n[p]-1, pos_to_n[p])
            piece_length[x] += piece_length[y]

    result = []
    # 逆順にクエリを処理していく。
    for c, x in queries[::-1]:
        # 切るクエリなら、逆に木を結合する。
        if c == 1:
            x, y = uft.unite(pos_to_n[x]-1, pos_to_n[x])
            piece_length[x] += piece_length[y]

        # 指示された地点の長さを返す。
        else:
            result.append(piece_length[uft.find(pos_to_n[x])])

    # 得られた答えを逆順に出力。
    for r in result[::-1]:
        print(r)

main()

もちろん、C++ における std::set のような平衡二分探索木を用いて回答するやり方もあるが、scrblbugが作成したライブラリが無いため、ここでは紹介できないのであしからず。

Share

コメントを残す

メールアドレスが公開されることはありません。