알고리즘/PST 세그먼트트리

BOJ 백준 7469 K번째 수

kyjdummy 2025. 4. 6. 16:38

문제 : https://www.acmicpc.net/problem/7469

 

[ 문제 요약 ]

- 1차원 배열이 주어지고, 그 배열의 범위가 주어지면, 그 범위 안에서 K 번째로 큰 수를 출력하는 것

 

[ 테스트 케이스 설명 ]

7 3		// 배열의크기 N(1<=100,000), 쿼리수 Q(1<=5,000)
1 5 2 6 3 7 4	// 각 정수는 절대 값 십억을 넘지 않는 정수
2 5 3		// 왼쪽범위i,오른쪽범위j,찾을k번째수(1<i,j<=N / 1<= k<= j-i+1)
4 4 1
1 7 3
//답
5
6
3

 

[ 문제 풀이 ]

저는 이 문제를 영속성 세그먼트 트리를 이용해 풀었습니다. 영속성 세그먼트 트리의 기본 내용과 구현 방법은 아래 글을 참고해 주세요!

 

PST 알고리즘 : https://kyjdummy.tistory.com/5

 

Persistent Segment Tree(PST) - 영속성 세그먼트 트리

[ 글을 쓰게 된 계기 ]- PST(Persistent Segment Tree) 알고리즘을 공부하면서 느낀 점은, 개념에 대한 대략적인 설명은 많지만, 원리나 코드 적용 방법, 구현 과정에서의 인사이트를 초보자가 쉽게 이해

kyjdummy.tistory.com

 

그러나 이 문제는 단순히 PST 알고리즘을 알고 있다고 풀 수 있지는 않습니다. 다른 알고리즘 스킬이 들어가야 합니다. 

 

테스트 케이스와 같이 숫자 1, 5, 2, 6, 3, 7, 4가 있을 때, 특정 범위가 주어지면 거기서 어떻게 빠르게 K 번째 수를 찾느냐가 중요한 문제입니다.

 

아래 표를 보면 놀라운 사실을 하나 알 수 있습니다. 입력된 숫자들을 작은 수부터 차례로 1로 마킹을 해나가다 보면 규칙이 있습니다.

  1 5 2 6 3 7 4
첫번째 작은 수(1) 1 0 0 0 0 0 0
두번째 작은 수(2) 1 0 1 0 0 0 0
세번째 작은 수(3) 1 0 1 0 1 0 0
네번째 작은 수(4) 1 0 1 0 1 0 1
다섯번째 작은 수(5) 1 1 1 0 1 0 1
여섯번째 작은 수(6) 1 1 1 1 1 0 1
일곱번째 작은 수(7) 1 1 1 1 1 1 1

 

 

첫 번째 작은 수는 1입니다. 이 첫 번째 작은 수인 1에 숫자 1을 마킹합니다. 나머지는 0입니다.

두 번째 작은 수는 2입니다. 이 두 번째 작은 수인 2에 숫자 1을 마킹하고, 이전에 1이 마킹되었으니 복사해서 같이 마킹해 놓습니다.

세 번째 작은 수는 3입니다. 이 세 번째 작은 수인 3에 숫자 1을 마킹하고, 이전 1과 2가 마킹되었으니 복사해서 같이 마킹해 놓습니다. 

이렇게 일곱 번째 수까지 모두 반복합니다.

 

이랬을 때 놀라운 사실은 특정 범위가 주어지면, k 번째 작은 수는, 그 구간에 있는 1들의 합이 k가 되는 가장 첫 번째 수라는 것입니다.

 

예를 들어 인덱스 3부터 5까지 두 번째 작은 수가 무엇인지를 구한다 했을 때, 인덱스 3은 숫자 2이고, 인덱스 4는 6, 인덱스 5는 숫자 3입니다. 그러면 2,6,3 중 두 번째 작은 수는 3임을 쉽게 알 수 있지만, 위표를 통해 구하는 방법은 아래와 같습니다. 

 

- 3부터 5까지 구간 합이 2가 되는 가장 낮은 행번호 찾기.

 

이때 구간의 1의 합이 최초로 2가 되는 행은 세 번째 작은 수(3)입니다. 그 이후로 구간합이 2인 경우가 몇 개 더 생기게 되지만, 최초인 것이 답입니다. 

 

다른 예시로 1번째 인덱스 와 3번째 인덱스 사이에 3번째로 작은 수를 구한다 치면, 1,5,2 중에 3번째로 작은 수는 5입니다. 이것을 위 표를 통해 구한 다치면, 해당 구간의 1들의 합이 3이 되는 가장 첫 번째 행이 다섯 번째 작은 수, 즉 5가 됩니다.

 

위 내용이 선뜻 잘 이해되지 않을 수 있지만 손으로 해보면 이해가 쉽습니다.

 

결론은, 위와 같은 표를 만들어서 빠르게 K 번째 수를 구할 수 있습니다. 먼저 배열의 초기 값들을 입력받고, 값이 작은 순으로 정렬합니다. 이때, 입력받은 인덱스를 같이 저장해 놓아야 합니다. 

 

그리고 영속성 세그먼트 트리를 구현하여 작은 숫자대로 세그먼트 트리에 마킹해 나가면서 루트 노드들을 저장합니다.

 

세그먼트 트리에 마킹하는 것은 해당 숫자의 '인덱스'입니다. 그래야 추후 k 번째 수를 찾는 쿼리가 입력될 때, 입력된 그 인덱스를 갖고 이분 탐색을 진행할 수 있습니다.

 

그리고 해당 인덱스 범위 내에서 1의 합이 K가 되는 가장 작은 행번호를 가져와 출력해 주기만 하면 됩니다.

 

인덱스 범위 내에서 1의 합이 K가 되는 가장 작은 행번호는 이분 탐색으로 빠르게 구해줄 수 있습니다.

 

아래는 해당 내용을 그대로 코드로 구현한 것입니다.

 

 

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.StringTokenizer;
class Object implements Comparable<Object>{
	int idx, value;
	Object(int i, int v){
		this.idx = i;
		this.value = v;
	}
	@Override
	public int compareTo(Object o) {
		if(value != o.value)
			return value - o.value;
		return idx - o.idx;
	}
}
class Node{
	int left, right, sum;
	Node(int l, int r, int s){
		this.left	= l;
		this.right	= r;
		this.sum	= s;
	}
}
class Main{
	
	static ArrayList<Node> nodes;

	public static void main(String[] args)throws Exception{
		BufferedReader	br	= new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st	= new StringTokenizer(br.readLine());
		int N				= Integer.parseInt(st.nextToken());// 배열의크기 N(1<=100,000)
		int Q				= Integer.parseInt(st.nextToken());// 쿼리수 Q(1<=5,000)
		int roots[]			= new int[N + 1];// 루트는 원래 N개만 필요하지만, 최초 세그먼트 초기화시킨 후 담을 루트노드 값 때문에 + 1함
		List<Object> list	= new ArrayList<>();
		nodes				= new ArrayList<>();
		
		st = new StringTokenizer(br.readLine());
		for(int i=0; i<N; i++)
		{
			int value = Integer.parseInt(st.nextToken());
			list.add(new Object(i, value));
		}
		
		Collections.sort(list);// value가 낮은 기준으로 오름차순 정렬
		
		// 세그먼트트리를 초기화하고 루트노드 번호를 roots[0]에 담는다.
		roots[0] = init(0, N - 1);
		
		for(int i=0; i<list.size(); i++)
		{
			// 배열의 크기만큼 PST버전을 만들어 주고 만든 후 PST의 각각의 루트 번호를 roots에 담는다.
			// i+1번째 루트노드는, i번째 루트 노드를 통해 구해진다.
			// value기준으로 오름차순 정렬된 입력순서(idx)를 세그먼트 트리에 마킹하며 버전을 갱신해준다.
			roots[i + 1] = update(roots[i], 0, N - 1, list.get(i).idx);
		}
		
		StringBuilder sb = new StringBuilder();
		while(Q-->0)
		{
			st = new StringTokenizer(br.readLine());
			int i = Integer.parseInt(st.nextToken()) - 1; // 인덱스가 0부터시작했기 때문에 -1을 해주어 인덱스 보정
			int j = Integer.parseInt(st.nextToken()) - 1; // 인덱스가 0부터시작했기 때문에 -1을 해주어 인덱스 보정
			int k = Integer.parseInt(st.nextToken());
			
			int s = 0;
			int e = N - 1;
			int res = 0;
			while(s <= e)
			{
				int mid = (s + e) >> 1;
                // 이분탐색을 통해 i,j범위의 합이 k가되는 가장 작은 mid를 찾는다.
				int cal = sum(roots[mid + 1], 0, N - 1, i, j);
				if(k <= cal)
				{
					res = mid;
					e = mid - 1;
				}
				else s = mid + 1;
			}
			
			sb.append(list.get(res).value)
				.append('\n');
		}
		System.out.print(sb);
	}
	public static int init(int s, int e) {
		if(s == e)// 리프 노드인 경우 왼쪽, 오른쪽 자신은 -1로, sum 0으로
		{
			nodes.add(new Node(-1, -1, 0));
			return nodes.size() - 1;
		}
		int mid = (s + e) >> 1;
		int l = init(s, mid);
		int r = init(mid + 1, e);
		
		nodes.add(new Node(l, r, 0));
		
		return nodes.size() - 1;
	}
	public static int update(int nowNode, int s, int e, int targetIdx) {
		
		Node now = nodes.get(nowNode);
		
		if(s == e)
		{
			nodes.add(new Node(-1, -1, now.sum + 1));
			return nodes.size() - 1;
		}
		int mid = (s + e) >> 1;
		int l = now.left;
		int r = now.right;
		
		if(targetIdx <= mid) {
			l = update(now.left, s, mid, targetIdx);
		}
		else {
			r = update(now.right, mid + 1, e, targetIdx);
		}
		
		nodes.add(new Node(l, r, now.sum + 1));
		
		return nodes.size() - 1;
	}
	
	public static int sum(int nowNode, int s, int e, int left, int right) {
		if(e < left || right < s)
			return 0;
		
		if(left<=s && e<=right)
			return nodes.get(nowNode).sum;
		
		int mid = (s + e) >> 1;
		
		return sum(nodes.get(nowNode).left, s, mid, left, right)
				+ sum(nodes.get(nowNode).right, mid + 1, e, left, right);
	}
}