Princeton Algorithm Kd-Trees
普林斯顿大学算法课第 5 次作业 KD 树
题目链接: https://coursera.cs.princeton.edu/algs4/assignments/kdtree/specification.php
本次课程作业是编写一个数据结构,以表示单位正方形中的一组点,并支持高效的范围搜索(查找查询矩形中包含的所有点),以及高效的最近邻居搜索(找到最接近查询点的点)。
KD 树有许多应用,从对天文物体进行分类到计算机动画,再到加速神经网络,再到挖掘数据再到图像检索等。
首先要用暴力做法做一次,题目限定只能使用 SET 或者 java.util.TreeSet,这个就比较简单了,只需要注意一下 corner case,然后注意参数不合法的时候抛异常。
其的搜索和插入的算法与 BST 的算法相似,但是在根结点处,我们使用 x 坐标来判断大小,如果要插入的点的 x 坐标比在根结点的点小,向左移动,否则向右移动;然后在下一个级别,我们使用 y 坐标来判断大小,如果要插入的点的 y 坐标比结点中的点小,则向左移动,否则向右移动;然后在下一级,继续使用 x 坐标,依此类推……
由此,我们可以得到下图:
相对于 BST 的主要优势在于,它支持范围搜索和最近邻居搜索的高效实现。每个节点对应于单位正方形中与轴对齐的矩形,该矩形将其子树中的所有点都包含在内。根结点对应整个单位正方形,根的左、右子元素对应于两个矩形,该两个矩形被根结点的 x 坐标分开,以此类推……
由此,我们可以得到范围搜索和最近邻居搜索的思想思路。
进行范围搜索时,从根结点开始,递归地搜索左右子树,若查询矩形不与该结点对应的矩形相交,那么就不需要探索该节点及其子树。子树只有在可能包含查询矩形中包含的点时才被搜索。
进行最近邻居搜索时,从根结点开始,递归地搜索左右子树,如果到目前为止发现的最近点比查询点与结点对应的矩形之间的距离更近,则不需要探索该结点及其子树。也就是说,仅当一个结点可能包含一个比目前发现的最佳结点更接近的点时,才进行搜索。
这样的剪枝规则,依赖于能否快速找到附近的点。因此,我们需要注意在递归代码中,当有两个可能的子树的时候,总是选择位于分隔线同一侧的子树作为要探索的第一棵子树的查询点。这是因为在探索第一棵子树时发现的有可能是最近的点,将有利于探索第二棵子树时剪枝。
这里在实现的时候,递归先左先右当然都可以得到正确的结果,但是这里必须调整递归的顺序,才能达到剪枝的效果。
这是因为,如果左孩子包含 p,由于矩形是越来越小的,所以若点在某个 node 的矩形内被包含,则该 node 的 p 离这个所求 p 的距离就可能越小。min 越小,那么剪枝的效果就越明显,因为越来越多的就不需要再计算了。于是,应该始终优先去递归那个 contains(p) 的方向(因为有且只有可能要么是 left 要么还是 right)包含 p。
如果不进行剪枝,那么就算你的代码功底非常好,在规定时间内求得了正确解,没有超时,也一样不能通过测评:
- student sequence of kd-tree nodes involved in calls to Point2D methods: A D F I G B C E J - reference sequence of kd-tree nodes involved in calls to Point2D methods: A D F I G B C - failed on trial 1 of 1000
具体剪枝的策略就是,如果左孩子包含了目标点,那么就去左孩子,如果右孩子包含了目标点,那么就去右孩子。有可能左右孩子都不包含目标点,那么离谁近就去谁那。
// 先左先右当然都可以得到正确的结果,但是 // 这里必须调整递归的顺序,才能达到剪枝的效果 if (node.left != null && node.left.rect.contains(p)) { // 如果左孩子包含 p,由于矩形是越来越小的,所以若点在某个 node 的矩形内被包含,则该 node 的 p 离这个所求 p 的距离就可能越小 // min 越小,那么剪枝的效果就越明显,因为越来越多的就不需要再计算了 // 于是,应该始终优先去递归那个 contains(p) 的方向(因为有且只有可能要么是 left 要么还是 right)包含 p findNearest(p, node.left); findNearest(p, node.right); } else if (node.right != null && node.right.rect.contains(p)) { // 如果右孩子包含就先去右边 findNearest(p, node.right); findNearest(p, node.left); } else { // 也可能出现两个都不包含的情况,那么离谁近就先去谁那 // 注意调用时 null 的问题要特别处理,可以设置为无穷大 double toLeft = node.left != null ? node.left.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY; double toRight = node.right != null ? node.right.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY; if (toLeft < toRight) { findNearest(p, node.left); findNearest(p, node.right); } else { findNearest(p, node.right); findNearest(p, node.left); } }
为了代码实现的方便,二叉树当然要用递归的写法啦。
课程提供了若干可视化工具用于调试。
draw() 函数的正确性将会大幅度提高 debug 的效率,所以这个函数一定要写的正确。
在可视化过程中,使用暴力法求解的答案会标注为红色,使用 KDTree 方法求解的会标注为蓝色。由于我们非常有信心,暴力法肯定是对的,所以可以用这个方法来检验 KdTree 的搜索是不是正确。
使用上也非常简单:当检验区域搜索的时候,只需要用鼠标在上面画一个矩形;当检验最近邻居的时候,只需要将鼠标移动到想要搜索的那个点对应的位置上(也许这个点并没有在图中画出)。
另一个难点是处理重叠的点。重叠点在统计个数的时候不能被重复计算,我简单地开了一个 same 数组,但是可能没有必要。
另外特别要注意每一个新增点的时候,它对应的 RectHV 的范围一定要搞清楚,否则后面的事情没法做。不过这个也简单,只要把 draw() 写了,然后点几个点,根据画出来的图马上就知道自己写的对不对了。如果图和自己预想的不一样,那就肯定是写错了,这个是最容易 debug 的。
以下是完整代码,该代码通过 100% 的测试数据,得分 100 分。
更多题解请参看:
import edu.princeton.cs.algs4.Point2D; import edu.princeton.cs.algs4.RectHV; import java.util.ArrayList; import java.util.TreeSet; public class PointSET { private final TreeSet<Point2D> set; public PointSET() { set = new TreeSet<>(); } public boolean isEmpty() { return set.isEmpty(); } public int size() { return set.size(); } public void insert(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } if (!contains((p))) { set.add(p); } } public boolean contains(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } return set.contains(p); } public void draw() { for (Point2D p : set) { p.draw(); } } public Iterable<Point2D> range(RectHV rect) { if (rect == null) { throw new IllegalArgumentException(); } ArrayList<Point2D> list = new ArrayList<>(); for (Point2D p : set) { if (rect.contains(p)) { list.add(p); } } return list; } public Point2D nearest(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } Point2D ans = null; if (!isEmpty()) { double min = Double.POSITIVE_INFINITY; for (Point2D pp : set) { // Do not call 'distanceTo()' in this program; instead use 'distanceSquaredTo()'. [Performance] double d = pp.distanceSquaredTo(p); if (d < min) { min = d; ans = pp; } } } return ans; } public static void main(String[] args) { PointSET ps = new PointSET(); Point2D p1 = new Point2D(1, 1); Point2D p2 = new Point2D(1, 2); Point2D p3 = new Point2D(2, 1); Point2D p4 = new Point2D(0, 0); ps.insert(p1); ps.insert(p2); ps.insert(p3); ps.insert(p4); System.out.println(ps.nearest(p4)); for (Point2D p : ps.range(new RectHV(1, 1, 3, 3))) { System.out.println(p); } } }
import edu.princeton.cs.algs4.Point2D; import edu.princeton.cs.algs4.RectHV; import edu.princeton.cs.algs4.StdDraw; import java.util.ArrayList; /** * @author jxtxzzw */ public class KdTree { private Node root; private int size; private static class Node { private final Point2D p; private final int level; private Node left; private Node right; private final RectHV rect; // 记录重叠的点 private final ArrayList<Point2D> same = new ArrayList<>(); // 对根结点 public Node(Point2D p) { // 根结点层数是 0,范围是单位正方形 this(p, 0, 0, 1, 0, 1); } public Node(Point2D p, int level, double xmin, double xmax, double ymin, double ymax) { this.p = p; this.level = level; rect = new RectHV(xmin, ymin, xmax, ymax); } public void addSame(Point2D point) { same.add(point); } public boolean hasSamePoint() { return !same.isEmpty(); } } private Point2D currentNearest; private double min; public KdTree() { } public boolean isEmpty() { return size == 0; } public int size() { return size; } private int compare(Point2D p, Node n) { if (n.level % 2 == 0) { // 如果是偶数层,按 x 比较 if (Double.compare(p.x(), n.p.x()) == 0) { return Double.compare(p.y(), n.p.y()); } else { return Double.compare(p.x(), n.p.x()); } } else { // 按 y 比较 if (Double.compare(p.y(), n.p.y()) == 0) { return Double.compare(p.x(), n.p.x()); } else { return Double.compare(p.y(), n.p.y()); } } } private Node generateNode(Point2D p, Node parent) { int cmp = compare(p, parent); if (cmp < 0) { if (parent.level % 2 == 0) { // 偶数层,比较结果是小于,说明是加在左边 // 那么它的 xmin, ymin, ymax 都和父结点一样,xmax 设置为父结点的 p.x() return new Node(p, parent.level + 1, parent.rect.xmin(), parent.p.x(), parent.rect.ymin(), parent.rect.ymax()); } else { // 奇数层,加在下边,那么只需要修改 ymax return new Node(p, parent.level + 1, parent.rect.xmin(), parent.rect.xmax(), parent.rect.ymin(), parent.p.y()); } } else { if (parent.level % 2 == 0) { // 偶数层,加在右边,那么只需要修改 xmin return new Node(p, parent.level + 1, parent.p.x(), parent.rect.xmax(), parent.rect.ymin(), parent.rect.ymax()); } else { // 奇数层,比较结果是大于,说明是加在上边,修改 ymin return new Node(p, parent.level + 1, parent.rect.xmin(), parent.rect.xmax(), parent.p.y(), parent.rect.ymax()); } } } public void insert(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } else { if (root == null) { // 初始化根结点 size++; root = new Node(p); } else { // 二叉树,用递归的写法去调用 insert(p, root); } } } private void insert(Point2D p, Node node) { int cmp = compare(p, node); // 如果比较结果是小于,那么就是要往左边走,右边同理 if (cmp < 0) { // 走到头了就新建,否则继续走 if (node.left == null) { size++; node.left = generateNode(p, node); } else { insert(p, node.left); } } else if (cmp > 0) { if (node.right == null) { size++; node.right = generateNode(p, node); } else { insert(p, node.right); } } // 重叠的点,size 不加 1 } public boolean contains(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } else { if (root == null) { return false; } else { // 递归的写法 return contains(p, root); } } } private boolean contains(Point2D p, Node node) { if (node == null) { return false; } else if (p.equals(node.p)) { return true; } else { if (compare(p, node) < 0) { return contains(p, node.left); } else { return contains(p, node.right); } } } public void draw() { // 清空画布 StdDraw.clear(); // 递归调用 draw(root); } private void draw(Node node) { if (node != null) { // 点用黑色 StdDraw.setPenColor(StdDraw.BLACK); // 画点 node.p.draw(); // 根据是不是偶数设置红色还是蓝色 if (node.level % 2 == 0) { StdDraw.setPenColor(StdDraw.RED); StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax()); } else { StdDraw.setPenColor(StdDraw.BLUE); StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y()); } // 递归画 draw(node.left); draw(node.right); } } public Iterable<Point2D> range(RectHV rect) { if (rect == null) { throw new IllegalArgumentException(); } if (isEmpty()) { return null; } // 递归调用 return new ArrayList<>(range(rect, root)); } private ArrayList<Point2D> range(RectHV rect, Node node) { ArrayList<Point2D> list = new ArrayList<>(); // A subtree is searched only if it might contain a point contained in the query rectangle. if (node != null && rect.intersects(node.rect)) { // 递归地检查左右孩子 list.addAll(range(rect, node.left)); list.addAll(range(rect, node.right)); // 如果对当前点包含,则加入 if (rect.contains(node.p)) { list.add(node.p); // 重叠点应该只被计算一次 } } return list; } public Point2D nearest(Point2D p) { if (p == null) { throw new IllegalArgumentException(); } if (isEmpty()) { return null; } currentNearest = null; min = Double.POSITIVE_INFINITY; findNearest(p, root); return currentNearest; } private void findNearest(Point2D p, Node node) { if (node == null) { return; } // The square of the Euclidean distance between the point {@code p} and the closest point on this rectangle; 0 if the point is contained in this rectangle if (node.rect.distanceSquaredTo(p) <= min) { // Do not call 'distanceTo()' in this program; instead use 'distanceSquaredTo()'. [Performance] double d = node.p.distanceSquaredTo(p); if (d < min) { min = d; currentNearest = node.p; } // 先左先右当然都可以得到正确的结果,但是 // 这里必须调整递归的顺序,才能达到剪枝的效果 if (node.left != null && node.left.rect.contains(p)) { // 如果左孩子包含 p,由于矩形是越来越小的,所以若点在某个 node 的矩形内被包含,则该 node 的 p 离这个所求 p 的距离就可能越小 // min 越小,那么剪枝的效果就越明显,因为越来越多的就不需要再计算了 // 于是,应该始终优先去递归那个 contains(p) 的方向(因为有且只有可能要么是 left 要么还是 right)包含 p findNearest(p, node.left); findNearest(p, node.right); } else if (node.right != null && node.right.rect.contains(p)) { // 如果右孩子包含就先去右边 findNearest(p, node.right); findNearest(p, node.left); } else { // 也可能出现两个都不包含的情况,那么离谁近就先去谁那 // 注意调用时 null 的问题要特别处理,可以设置为无穷大 double toLeft = node.left != null ? node.left.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY; double toRight = node.right != null ? node.right.rect.distanceSquaredTo(p) : Double.POSITIVE_INFINITY; if (toLeft < toRight) { findNearest(p, node.left); findNearest(p, node.right); } else { findNearest(p, node.right); findNearest(p, node.left); } } } } public static void main(String[] args) { KdTree kd; kd = new KdTree(); kd.insert(new Point2D(0.7, 0.2)); kd.insert(new Point2D(0.5, 0.4)); kd.insert(new Point2D(0.2, 0.3)); kd.insert(new Point2D(0.4, 0.7)); kd.insert(new Point2D(0.9, 0.6)); assert kd.nearest(new Point2D(0.73, 0.36)).equals(new Point2D(0.7, 0.2)); } }
关于其他 Coursera 上的作业分析,可以参考: https://gitlab.jxtxzzw.com/jxtxzzw/coursera_assignments