K Closest Points

ID: 612; medium; *

Solution 1 (Java)

/**
 * Definition for a point.
 * class Point {
 *     int x;
 *     int y;
 *     Point() { x = 0; y = 0; }
 *     Point(int a, int b) { x = a; y = b; }
 * }
 */

public class Solution {
    /**
     * @param points: a list of points
     * @param origin: a point
     * @param k: An integer
     * @return: the k closest points
     */
    public Point[] kClosest(Point[] points, Point origin, int k) {
        if (points == null || points.length < k)
            return new Point[0];
        
        Queue<Point> maxheap = new PriorityQueue<>(k, new Comparator<Point>() {
            public int compare(Point p1, Point p2) {
                int diff = dist(origin, p1) - dist(origin, p2);
                if (diff != 0)
                    return diff;
                if (p1.x != p2.x)
                    return p1.x - p2.x;
                return p1.y - p2.y;
            }
        });

        // O(nlogn)
        for (Point p : points) {
            maxheap.offer(p);
        }
        // O(klogn)
        Point[] res = new Point[k];
        for (int i = 0; i < k; i++) {
            res[i] = maxheap.poll();
        }
        return res;
    }

    private int dist(Point p1, Point p2) {
        return (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y);
    }
}

Notes

  • This is the min heap solution. The thought process is the most straight-forward.

  • Time complexity: O(nlogn + klogn) = O(nlogn)

Solution 2 (Java)

/**
 * Definition for a point.
 * class Point {
 *     int x;
 *     int y;
 *     Point() { x = 0; y = 0; }
 *     Point(int a, int b) { x = a; y = b; }
 * }
 */

public class Solution {
    /**
     * @param points: a list of points
     * @param origin: a point
     * @param k: An integer
     * @return: the k closest points
     */
    public Point[] kClosest(Point[] points, Point origin, int k) {
        Queue<Point> pq = new PriorityQueue<>(k, new Comparator<Point>() {
            public int compare(Point p1, Point p2) {
                int diff = dist(origin, p2) - dist(origin, p1);
                if (diff == 0) diff = p2.x - p1.x;
                if (diff == 0) diff = p2.y - p1.y;
                return diff;
            }
        });

        // O(nlogk)
        for (Point p : points) {
            pq.offer(p);
            if (pq.size() > k) pq.poll();
        }
        k = pq.size();
        Point[] res = new Point[k];
        // O(klogk)
        while (!pq.isEmpty()) {
            res[--k] = pq.poll();
        }
        return res;
    }

    private int dist(Point p1, Point p2) {
        return (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y);
    }
}

Notes

  • This is the max heap solution.

  • If p2 is larger than p1, return positive values. This gives us a max heap, vice versa.

  • Time complexity: O(nlogk + klogk) = O(nlogk).

    • So, this max heap version is better than the min heap version though the latter is more straight-forward to think about. The reason is clear because k is bounded by n.

    • Heap is priority queue in Java and we customize the ordering by overriding the Comparator. We iterate the whole points array and add each point to the priority queue. However, the priority queue only stores the top k closest points, as supposed to the min heap, which essentially sorts the points by their distance to the origin.

    • Thus, the offer and poll methods actually take O(logk) time in this solution because the size of the heap is k.

Solution 3 (Java)

/**
 * Definition for a point.
 * class Point {
 *     int x;
 *     int y;
 *     Point() { x = 0; y = 0; }
 *     Point(int a, int b) { x = a; y = b; }
 * }
 */

public class Solution {
    /**
     * @param points: a list of points
     * @param origin: a point
     * @param k: An integer
     * @return: the k closest points
     */
    public Point[] kClosest(Point[] points, Point origin, int k) {
        if (points == null || points.length < k)
            return new Point[0];
        
        Comparator comp = new Comparator<Point>() {
            public int compare(Point p1, Point p2) {
                // notice we want smaller points to be at the front
                // so p1 - p2 (different from max heap)
                int diff = dist(origin, p1) - dist(origin, p2);
                if (diff == 0) diff = p1.x - p2.x;
                if (diff == 0) diff = p1.y - p2.y;
                return diff;
            }
        };

        // O(n) for quickSelect
        quickSelect(points, comp, 0, points.length - 1, k - 1);
        // sorting takes O(klogk)
        Arrays.sort(points, 0, k - 1, comp);
        return Arrays.copyOf(points, k);
    }

    private Point quickSelect(Point[] points, Comparator<Point> comp, int start, int end, int k) {
        if (start >= end) return points[k];
        int left = start;
        int right = end;
        Point pivot = points[left + (right - left) / 2];

        while (left <= right) {
            while (left <= right && comp.compare(points[left], pivot) < 0) {
                left++;
            }
            while (left <= right && comp.compare(points[right], pivot) > 0) {
                right--;
            }
            if (left <= right) {
                Point temp = points[left];
                points[left] = points[right];
                points[right] = temp;
                left++;
                right--;
            }
        }
        if (k <= right) 
            return quickSelect(points, comp, start, right, k);
        if (k >= left)
            return quickSelect(points, comp, left, end, k);
        return points[k];
    }

    private int dist(Point p1, Point p2) {
        return (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y);
    }
}

Notes

  • This is the quick select solution.

  • Time complexity: O(n + klogk), which means this is the best solution so far.

  • For the quickSelect method, it is important to know that it is finding the kth closest point. After it is done, we have points[0, ..., k - 1] being the top k closest points. However, it is only guaranteed that points[k - 1] is the kth closest point, while the remaining points[0, ..., k - 2] is not necessarily sorted. Thus, we eventually use Arrays.sort() to sort these remaining points.

  • Although this solution gives the best time complexity, it requires us to know quickSelect very well. Moreover, using proper data structures is also important, so the max heap version is already a valid and efficient solution.

Last updated