카테고리 없음 2021. 4. 30. 18:47

최근접 점의 쌍 찾기(Closest Pair)

  • 2차원 평면 상의 n개의 점이 입력으로 주어질 때, 거리가 가장 가까운 한 쌍의 점을 찾는 문제

  • 모든 점에 대하여 각각의 두 점 사이의 거리를 계산하는 방법의 시간 복잡도는 O(n2)이다.

  • 분할 정복을 이용하면 시간 복잡도는 O(n(logn)2)이다.


분할 정복

알고리즘

  • 2개나 3개의 점이 남을 때까지 분할한다.

  • 최근접 쌍을 찾고 분할한 부분문제들을 합칠 때, 합치는 경계를 기준으로 좌우의 점들에 대해 다시 검증을 해야 한다.

    • 중간 영역에 속하는 점들을 판별하는 방법은 다음과 같다.

      • 각 부분문제의 최근접 점 사이의 거리 중 작은 값을 택하고 이를 d라고 하자.

      • 왼쪽 부분문제의 가장 오른쪽 점의 x좌표에서 d만큼 뺀 x좌표부터, 오른쪽 부분문제의 가장 왼쪽 점의 x좌표에서 d만큼 더한 x좌표 사이에 포함되는 점들이 바로 중간 영역에 속하는 점들이라고 할 수 있다.

  • arr는 x좌표 기준 오름차순으로 정렬된 좌표(x, y)들의 배열이다.

    Pseudo Code

    ClosestPair(arr)
    {
      if (arr.size() <= 3)
        return (2개 또는 3개의 점들 사이의 최근접 쌍);
      arr를 반으로 나눠서 arr_left와 arr_right로 분할한다.
      closest_pair_left = ClosestPair(arr_left);
      closest_pair_right = ClosestPair(arr_right);
      d = min(closest_pair_left, closest_pair_right); 를 통해 중간 영역에 속하는 점들을 판별한다.
      중간 영역에 속하는 점들 중에서 최근접 점의 쌍을 closest_pair_center로 저장한다.
      return (closest_pair_left, closest_pair_right, closest_pair_center 중에 최근접 점의 쌍);
    }

시간 복잡도

  • 처음 배열을 정렬할 때의 시간 복잡도는 O(nlogn)이다.

  • 분할해서 계산하는건 모두 O(1)이라고 생각되고, 층 수가 O(logn)인 것 같다.

  • 따라서 시간 복잡도는 O(nlogn)인데.. 처음 정렬하는게 제일 영향이 크다는게 의아하다.


코드

#include <iostream>
#include <vector>
#include <cmath>
#include <limits>

class vec2;

float GetDistance(const vec2& v1, const vec2& v2);

class vec2
{
public:
    int x_;
    int y_;

    vec2()
    {
        x_ = rand() % 100;
        y_ = rand() % 100;
    }

    vec2(int x_in, int y_in)
    : x_(x_in), y_(y_in)
    {}

    vec2 operator - (const vec2& v) const
    {
        return (vec2(x_ - v.x_, y_ - v.y_));
    }

    friend std::ostream& operator << (std::ostream& out, const vec2& v)
    {
        out << '(' << v.x_ << ", " << v.y_ << ')';
        return out;
    }
};

class ClosestPair
{
public:
    vec2 v1_;
    vec2 v2_;
    float dist_;

    ClosestPair()
    {
        dist_ = std::numeric_limits<float>::max();
    }
    ClosestPair(const vec2& v1, const vec2& v2)
    {
        v1_ = v1;
        v2_ = v2;
        dist_ = GetDistance(v1, v2);
    }

    friend std::ostream& operator << (std::ostream& out, const ClosestPair& cp)
    {
        out << "Point1 : " << cp.v1_ << "\tPoint2 : " << cp.v2_ << "\tDistance : " << cp.dist_ << std::endl;
        return out;
    }
};

float GetDistance(const vec2& v1, const vec2& v2)
{
    return std::sqrt(std::pow((v1 - v2).x_, 2) + std::pow((v1 - v2).y_, 2));
}

void QuickSort(std::vector<vec2>& arr, int left, int right)
{
    if (left >= right)
        return ;

    int    pivot = (left + right) / 2;
    int    high = left + 1;
    int    low = right;

    std::swap(arr[pivot], arr[left]);
    while (high <= low)
    {
        while ((high <= right) && (arr[high].x_ <= arr[left].x_))
            high++;
        while ((low > left) && (arr[low].x_ >= arr[left].x_))
            low--;
        if ((high > right) || (low == left) || (high >= low))
            break;
        std::swap(arr[high], arr[low]);
    }
    std::swap(arr[left], arr[low]);

    QuickSort(arr, left, low - 1);
    QuickSort(arr, low + 1, right);
}

ClosestPair GetUnderThreeCase(const std::vector<vec2>& v, int left, int right)
{
    if (right == left + 2)
    {
        float dist1 = GetDistance(v[left], v[left + 1]);
        float dist2 = GetDistance(v[left], v[right]);
        float dist3 = GetDistance(v[left + 1], v[right]);

        if (dist1 < dist2 && dist1 < dist3)
            return ClosestPair(v[left], v[left + 1]);
        else if (dist2 < dist1 && dist2 < dist3)
            return ClosestPair(v[left], v[right]);
        else
            return ClosestPair(v[left + 1], v[right]);
    }
    else
        return ClosestPair(v[left], v[right]);
}

ClosestPair CheckCenter(const std::vector<vec2>& v, int left, int right, float d)
{
    float x_leftside = v[(left + right) / 2].x_ - d;
    float x_rightside = v[(left + right) / 2 + 1].x_ + d;
    ClosestPair closest_pair_center;

    for (int i = 0; i <= (left + right) / 2; ++i)
    {
        if (v[i].x_ < x_leftside)
            continue ;
        for (int j = (left + right) / 2 + 1; j <= right; ++j)
        {
            if (v[j].x_ > x_rightside)
                continue ;
            if (GetDistance(v[i], v[j]) < closest_pair_center.dist_)
                closest_pair_center = ClosestPair(v[i], v[j]);
        }
    }
    return closest_pair_center;
}

ClosestPair GetClosestPair(const std::vector<vec2>& v, int left, int right)
{
    if (right <= left + 2)
        return GetUnderThreeCase(v, left, right);
    ClosestPair closest_pair_left = GetClosestPair(v, left, (left + right) / 2);
    ClosestPair closest_pair_right = GetClosestPair(v, (left + right) / 2 + 1, right);

    float d = std::min(closest_pair_left.dist_, closest_pair_right.dist_);
    ClosestPair closest_pair_center = CheckCenter(v, left, right, d);

    if (closest_pair_left.dist_ < closest_pair_right.dist_ && \
        closest_pair_left.dist_ < closest_pair_center.dist_)
        return closest_pair_left;
    else if (closest_pair_right.dist_ < closest_pair_left.dist_ && \
        closest_pair_right.dist_ < closest_pair_center.dist_)
        return closest_pair_right;
    else
        return closest_pair_center;
}

int main()
{
    using namespace std;

    srand(time(NULL));

    vector<vec2> v(10);
    QuickSort(v, 0, v.size() - 1);
    for (const auto &e : v)
        cout << e << '\n';

    ClosestPair closest_pair;
    closest_pair = GetClosestPair(v, 0, v.size() - 1);
    cout << closest_pair << endl;
}
  • 코드를 완벽히 하려면 sqrt를 제외하고 비교하는 등의 최적화를 하고, 예외처리도 해서 좀 더 정리해야 한다.

  • 그래픽 요소 추가하면 좋을 듯

  • 책에서는 부분문제를 합칠 때 y좌표를 기준으로 정렬한 뒤 d와 비교하라고 하는데, 굳이 정렬해야 하는 이유를 찾지 못해서 그냥 비교했다.