본문 바로가기

Algorithm/Baekjoon

Baekjoon #2606 - 바이러스 풀이(단방향과 양방향 그래프)

문제

문제의 내용을 요약하자면, 서로 연결된 여러 대의 컴퓨터 중 첫 번째 컴퓨터가 바이러스에 걸렸을 때 바이러스에 감염되는 컴퓨터의 수를 구하는 것이다. 예를 들어, 아래 그림과 같이 컴퓨터가 연결되어 있고, 이 중 1번 컴퓨터가 바이러스에 감염되었다면 1번 컴퓨터와 연결된 모든 컴퓨터(2, 3, 5, 6)의 수량을 출력해야 한다.

입력으로 주어지는 값은 다음과 같다. 이때, 컴퓨터의 수가 7이면, 1번부터 차례대로 번호가 매겨진다.

기호 내용 예시
첫 번째 줄 N 컴퓨터의 수 7
두 번째 줄 M 직접 연결된 컴퓨터 번호 쌍의 수 6
세 번째 줄 + M 직접 연결된 컴퓨터 번호 쌍 1 2
2 3
1 5
5 2
5 6
4 7

풀이

총 두 번의 시도 끝에 성공하였다(소스코드). 정말 놓치기 쉬운 부분을 놓쳐서 오답처리가 되었는데, 무엇이 잘못 되었길래 오답 처리가 되었는지는 뒤에서 언급하도록 하고, 문제에 접근한 방식부터 설명해보겠다. 이 문제는 DFS 탐색 알고리즘 문제이다. 왜 그렇게 생각하냐면, 각각의 노드들이 연결되어 있고, 각 노드에는 또 다른 연결(간선) 정보가 있으므로 인접한 노드를 탐색하여야 하기 때문이다. 위의 예시에서 주어진 컴퓨터들의 연결 정보를 입력받고, 이를 리스트로 나타내면 다음과 같다.

pares = [[2, 5], [3], [], [7], [2, 6], []]

# 실제로는 -1을 해주었으므로 다음과 같다.
pares = [[1, 4], [2], [], [6], [1, 5], []]

여기서 1번 컴퓨터에 해당하는 정보는 [2, 5]이며, 이는 1번 컴퓨터가 2번과 5번 컴퓨터에 직접 연결되어 있다는 의미이다. 1번 컴퓨터가 2번 컴퓨터와 연결되어 있으므로, 2번 컴퓨터 또한 바이러스에 감염될 것이고, 2번 컴퓨터와 연결된 3번 컴퓨터도 바이러스에 감염될 것이다. 이 로직을 재귀함수로 나타낸 전체 코드는 다음과 같다.

cnt = int(input())

if cnt == 0:
    print(0)

else:
    pares = [[] for _ in range(cnt)]

    for _ in range(int(input())):
        fc, tc = map(int, input().split())
        pares[fc-1].append(tc-1)

    def recursive(index):
        if visited[index] == 1:
            return

        visited[index] = 1

        for c in pares[index]:
            recursive(c)

    visited = [0]*len(pares)
    current = 0
    recursive(current)
    print(sum(visited) - 1)

변수명을 설명해보자면, pares는 컴퓨터의 연결 쌍 정보를 저장한 리스트이고, fc, tc는 각각 현재 컴퓨터와 연결된 컴퓨터를 의미한다. 그리고 감염여부를 visited 리스트에 0으로 초기화하였다. 문제에서는 1번 컴퓨터가 감염되었을 경우라는 가정이 있기 때문에 현재 탐색하려는 컴퓨터인 current를 0으로 초기화하였다. 그리고 연결된 모든 노드를 탐색하며 0을 1로 바꿔주었다. 모든 재귀함수가 종료되면 visited의 값을 모두 더하고(리스트의 요소가 0 또는 1이므로 바이러스에 감염된 컴퓨터의 수를 알 수 있다), 1번 컴퓨터를 제외시키기 위해 1을 빼주었다. 이대로 코드를 실행하고 테스트케이스를 입력하면 원하는 값인 4가 출력되는 것을 확인할 수 있다.

그.러.나!

이 코드를 제출하면 20%정도에서 오답처리가 되는데, 여러번 코드를 뜯어도 보고, 직접 손으로 그림도 그려가면서 확인해봐도 그 이유를 찾을 수 없었다. 결국, 구글링해본 끝에 한가지 중요한 사실이 빠져있다는 것을 알게 되었다. 그 이유로 말할 것 같으면, 바로 방향성이다. 현재 내가 짜놓은 코드는 단방향 그래프이다. 이걸 어떻게 알 수 있냐면, 이 부분을 보면 된다.

for _ in range(int(input())):
    fc, tc = map(int, input().split())

    # 단방향
    pares[fc-1].append(tc-1)

예를 들어, 1 2가 입력되었을 때 1번 컴퓨터와 2번 컴퓨터가 연결되어 있다는 정보가 담겨있다. 이를 1번 컴퓨터에 해당하는 리스트에 2번 컴퓨터만 추가하였기 때문에 단방향 그래프가 형성된 것이다.

그러면 어떻게 해야하는가?

정말 간단하고도 당연한 말인데, 1 2를 입력받았을 때 1번 컴퓨터에 2번 컴퓨터를 추가하고, 2번 컴퓨터에도 1번 컴퓨터를 추가하면 된다. 그래야 양방향 그래프가 형성이 되기 때문이다. 만약에, 컴퓨터들이 아래와 같이 연결되어 있다고 가정해보자.

pares = [[2], [], [2], []]

즉, 1번 컴퓨터는 2번 컴퓨터와 연결되어 있는데, 2번 컴퓨터는 어떠한 연결된 정보가 담겨있지 않고, 3번 컴퓨터에서는 2번 컴퓨터와 연결되어 있다는 정보를 입력받은 상태이다. 이 경우, 위의 코드를 실행하였을 때의 과정을 간단히 나타내보면 다음과 같다.

# 1번: [2] -> 2번: [] -> 없음 (종료)

그 결과 3번 컴퓨터에 대한 정보는 아예 누락되어 1번 컴퓨터로 인해 추가로 감염된 컴퓨터는 2번 컴퓨터 밖에 없다는 잘못된 연산이 수행되어 1을 출력하게 된다. 따라서, 다음과 같이 코드를 추가하여 양방향으로 연결해주어야 한다.

for _ in range(int(input())):
    fc, tc = map(int, input().split())

    # 양방향
    pares[fc-1].append(tc-1)
    pares[tc-1].append(fc-1)

전체 코드

cnt = int(input())

if cnt == 0:
    print(0)

else:
    pares = [[] for _ in range(cnt)]

    for _ in range(int(input())):
        fc, tc = map(int, input().split())

        pares[fc-1].append(tc-1)
        pares[tc-1].append(fc-1)

    def recursive(index):
        if visited[index] == 1:
            return

        visited[index] = 1

        for c in pares[index]:
            recursive(c)

    visited = [0]*len(pares)
    current = 0
    recursive(current)
    print(sum(visited) - 1)

마치며

처음으로 어떠한 답을 참고하지도 않고 DFS를 recursive로 구현한 문제이다. 물론, 양방향 처리까지 해주었다면 더할 나위 없었겠지만, 이를 계기로 또 하나 배우게 되었다는 점에서 큰 의미가 있는 것 같다.