선택 문제 (Selection Problem)

  • n개의 숫자들 중에서 k번째로 작은 숫자를 찾는 문제이다.

  • 간단한 해결 방법

    • 최소인 숫자들 찾아서 제거하는 방법을 k번 반복한다.

      • 최악의 경우 시간 복잡도 : O(kn)
    • 오름차순으로 정렬한 뒤 k번째 숫자를 찾는다.

      • 최악의 경우 시간 복잡도 : O(nlogn)
    • 간단하지만 비효율적이다.

  • 이진 탐색과 퀵 정렬을 섞으면 원하는 숫자를 효율적으로 찾을 수 있다.

  • 데이터 분석에서 중앙값(median)을 찾는 데 활용된다.


아이디어

  • 퀵 정렬에서 처럼 피벗을 정하고 피벗보다 작은건 왼쪽, 피벗보다 큰건 오른쪽에 위치하도록 정렬한다.

  • 작은 그룹의 크기와 큰 그룹의 크기는 피벗의 인덱스를 통해 알 수 있다.

  • 분할

    • 찾고자하는 k번째로 작은 숫자가 피벗인 경우 그대로 반환한다.

    • 찾고자하는 k번째로 작은 숫자가 작은 그룹에 있을 경우, 작은 그룹에서 k번째 작은 숫자를 찾는다.

    • 찾고자하는 k번째로 작은 숫자가 큰 그룹에 있을 경우, 큰 그룹에서 (k - (작은 그룹의 크기 + 1))번째로 작은 수를 찾는다.

  • 2개의 부분문제로 나눠지지만, 그중에 1개만 고려하면 되고, 부분문제의 크기가 일정하지 않은 크기로 감소하는 형태이다.

    pseudo code

    int Selection(int* arr, int left, int right, int k)
    {
      피벗을 선택하고 arr[left]와 자리를 바꾼다.
      피벗보다 작은 숫자는 arr[left] ~ arr[p - 1]로 이동한다.
      피벗보다 큰 숫자는 arr[p + 1] ~ arr[right]로 이동한다.
      피벗은 arr[p]에 위치시킨다. 여기서의 p는 위의 과정이 끝난 후 피벗보다 작은 그룹의 가장 오른쪽 인덱스이다.
      small_group_size = p - left;
      if (k <= small_group_size) return Selection(arr, left, p - 1, k);
      else if (k == small_group_size + 1) return arr[p];
      else return Selection(arr, p + 1, right, k - (small_group_size + 1));
    }


시간복잡도

  • 피벗을 정할 때 완벽하게 절반으로 나누는 분할은 불가능하다.

  • 시간 복잡도 계산을 위해 다음과 같은 조건을 설정했다.

    • 분할했을 때 큰 쪽의 크기가 입력 크기의 3/4 이상일 경우 bad 분할, 이하인 경우 good 분할이라고 정의한다.
  • 분할했을 때 good 분할이 될 확률은 1/2이다.

    • 평균 2회 연속해서 랜덤하게 피벗을 정하면 good 분할이 한번 나오는 것이다.
  • 따라서 평균 경우 시간 복잡도를 구하려면, 연속으로 good 분할이 됐다고 가정한 다음 해당 시간 복잡도에 2를 곱하면 되는 것이다.

  • 시간 복잡도를 구하는 과정은 다음과 같다.

    • 1번째 분할

      • 입력의 크기가 n일 때 두 그룹을 분할하는 데 걸리는 시간은 O(n)이다.

      • 분할 후 큰 부분의 최대 크기는 (3n - 1)/4 인데, 편의상 (3/4)n 이라고 생각한다.

    • 2번째 분할

      • 큰 부분의 입력 크기가 3n/4이고, 분할 시간은 O(3n/4)이다.

      • 분할 후 큰 부분의 최대 크기는 (3/4)2n이다.

    • 이를 반복하면 시간 복잡도는 다음과 같다.

      O(n + (3/4)n + (3/4)2n + ... + (3/4)in) = O(n)

    • 여기에 2를 곱한 평균 경우 시간 복잡도는 O(n)이다.


코드

#include <iostream>
#include <vector>

void PrintArr(const std::vector<int>& arr)
{
    for (const auto &e : arr)
        std::cout << e << ' ';
    std::cout << std::endl;
}

int Selection(std::vector<int>& arr, int left, int right, int k)
{
    if (left == right)
        return (arr[left]);

    int    pivot = (left + right) / 2;
    int    high = left + 1;
    int    low = right;

    std::swap(arr[pivot], arr[left]);
    while (high <= low)
    {
        while ((high <= right) && (arr[high] <= arr[left]))
            high++;
        while ((low >= left) && (arr[low] >= arr[left]))
            low--;
        if (high > low)
            break;
        std::swap(arr[low], arr[high]);
    }
    std::swap(arr[left], arr[low]);

    int small_group_size = low - left;
    if (k <= small_group_size)
        return Selection(arr, left, low - 1, k);
    else if (k == small_group_size + 1)
        return (arr[low]);
    else
        return Selection(arr, low + 1, right, k - (small_group_size + 1));
}

int main()
{
    using namespace std;

    vector<int> arr{ 6, 3, 11, 9, 12, 2, 8, 15 };
    int k = 7;

    PrintArr(arr);
    cout << k << "th number : " << Selection(arr, 0, arr.size() - 1, k) << endl;
}

/* stdout
6 3 11 9 12 2 8 15 
7th number : 12
*/