AtCoder Regular Contest 121 B – RGB Matching をPython3で解く

Share

AtCoder上にある問題のうち、AtCoder Problemsでdiff 800以上と判定されているものを順番に解いていく企画。
基本的な考え方は全てコード中のコメントに入れてあるので、参照のこと。

出典:
AtCoder Regular Contest 121 B – RGB Matching

最終的には、2つの数列の中からそれぞれ一つずつ数字を選び、その差を最小にする問題に帰着する。

# AtCoder Regular Contest 121 B - RGB Matching
# https://atcoder.jp/contests/arc121/tasks/arc121_b
# tag: 数列 考察 二分探索

# 明らかに同じ色の犬同士をペアにするほうがいいのだが、
# 頭数が奇数なら余りが出てしまうことがある。
# それをどのように組み合わせてやるかがポイント。

# 奇数頭数の色のグループを R, G、
# 偶数頭数の色のグループを B と仮に考える。

# この際、考慮しなければならない組み合わせは以下の通り

# 1) R, G から 1 頭ずつを選び、それを組み合わせる。
# これは R, G を前から順に見ていくことで、
# 最小コストを求めることができる。

# 2) R, G から 1 頭ずつ選び、それぞれを B から
# 選んだものと組み合わせる。
# これも同様に可能だが、RB の組み合わせと GB の組み合わせで
# 使用される B が同一のものである可能性がある。
# しかし、その場合、その最小コストは 1) の最小コスト
# 以上となるので、考慮しなくてもいい(数直線上に書いて
# 考察すると分かりやすい)。

def main():
    N = int(input())
    dogs = [input().split() for _ in range(2 * N)]

    dog_groups = [[] for _ in range(3)]

    # 色別にグループ分け。
    for a, c in dogs:
        a = int(a)
        if c == 'R':
            dog_groups[0].append(a)
        elif c == 'G':
            dog_groups[1].append(a)
        else:
            dog_groups[2].append(a)

    # ソートしておく。
    for g in dog_groups:
        g.sort()

    # 全て偶数頭数なら、最小値は 0 になる。
    if all(len(g) % 2 == 0 for g in dog_groups):
        print(0)
        return

    # 偶数頭数のものと奇数頭数のものに分ける。
    odd_groups = []
    for g in dog_groups:
        if len(g) % 2 == 1:
            odd_groups.append(g)
        else:
            even_group = g

    # 2つのソート済みリストから一つずつ数字を選び、
    # その差の最小を求める関数を作成しておく。
    def get_min_diff(list_a, list_b):
        # リストの中身がない場合は、最大値を返しておく。
        if len(list_a) == 0 or len(list_b) == 0:
            return 10**16

        idx_a, idx_b = 0, 0
        result = 10**16

        # 値の小さい方のインデックスを進めていきつつ、
        # 順に探索していく。
        while True:
            va, vb = list_a[idx_a], list_b[idx_b]
            if abs(va - vb) < result:
                result = abs(va - vb)

            if idx_a == len(list_a) - 1:
                idx_b += 1
            elif idx_b == len(list_b) - 1:
                idx_a += 1
            elif va <= vb:
                idx_a += 1
            else:
                idx_b += 1

            if idx_a == len(list_a) or idx_b == len(list_b):
                break

        return result

    result = min(
        get_min_diff(odd_groups[0], odd_groups[1]),
        get_min_diff(even_group, odd_groups[0]) + get_min_diff(even_group, odd_groups[1])
    )

    print(result)

main()

二分探索、というかbisectを用いて書くこともできる。

# 回答中の get_min_diff は、二分探索を利用して書いてもいい。
# いずれにせよ、ソートがボトルネックとなるため O(N log N)。

from bisect import bisect
def main():
    N = int(input())
    dogs = [input().split() for _ in range(2 * N)]

    dog_groups = [[] for _ in range(3)]

    # 色別にグループ分け。
    for a, c in dogs:
        a = int(a)
        if c == 'R':
            dog_groups[0].append(a)
        elif c == 'G':
            dog_groups[1].append(a)
        else:
            dog_groups[2].append(a)

    # ソートしておく。
    for g in dog_groups:
        g.sort()

    # 全て偶数頭数なら、最小値は 0 になる。
    if all(len(g) % 2 == 0 for g in dog_groups):
        print(0)
        return

    # 偶数頭数のものと奇数頭数のものに分ける。
    odd_groups = []
    for g in dog_groups:
        if len(g) % 2 == 1:
            odd_groups.append(g)
        else:
            even_group = g

    # 2つのソート済みリストから一つずつ数字を選び、
    # その差の最小を求める関数を作成しておく。
    def get_min_diff(list_a, list_b):
        # リストの中身がない場合は、最大値を返しておく。
        if len(list_a) == 0 or len(list_b) == 0:
            return 10**16

        result = 10**16

        for va in list_a:
            idx = bisect(list_b, va)
            if idx < len(list_b):
                vb = list_b[idx]
                if abs(va - vb) < result:
                    result = abs(va - vb)
            if idx > 0:
                vb = list_b[idx-1]
                if abs(va - vb) < result:
                    result = abs(va - vb)

        return result

    result = min(
        get_min_diff(odd_groups[0], odd_groups[1]),
        get_min_diff(even_group, odd_groups[0]) + get_min_diff(even_group, odd_groups[1])
    )

    print(result)

main()
Share

コメントを残す

メールアドレスが公開されることはありません。