https://www.acmicpc.net/problem/17471

그룹을 나눈 다음 유니온파인드로 해결한 문제이다.
유니온 파인드 알고리즘, python
유니온 파인드 알고리즘이랑, 두 집합을 하나로 합치고(union) 두 원소가 같은 집합에 속해있는지 확인하는(find) 알고리즘이다. 유니온 파인드를 구현하기 위해서는 부모노드의 번호를 저장할 root
aiden0413.tistory.com
n개를 2개의 그룹으로 나눌 수 있다.
1개 n-1개, 2개 n-2개... n-1개 1개로 나눌 수 있다.
즉 k,n-k (1 <=k <=n-1)로 나눌 수 있다.
k, n-k로 나눌때 combinations을 사용해서 모든 케이스를 구했다.
해당 케이스를 a그룹이라고 하고, a그룹에 포함되지않는 나머지 노드들을 b그룹이라고 하고 각각의 그룹에서 유니온 파인드를 구현하여 union 하였다.
모두 합친 뒤에 만약 2개의 그룹으로 분리되었다면 parents배열에 수의 종류가 2개만 있어야 하므로 그런 경우 각 그룹의 합을 계산에서 차이의 최솟값을 계산하였다.
union find 구현과정
paretns 배열: 해당 노드가 속한 트리의 루트노드를 저장하는 배열
level: 트리의 깊이를 저장하는 배열
find 함수: 해당 노드의 루트노드를 찾으면서 parents배열 갱신 (경로 압축)
만약 x가 속한 루트노드(find(x))의 트리의 깊이(level [find(x)])가
y가 속한 루트노드(find(y))의 트리의 깊이(level [find(y)])보다 깊다면
y의 루트노드(paretns [find(y)])의 루트노드를 x로 갱신한다.
반대의 경우 x에 대해 y로 갱신한다.
만약 두 트리의 깊이가 같다면, x의 깊이를 1 늘리고 y의 루트노드를 x로 갱신한다.
import sys
from itertools import combinations
from collections import deque
input = sys.stdin.readline
n = int(input())
arr = list(map(int, input().split()))
graph = [[] for _ in range(n)]
for i in range(n):
graph[i] = [x-1 for x in list(map(int, input().split()))[1:]]
def solve(comb):
parents = [x for x in range(n)]
level = [0]*n
# union find 계산
# 경로 압축
def find(x):
if parents[x] != x:
parents[x] = find(parents[x])
return parents[x]
# union
def union(x,y):
x = find(x)
y = find(y)
if level[x]>level[y]:
parents[y] = x
elif level[x]<level[y]:
parents[x] = y
else:
level[x] += 1
parents[y] = x
# 그룹 a와 b 생성
group1 = list(comb)
group2 = [x for x in range(n) if not x in comb]
# 각각의 그룹을 union 하여 parents 배열 갱신
for x in group1:
q = deque([x])
visited = [0]*n
visited[x] = 1
while q:
cur = q.popleft()
union(x,cur)
for next in graph[cur]:
if visited[next] == 0 and next in group1:
visited[next] = 1
q.append(next)
for x in group2:
q = deque([x])
visited = [0]*n
visited[x] = 1
while q:
cur = q.popleft()
union(x,cur)
for next in graph[cur]:
if visited[next] == 0 and next in group2:
visited[next] = 1
q.append(next)
if len(set(parents)) != 2:
return 10**16
# parents 배열에 수의 종류가 2개라면 2개의 그룹으로 분리 되었다는 뜻이므로 차이 계산
else:
sum1 = sum([arr[x] for x in range(n) if x in group1])
sum2 = sum([arr[x] for x in range(n) if x in group2])
return abs(sum1 - sum2)
answer = 10**16
# 모든 경우의 수 comb에 대해 수행
for x in range(1,n):
for comb in combinations(range(n),x):
answer = min(answer,solve(comb))
print(-1 if answer==10**16 else answer)

'python' 카테고리의 다른 글
| [프로그래머스] 미로 탈출 명령어, python (0) | 2025.10.04 |
|---|---|
| [백준] 17825 주사위 윷놀이, python (0) | 2025.10.04 |
| [백준] 17136 색종이 붙이기, python (0) | 2025.10.03 |
| [백준] 17135 캐슬 디펜스, python (0) | 2025.10.03 |
| [백준] 17070 파이프 옮기기 1, python (0) | 2025.10.03 |