연결된 간선 k-1 개를 모든 조합으로 나눠서 그룹의 수를 계산하기엔 너무 복잡도가 높아서 완전 탐색보단 이진 탐색으로 생각해보자.
특정 x 명인 그룹이 최대로 k 개의 그룹이 나누어 지는지, 이 중의 k개 그룹으로 나눠지는 x의 최대 x 는 무엇일까
그룹의 인원의 범위는 최소 모든 노드를 나눈 '노드의 최댓값' 에서 모든 시험장이 한 그룹인 '노드들의 총합' 까지 이다.
즉, 범위는 num.max ≤ x ≤ num.sum 이다.
탐색은 DFS 로 하면서 그룹을 나누는 기준이 필요하다. 기준은 총 3개가 필요하다.
노드를 탐색하면서 현재 노드의 값(node), 왼쪽 자식들의 합(left), 오른쪽 자식들의 합(right)을 계산한다.
1. 그룹을 나누지 않는다. (node + left + right ≤ x)
node + left + right 의 합이 x 보다 작으면 그룹을 나눌 필요가 없다.
2. 두 그룹으로 나눈다. (1 cut)
1에 해당되지 않은 (x ≤node + left + right) 노드 중에 node+left or node+right 둘 중 하나가 x 를 넘지 않으면 1번 나눈다.
3. 세 그룹으로 나눈다. (2 cut)
1,2 에 해당하지 않은 경우 2번 나눈다.
코드
public class P_81305 {
// 몇번 그룹을 나눌지 저장하는 전역변수
static int cut = 0;
// 루트 노드(시험장) 번호저장
static int root;
// 연결 노드
static int[][] trees;
// 노드(시험장)의 학생수
static int[] amount;
private static int solution(int k, int[] num, int[][] links) {
// 특정 x 명인 그룹이 최대로 k 개의 그룹이 나누어 지는지, 이 중의 k개 그룹으로 나눠지는 x의 최소 x 는 무엇인가?
// x 의 범위는 num.max ≤ x ≤ num.sum
int start = 0;
int end = 0;
int limit;
trees = links;
amount = num;
// 루트 노드 찾기
findRoot(num.length);
// 이진 탐색 범위
// start : 노드의 최댓값
// end : 노드들의 합
for (int n : num) {
start = Math.max(n, start);
end += n;
}
// 이진 탐색 시작
while (start<end) {
limit = (start + end)/2;
if (getGroup(limit) <= k) {
end = limit;
} else {
start = limit+1;
}
}
return start;
}
// root 노드 찾는 함수
private static void findRoot(int size) {
boolean[] notRoot = new boolean[size];
// root 를 찾기
// 자식이 없으면 root
for (int[] tree : trees) {
if (tree[0] != -1) notRoot[tree[0]] = true;
if (tree[1] != -1) notRoot[tree[1]] = true;
}
for (int i=0; i<notRoot.length; i++) {
if(!notRoot[i]) {
root = i;
break;
}
}
}
// 개별 그룹의 최대 학생수 limit 을 안넘는 그룹 개수
private static int getGroup(int limit) {
cut = 0;
DFS(root, limit);
// 총 그룹수는 나눈 횟수 + 1
return cut+1;
}
// 시험잠 탐색
private static int DFS(int node, int limit) {
int left, right;
left = (trees[node][0] == -1) ? 0 : DFS(trees[node][0], limit);
right = (trees[node][1] == -1) ? 0 : DFS(trees[node][1], limit);
// 1. 그룹이 나뉘어지지 않는다.(N + L + R ≤ limit)
if(amount[node] + left + right <= limit) {
return amount[node] + left + right;
}
// 2. 2그룹으로 나눈다.
// 2그룹으로 나눌수 있는 조건이 한번은 cut 나야하니까 (N + L + R ≥ limit 이기 때문에)
// node + min ≤ limit 값이면 1번만 cut. 이외면 2번 cut.
if (amount[node] + Math.min(left, right) <= limit) {
cut+=1;
return amount[node] + Math.min(left,right);
}
// 3. 이외에는 3그룹으로 나눈다.
cut += 2;
return amount[node];
}
}