문제 : https://www.acmicpc.net/problem/11012

 

[ 문제 요약 ]

  • 각 테스트케이스마다 2차원 좌표 평면 위에 여러 개의 점이 주어집니다. 이후 두 개의 좌표로 표현된 직사각형이 주어지며, 각 테스트케이스마다 해당 직사각형 내부(경계 포함)에 있는 점들의 개수를 구해 출력하는 문제입니다. 두 좌표는 직사각형의 서로 대각선 꼭짓점에 해당하며, 이 범위 안에 있는 모든 점의 개수를 세면 됩니다.

 

[ 테스트 케이스 설명 ]

1		// 테스트케이스 수(1<=20)
3 1		// 좌표의수N(0<=10,000), 쿼리수M(0<=50,000)
3 5		// N개의 줄에 x,y 각 좌표가 주어짐 각수는 0<=10의5승
2 3
1 1
1 2 1 3		// M개의 줄에 l, r, b, t가 주어짐 (0<=10의5승), l과r은 x좌표, b,t는 y좌표
답 : 2

 

 

[ 문제 풀이 ]

 

이 문제를 풀기 전에, 직사각형 안에 있는 점의 개수를 어떻게 빠르게 찾아 더 할 것인지를 먼저 알아야 합니다.

 

방법은, 좌표 평면에서 점의 위치를 특정한 형태로 만들어 놓고, 간단한 수학적인 계산으로 빠르게 구할 수 있습니다.

 

아래 표는 2차원 좌표평면을 표현한 것이고, 0은 점이 없는 것, 1은 점이 있는 위치를 나타냅니다.

 

세로는 Y좌표를, 가로는 X좌표를 나타냅니다.

 

 

Y좌표 :  3 0 0 1 0 1 0
Y좌표 :  2 0 1 1 0 0 0
Y좌표 : 1 0 1 0 1 0 1
Y좌표 : 0 0 0 0 0 0 0
  X좌표 : 0 X좌표 :  1 X좌표 :  2 X좌표 :  3 X좌표 :  4 X좌표 :  5

 

 

위 와 같이 있다고 했을 때, 위 표를 아래 표와 같은 형태로 전환해서 만들면,

 

사각형의 두 꼭짓점을 알았을 때 쉽게 사각형 안에 있는 점 개수의 총합을 구할 수 있습니다.

 

글로 먼저 설명하자면, x좌표 하나하나마다 y좌표가 낮은 곳에서 높은 곳으로 올라가면서 누적합을 구하면 됩니다.

 

 

Y좌표 :  3 0 2 ↑ 2 ↑ 1 ↑ 1 1 ↑
Y좌표 :  2 0 2 ↑ 1 ↑ 1 ↑  0 1 ↑
Y좌표 : 1 0 1 ↑ 0 ↑ 1 ↑ 0 1 ↑
Y좌표 : 0 0 0 ↑ 0 ↑ 0 ↑ 0
  X좌표 : 0 X좌표 :  1 X좌표 :  2 X좌표 :  3 X좌표 :  4 X좌표 :  5

 

 

위 와 같이 X좌표마다 Y좌표를 작은 수에서 큰 수로 올라가면서 누적합을 구합니다.

 

위와 같은 표가 준비되었을 때, 예를 들어 구하려는 사각형의 두 꼭짓점이 (y=3, x=1) / (y=1, x=3)이라 하면,

 

Y좌표는 3과 1입니다. 먼저 Y가 3일 때를 구합니다. X의 범위는 1~3까지가 되니 Y=3, X=1~3 범위의 모든 숫자를 더합니다.

 

그러면 2+2+1 = 5입니다. Y가 3일 때는 구했으니 Y가 1일 때의 합도 구합니다.

 

그런데 Y 값이 낮은 행은 Y 값에서 1을 추가로 빼야 합니다.

 

그럼 Y가 0이면서, X가 1~3 사이의 값을 모두 더합니다. Y좌표가 0인 경우는 X의 값이 모두 0이므로 합해도 0입니다.

 

그럼 Y=3일 때 구한 5와 Y=0일 때 구한 0을 빼면 5가 나오는데, 이 5가 해당 사각형 안에 있는 점의 개수입니다.

 

위와 같이, 좌표 평면에서 Y를 낮은 좌표에서 높은 좌표순으로 올라가면서 누적합을 구해놓고,

 

Y좌표에 따른 X좌표 사이의 값들의 합을 구해서 빼주기만 하면 됩니다.

 

이것을 빠르게 구현하기 위해 영속성 세그먼트 트리(PST)를 사용합니다.

 

영속성 세그먼트 트리의 기본 내용과 구현 방법은 아래 글을 참고해 주세요!

 

영속성 세그먼트 트리 : https://kyjdummy.tistory.com/5

 

Persistent Segment Tree(PST) - 영속성 세그먼트 트리

[ 글을 쓰게 된 계기 ]- PST(Persistent Segment Tree) 알고리즘을 공부하면서 느낀 점은, 개념에 대한 대략적인 설명은 많지만, 원리나 코드 적용 방법, 구현 과정에서의 인사이트를 초보자가 쉽게 이해

kyjdummy.tistory.com

 

영속성 세그먼트 트리를 구현할 때, 루트 노드(roots)는 Y좌표입니다. Y좌표 기준으로 X좌표를 업데이트해나갑니다.

 

 

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;
class Point implements Comparable<Point>{
	int x, y;
	Point(int x, int y){
		this.x=x;
		this.y=y;
	}
	@Override
	public int compareTo(Point o) {
		return y - o.y;
	}
}
class Node{
	int left, right, sum;
	Node(int l, int r, int s){
		left=l;
		right=r;
		sum=s;
	}
}
class Main{
	
	static final int LEN = 100_001;
	static ArrayList<Node> nodes;
	
	public static int init(int s, int e) {
		if(s == e)	// 리프노드인 경우 노드를 추가하고 해당 노드 번호를 반환하고 종료
		{
			nodes.add(new Node(-1, -1, 0));
			return nodes.size() - 1;
		}
		// 리프노드가 아닌 경우 왼쪽 오른쪽을 탐색하며 노드를 만들어나간다.
		int mid = (s + e) >> 1;
		int l = init(s, mid);
		int r = init(mid + 1, e);
		nodes.add(new Node(l, r, 0));
		// 최종적으로 자기자신의 노드번호(size - 1)를 반환
		return nodes.size() - 1;
	}
	public static int update(int nowNode, int s, int e, int targetIdx) {
		
		Node now = nodes.get(nowNode);
		
		if(s == e)// 리프노드인 경우, 기존 sum에서 + 1을 하여 x좌표에 대해 입력된 횟수를 1더함
		{
			nodes.add(new Node(-1, -1, now.sum + 1));
			return nodes.size() - 1;
		}
		// 리프노드가 아닌 경우, targetIdx의 값, 즉, 목표로하는 x좌표의 값에 따라 왼쪽으로갈지, 오른쪽으로 갈지 정하여 탐색
		int mid = (s + e) >> 1;
		int l = now.left;
		int r = now.right;
		
		if(targetIdx <= mid) {
			l = update(now.left, s, mid, targetIdx);
		}
		else {
			r = update(now.right, mid + 1, e, targetIdx);
		}
		
		// 왼쪽, 혹은 오른쪽 자식이 갱신되었을 텐데, 그 갱신된 값을 통해 노드를 새롭게 생성한다. 이 때 x좌표에 대해 입력된 횟수 + 1을함
		nodes.add(new Node(l,r, now.sum + 1));
		// 갱신된 노드번호(size - 1)반환
		return nodes.size() - 1;
	}
	public static int sum(int nowNode, int s, int e, int left, int right) {
		
		if(e < left || right < s)
			return 0;
		
		if(left<=s && e<=right)
			return nodes.get(nowNode).sum;
		
		int mid = (s + e) >> 1;
		
		return sum(nodes.get(nowNode).left, s, mid, left, right)
				+ sum(nodes.get(nowNode).right, mid + 1, e, left, right);
	}
	public static void main(String[] args)throws Exception{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringBuilder sb = new StringBuilder();
		int T = Integer.parseInt(br.readLine());//테스트케이스 수(1<=20)
		while(T-->0)
		{
			StringTokenizer st = new StringTokenizer(br.readLine());
			int N = Integer.parseInt(st.nextToken());// 좌표수
			int M = Integer.parseInt(st.nextToken());// 쿼리수
			Point point[] = new Point[N];
			int roots[] = new int[LEN + 1];
			nodes = new ArrayList<>();
			
			for(int i=0; i<N; i++)
			{
				st = new StringTokenizer(br.readLine());
				int x	= Integer.parseInt(st.nextToken()) + 1;// 좌표의수N(0<=10,000)
				int y	= Integer.parseInt(st.nextToken()) + 1;// 쿼리수M(0<=50,000)
				point[i]= new Point(x,y);
			}
			// 초기 세그먼트 트리 생성, roots[0]에는 초기 세그먼트트리의 root번호, 즉 nodes의 사이즈가 들어간다.
			roots[0] = init(1, LEN);
			
			Arrays.sort(point);
			
			int prevY = 0;
			for(Point p : point)
			{
				for(int y=prevY; y<=p.y; y++)
				{
					// 이전 탐색한 y좌표 부터 현재 파악하려는 p.y좌표까지 복사를 진행한다.
					// 이로써 쿼리에서 어떤 y좌표가 입력되어도 그 때의 세그먼트트리 루트노드를 찾아갈 수 있다.
					roots[y] = roots[prevY]; 
				}
				// 복사된 루트노드(root[p.y])를 시작으로 x의 인덱스를 마킹하고 새로운 루트 노드를 재저장한다.
				// p.y가 같은 값이 여러개 입력되어도, 결국 roots[p.y]에는 마지막으로 x좌표를 업데이트한 루트노드번호만 저장됨
				roots[p.y] = update(roots[p.y], 1, LEN, p.x);
				
				prevY = p.y;
			}
			// 남은 y좌표를 끝까지 update한다.
			for(int y=prevY; y<=LEN; y++)
				roots[y] = roots[prevY];
			
			int result = 0;
			
			while(M-->0)
			{
				st = new StringTokenizer(br.readLine());
				// l, r, b, t가 주어짐 (0<=10의5승), l과r은 x좌표, b,t는 y좌표
				int l = Integer.parseInt(st.nextToken()) + 1;
				int r = Integer.parseInt(st.nextToken()) + 1;
				int b = Integer.parseInt(st.nextToken()) + 1;
				int t = Integer.parseInt(st.nextToken()) + 1;
				
				result += sum(roots[t], 1, LEN, l, r)
							- sum(roots[b - 1], 1, LEN, l, r);
			}
			sb.append(result)
				.append('\n');
		}
		System.out.print(sb);
	}
}

 

 

 

[ 코드 해석 ]

- 좌표 입력 시 +1을 하는 이유 : roots 배열은 y좌표 기준으로 생성되는데, PST를 처음 init 할 때, y좌표 0에 루트 노드 값을 담기 때문에 y좌표가 0인 경우가 필요합니다. 입력되는 y좌표의 범위가 0 ~ 100,000이기 때문에 +1을 하여 입력되는 y좌표를 최소 1부터 시작되도록 보정한 것입니다.

 

- roots 배열을 복사해나가는 이유 : 코드에서 roots[y] = roots[prevY] 부분이 있는데, 이건 모든 y좌표에 대해서 PST의 루트 노드를 저장시켜야 하기 때문입니다. 주어지는 점의 좌표와, 구하고자 하는 사각형의 꼭짓점 좌표가 일치하지 않을 수 있기 때문에 모든 y 점에 대해서 루트 노드를 저장시켜야 합니다. 복잡해 보이지만 결국 y좌표 기준으로 x좌표를 업데이트한 최종 PST만 쓰이게 됩니다.

 

문제 : https://www.acmicpc.net/problem/7469

 

[ 문제 요약 ]

- 1차원 배열이 주어지고, 그 배열의 범위가 주어지면, 그 범위 안에서 K 번째로 큰 수를 출력하는 것

 

[ 테스트 케이스 설명 ]

7 3		// 배열의크기 N(1<=100,000), 쿼리수 Q(1<=5,000)
1 5 2 6 3 7 4	// 각 정수는 절대 값 십억을 넘지 않는 정수
2 5 3		// 왼쪽범위i,오른쪽범위j,찾을k번째수(1<i,j<=N / 1<= k<= j-i+1)
4 4 1
1 7 3
//답
5
6
3

 

[ 문제 풀이 ]

저는 이 문제를 영속성 세그먼트 트리를 이용해 풀었습니다. 영속성 세그먼트 트리의 기본 내용과 구현 방법은 아래 글을 참고해 주세요!

 

PST 알고리즘 : https://kyjdummy.tistory.com/5

 

Persistent Segment Tree(PST) - 영속성 세그먼트 트리

[ 글을 쓰게 된 계기 ]- PST(Persistent Segment Tree) 알고리즘을 공부하면서 느낀 점은, 개념에 대한 대략적인 설명은 많지만, 원리나 코드 적용 방법, 구현 과정에서의 인사이트를 초보자가 쉽게 이해

kyjdummy.tistory.com

 

그러나 이 문제는 단순히 PST 알고리즘을 알고 있다고 풀 수 있지는 않습니다. 다른 알고리즘 스킬이 들어가야 합니다. 

 

테스트 케이스와 같이 숫자 1, 5, 2, 6, 3, 7, 4가 있을 때, 특정 범위가 주어지면 거기서 어떻게 빠르게 K 번째 수를 찾느냐가 중요한 문제입니다.

 

아래 표를 보면 놀라운 사실을 하나 알 수 있습니다. 입력된 숫자들을 작은 수부터 차례로 1로 마킹을 해나가다 보면 규칙이 있습니다.

  1 5 2 6 3 7 4
첫번째 작은 수(1) 1 0 0 0 0 0 0
두번째 작은 수(2) 1 0 1 0 0 0 0
세번째 작은 수(3) 1 0 1 0 1 0 0
네번째 작은 수(4) 1 0 1 0 1 0 1
다섯번째 작은 수(5) 1 1 1 0 1 0 1
여섯번째 작은 수(6) 1 1 1 1 1 0 1
일곱번째 작은 수(7) 1 1 1 1 1 1 1

 

 

첫 번째 작은 수는 1입니다. 이 첫 번째 작은 수인 1에 숫자 1을 마킹합니다. 나머지는 0입니다.

두 번째 작은 수는 2입니다. 이 두 번째 작은 수인 2에 숫자 1을 마킹하고, 이전에 1이 마킹되었으니 복사해서 같이 마킹해 놓습니다.

세 번째 작은 수는 3입니다. 이 세 번째 작은 수인 3에 숫자 1을 마킹하고, 이전 1과 2가 마킹되었으니 복사해서 같이 마킹해 놓습니다. 

이렇게 일곱 번째 수까지 모두 반복합니다.

 

이랬을 때 놀라운 사실은 특정 범위가 주어지면, k 번째 작은 수는, 그 구간에 있는 1들의 합이 k가 되는 가장 첫 번째 수라는 것입니다.

 

예를 들어 인덱스 3부터 5까지 두 번째 작은 수가 무엇인지를 구한다 했을 때, 인덱스 3은 숫자 2이고, 인덱스 4는 6, 인덱스 5는 숫자 3입니다. 그러면 2,6,3 중 두 번째 작은 수는 3임을 쉽게 알 수 있지만, 위표를 통해 구하는 방법은 아래와 같습니다. 

 

- 3부터 5까지 구간 합이 2가 되는 가장 낮은 행번호 찾기.

 

이때 구간의 1의 합이 최초로 2가 되는 행은 세 번째 작은 수(3)입니다. 그 이후로 구간합이 2인 경우가 몇 개 더 생기게 되지만, 최초인 것이 답입니다. 

 

다른 예시로 1번째 인덱스 와 3번째 인덱스 사이에 3번째로 작은 수를 구한다 치면, 1,5,2 중에 3번째로 작은 수는 5입니다. 이것을 위 표를 통해 구한 다치면, 해당 구간의 1들의 합이 3이 되는 가장 첫 번째 행이 다섯 번째 작은 수, 즉 5가 됩니다.

 

위 내용이 선뜻 잘 이해되지 않을 수 있지만 손으로 해보면 이해가 쉽습니다.

 

결론은, 위와 같은 표를 만들어서 빠르게 K 번째 수를 구할 수 있습니다. 먼저 배열의 초기 값들을 입력받고, 값이 작은 순으로 정렬합니다. 이때, 입력받은 인덱스를 같이 저장해 놓아야 합니다. 

 

그리고 영속성 세그먼트 트리를 구현하여 작은 숫자대로 세그먼트 트리에 마킹해 나가면서 루트 노드들을 저장합니다.

 

세그먼트 트리에 마킹하는 것은 해당 숫자의 '인덱스'입니다. 그래야 추후 k 번째 수를 찾는 쿼리가 입력될 때, 입력된 그 인덱스를 갖고 이분 탐색을 진행할 수 있습니다.

 

그리고 해당 인덱스 범위 내에서 1의 합이 K가 되는 가장 작은 행번호를 가져와 출력해 주기만 하면 됩니다.

 

인덱스 범위 내에서 1의 합이 K가 되는 가장 작은 행번호는 이분 탐색으로 빠르게 구해줄 수 있습니다.

 

아래는 해당 내용을 그대로 코드로 구현한 것입니다.

 

 

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.StringTokenizer;
class Object implements Comparable<Object>{
	int idx, value;
	Object(int i, int v){
		this.idx = i;
		this.value = v;
	}
	@Override
	public int compareTo(Object o) {
		if(value != o.value)
			return value - o.value;
		return idx - o.idx;
	}
}
class Node{
	int left, right, sum;
	Node(int l, int r, int s){
		this.left	= l;
		this.right	= r;
		this.sum	= s;
	}
}
class Main{
	
	static ArrayList<Node> nodes;

	public static void main(String[] args)throws Exception{
		BufferedReader	br	= new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st	= new StringTokenizer(br.readLine());
		int N				= Integer.parseInt(st.nextToken());// 배열의크기 N(1<=100,000)
		int Q				= Integer.parseInt(st.nextToken());// 쿼리수 Q(1<=5,000)
		int roots[]			= new int[N + 1];// 루트는 원래 N개만 필요하지만, 최초 세그먼트 초기화시킨 후 담을 루트노드 값 때문에 + 1함
		List<Object> list	= new ArrayList<>();
		nodes				= new ArrayList<>();
		
		st = new StringTokenizer(br.readLine());
		for(int i=0; i<N; i++)
		{
			int value = Integer.parseInt(st.nextToken());
			list.add(new Object(i, value));
		}
		
		Collections.sort(list);// value가 낮은 기준으로 오름차순 정렬
		
		// 세그먼트트리를 초기화하고 루트노드 번호를 roots[0]에 담는다.
		roots[0] = init(0, N - 1);
		
		for(int i=0; i<list.size(); i++)
		{
			// 배열의 크기만큼 PST버전을 만들어 주고 만든 후 PST의 각각의 루트 번호를 roots에 담는다.
			// i+1번째 루트노드는, i번째 루트 노드를 통해 구해진다.
			// value기준으로 오름차순 정렬된 입력순서(idx)를 세그먼트 트리에 마킹하며 버전을 갱신해준다.
			roots[i + 1] = update(roots[i], 0, N - 1, list.get(i).idx);
		}
		
		StringBuilder sb = new StringBuilder();
		while(Q-->0)
		{
			st = new StringTokenizer(br.readLine());
			int i = Integer.parseInt(st.nextToken()) - 1; // 인덱스가 0부터시작했기 때문에 -1을 해주어 인덱스 보정
			int j = Integer.parseInt(st.nextToken()) - 1; // 인덱스가 0부터시작했기 때문에 -1을 해주어 인덱스 보정
			int k = Integer.parseInt(st.nextToken());
			
			int s = 0;
			int e = N - 1;
			int res = 0;
			while(s <= e)
			{
				int mid = (s + e) >> 1;
                // 이분탐색을 통해 i,j범위의 합이 k가되는 가장 작은 mid를 찾는다.
				int cal = sum(roots[mid + 1], 0, N - 1, i, j);
				if(k <= cal)
				{
					res = mid;
					e = mid - 1;
				}
				else s = mid + 1;
			}
			
			sb.append(list.get(res).value)
				.append('\n');
		}
		System.out.print(sb);
	}
	public static int init(int s, int e) {
		if(s == e)// 리프 노드인 경우 왼쪽, 오른쪽 자신은 -1로, sum 0으로
		{
			nodes.add(new Node(-1, -1, 0));
			return nodes.size() - 1;
		}
		int mid = (s + e) >> 1;
		int l = init(s, mid);
		int r = init(mid + 1, e);
		
		nodes.add(new Node(l, r, 0));
		
		return nodes.size() - 1;
	}
	public static int update(int nowNode, int s, int e, int targetIdx) {
		
		Node now = nodes.get(nowNode);
		
		if(s == e)
		{
			nodes.add(new Node(-1, -1, now.sum + 1));
			return nodes.size() - 1;
		}
		int mid = (s + e) >> 1;
		int l = now.left;
		int r = now.right;
		
		if(targetIdx <= mid) {
			l = update(now.left, s, mid, targetIdx);
		}
		else {
			r = update(now.right, mid + 1, e, targetIdx);
		}
		
		nodes.add(new Node(l, r, now.sum + 1));
		
		return nodes.size() - 1;
	}
	
	public static int sum(int nowNode, int s, int e, int left, int right) {
		if(e < left || right < s)
			return 0;
		
		if(left<=s && e<=right)
			return nodes.get(nowNode).sum;
		
		int mid = (s + e) >> 1;
		
		return sum(nodes.get(nowNode).left, s, mid, left, right)
				+ sum(nodes.get(nowNode).right, mid + 1, e, left, right);
	}
}

 

 

[ 글을 쓰게 된 계기 ]

- PST(Persistent Segment Tree) 알고리즘을 공부하면서 느낀 점은, 개념에 대한 대략적인 설명은 많지만, 원리나 코드 적용 방법, 구현 과정에서의 인사이트를 초보자가 쉽게 이해할 수 있도록 정리한 자료는 거의 없다는 것이었습니다. 그래서 제가 문제를 풀며 정리한 내용을, 같은 고민을 하는 분들과 공유하고자 이 글을 작성하게 되었습니다. 이 글은 기본적인 세그먼트 트리 구현 방법을 알고 있는 분을 대상으로 작성되었습니다. 세그먼트 트리의 작동 원리를 모른다면, 본 내용을 이해하기는 조금 어려울 수 있습니다.

 

[ 영속성(Persistence)이란? ]

- 영속성이란, 이전 상태를 보존하면서 새로운 상태를 만들어내는 것을 의미합니다.

- 일반적인 세그먼트 트리처럼, 기존의 값을 덮어쓰지 않고, 변경 이력을 '버전별'로 관리합니다.

 

 

[ 일반적인 세그먼트 트리와의 차이 ]

구분 일반 세그먼트 트리 PST(영속성) 세그먼트 트리
업데이트 덮어쓰기 버전별로 새로운 트리 생성
과거 데이터 조회 불가능 가능
메모리 효율적 업데이트마다 노드가 늘어남

 

 

[ Persistent Segment Tree의 시각적 자료 ]

- 회색으로 칠해진 원래 트리에서, 오른쪽 가장 아래 값을 수정한다고 했을 때, 수정되는 파랑, 빨강 노드들만 새로 생성하고, 나머지는 주솟값을 연동하는 방식으로 구현합니다.

 

 

 

[ 구현하는 방법 ]

영속성 세그먼트 트리를 구현하기 위해 알아야 할 몇 가지가 있습니다. 

먼저 일반 세그먼트 트리는 아래 코드와 같이 1차원 배열로 간단하게 만들 수 있습니다.

 

int tree[] = new int[H];

 

 

 

하지만 영속성 세그먼트 트리를 구현할 때는 위와 같이 1차원 배열로 선언하는 것은 좋지 않습니다.(노드 업데이트가 적다면 상관없음) 영속성 세그먼트 트리는 값을 업데이트할 때마다 노드가 지속적으로 늘어나기 때문에, 노드를 많이 생성할 경우 언젠가 배열을 초과할 것입니다. 그래서 안전하게 List로 세그먼트 트리를 표현합니다.

 

그리고 필요한 노드만 생성하고, 나머지 노드들은 기존 노드의 주솟값(정확히는 list에서의 노드의 위치)을 갖고 있어야 하기 때문에 별도의 Node 자료구조를 생성합니다. 아래 코드는 Node 객체와 노드를 담을 리스트를 선언한 것입니다. 아래 코드에서 sum은 누적합 세그먼트 트리를 구현한다고 가정했을 때 누적합을 담을 변수입니다.

 

class Node{
	int left, right, sum;
	Node(int left, int right, int sum){
		this.left = left;
		this.right = right;
		this.sum = sum;
	}
}
List<Node> nodes = new ArrayList<>();

 

 

또한, 위 그림에서 보듯이, 오른쪽 아래 값을 수정할 때마다 루트 노드가 계속해서 추가로 늘어나며 생성됨을 알 수 있습니다. 이렇게 업데이트가 될 때마다 root 노드가 계속 생성되며 늘어나기 때문에 이 늘어나는 root 노드를 계속 저장해 줄 배열이나 리스트를 따로 두어야 합니다. 이렇게 저장해 놓으면, root 노드마다 어떻게 업데이트 되었는지 알 수 있는 것입니다. 이렇게 관리하기 때문에 버전별로 관리한다고 표현합니다. root 노드의 값을 저장할 때 저는 int형 배열로 생성했습니다. 이 배열안에 root 노드의 값이 들어가있습니다. 이 root 노드의 값은, 세그먼트 트리를 표현하는 list에서의 root노드 위치입니다.

 

 

int[] roots = new int[업데이트를 하는 횟수 + 1];
// +1을 하는 이유는, 최초 세그먼트트리 init시 루트노드를 담아야 하기 때문

 

 

위와 같이 기본적인 개념을 알고, 누적합을 구하는 세그먼트 트리를 만들되, 업데이트 시점마다 버전(새 루트 노드)을 저장하도록 하는 영속성 세그먼트 트리 코드를 보여드리겠습니다.

 

코드의 전체 흐름은, 최초 세그먼트 트리를 만들고(init), 값을 2번 업데이트하고, 각 버전 마다의 누적합을 출력하는 것입니다. 배열의 길이는 5이고, 업데이트는 2번 한다고 가정했습니다.

 

import java.util.ArrayList;
import java.util.List;

class Node{
	int left, right, sum;
	Node(int left, int right, int sum){
		this.left = left;
		this.right = right;
		this.sum = sum;
	}
}

class Main{
	
	static final int LEN = 5;
	static List<Node>	nodes = new ArrayList<>();
	static int[] roots = new int[3];
	
	public static void main(String[] args)throws Exception{
		
		roots[0] = init(1, LEN);	// 최초 세그먼트 트리 생성
		
		roots[1] = update(roots[0], 1, LEN, 5, 100);// 최초 root값 = root[0]
		
		roots[2] = update(roots[1], 1, LEN, 1, 100);// 첫번 째 갱신 후 root 값 = root[1]
		// roots[2]에는 두번 째 갱신 후 루트의 위치가 들어가 있습니다.
        
		int version1 = sum(roots[1], 1, LEN, 1, LEN);// 첫번 째 갱신 후의 sum
		
		int version2 = sum(roots[2], 1, LEN, 1, LEN);// 두번 째 갱신 후의 sum
		
		System.out.println(version1);// 출력 결과 : 100
		System.out.println(version2);// 출력 결과 : 200
	}
	public static int sum(int nowNode, int s, int e, int left, int right) {
		if(e < left || right < s)// 유효하지 않은 범위일 때 0 반환
			return 0;
		
		Node now = nodes.get(nowNode);
		
		if(left<=s && e<= right)
			return now.sum;	// 유효한 범위일 때 값 반환
		
		int mid			= (s + e) >> 1;
		int leftNode	= now.left;
		int rightNode	= now.right;
		
		return sum(leftNode, s, mid, left, right)
				+ sum(rightNode, mid + 1, e, left, right);
	}
	public static int update(int nowNode, int s, int e, int targetIdx, int value) {
		Node now = nodes.get(nowNode);
		if(s == e)
		{
			nodes.add(new Node(-1, -1, now.sum + value));// 기존 노드의 sum에서 value를 추가한 노드를 새롭게 생성
		}
		else
		{
			int mid		= (s + e) >> 1;
			int left	= now.left;
			int right	= now.right;
			// 업데이트할 위치(targetIdx)가 mid보다 왼쪽에 있다면 왼쪽으로 내려감
			if(targetIdx <= mid)
			{
				left = update(now.left, s, mid, targetIdx, value);
			}
			// 업데이트할 위치(targetIdx)가 mid보다 오른쪽에 있다면 오른쪽으로 내려감
			else
			{
				right = update(now.right, mid + 1, e, targetIdx, value);
			}
			// 새 노드 생성
			nodes.add(new Node(left, right, now.sum + value));
		}
		// 새롭게 만든 노드의 위치 반환
		return nodes.size() - 1;
	}
	public static int init(int s, int e) {
		if(s == e)// 루트 노드일 경우 left, right를 -1로 하여 유효하지 않은 값으로 세팅, 최초 sum값은 0으로 세팅
		{
			nodes.add(new Node(-1,-1, 0));
		}
		else
		{
			// 루트노드가 아닌 경우, 왼쪽과 오른쪽을 각각 탐색하여 그 노드의 위치를 가져옴
			int mid = (s + e) >> 1;
			int left = init(s, mid);
			int right = init(mid + 1, e);
			
			nodes.add(new Node(left, right, 0));
		}
		// 현재까지 노드에 들어간 위치를 반환함으로 써 nodes 리스트에서 현재 노드의 위치를 알 수 있게 한다.
		return nodes.size() - 1;
	}
}

 

 

[ 변수 설명 ]

구현의 편의성을 위해 세그먼트 트리의 node를 담을 리스트(nodes)와, 루트 노드의 정보를 담을 배열(roots)을 전역변수(static)으로 선언했습니다. 

 

[ init 함수 설명 ]

세그먼트 트리 생성은 후위 순회 식으로 값을 추가합니다. 루트 노드일 때와, 아닐 때로 나누어 각각 노드를 생성해 넣습니다. 그리고 반환 값은 새롭게 생성한 노드의 위치(nodes.size() - 1)를 반환해야 합니다. init 함수에서 무조건 하나의 node를 생성하기 때문에 항상 새롭게 생성한 노드의 위치를 반환해 주어야 합니다. 

 

[ update 함수 설명 ]

update 함수도 init 함수와 크게 다르지 않습니다. 무조건 함수안에서 하나의 노드를 새롭게 생성합니다. 다만 누적합(sum)을 기존 노드의 값 + 업데이트하려는 값(value)을 해주는 것이 다릅니다. 추가로 타겟팅하는 위치의 인덱스(targetIdx)를 잘 찾아가도록 mid 값으로 분기를 해주면 됩니다.

 

[ sum 함수 설명 ]

sum 함수는 일반적인 세그먼트 트리의 누적합을 구하는 것과 크게 다르지 않아 생략합니다. 

 

[ roots 저장 방식 ]

코드를 보면 update할 때마다 roots 배열에 값을 저장하는 것을 볼 수 있습니다. node가 후위 순회 식으로 저장되기 때문에 roots는 결국 그 당시 nodes 리스트의 size() - 1 값이 됩니다.

 

'알고리즘 > PST 세그먼트트리' 카테고리의 다른 글

BOJ 백준 11012 Egg  (0) 2025.04.07
BOJ 백준 7469 K번째 수  (0) 2025.04.06

+ Recent posts