Princeton Algorithms, Part II, WordNet
普林斯顿大学算法课 WordNet 题解与代码
首先要理解什么是 WordNet,这里定义了同义词集、下位词、上位词、逻辑门等计算语言学的复杂概念,确实不太好懂,但是总体上说,它是一个有根的有向无环图 a rooted DAG,但是它不一定是树。

同义词集列表中第一个字段是 id,第二个字段是同义词集,构成同义词集的各个名词之间用空格分隔,第三个字段与本次作业无关。

上位词列表中第一个字段表示同义词的 id,后续字段是改同义词的上位词 id 号,一个词可以有多个上位词,所以可能有多个字段,这就相当于是确定了 DAG 中的边的关系。

在 WordNet 的类的构造中,参数为 null 或者单词不是 WordNet 中有效的单词,给出异常,这个比较好做。难以理解的是 The input to the constructor does not correspond to a rooted DAG. 这个要求,我们如何才能确定自己构造的是不是一个有根的有向无环图?实际上,algs4.jar 中的 Digraph(有向图)给出了 2 个函数,indegree() 和 outdegree() 可以很方便地计算一个顶点的入度和出度,我们可以通过这个来判断是不是一个 rooted DAG。
除了是不是单根,还需要调用 DirectedCycle 或者 Topological 去检查是不是有环。
private void validate(Digraph g) {
assert g != null;
int vertexNumber = g.V();
int rootNumber = 0;
for (int i = 0; i < vertexNumber; i++) {
// 出度为 0 的点是根节点(没有上位词的同义词集)
if (g.outdegree(i) == 0) {
rootNumber++;
}
}
// 根节点不足 1 或者大于 1 都不满足条件
if (rootNumber != 1) {
throw new IllegalArgumentException();
}
// The program uses neither 'DirectedCycle' nor 'Topological' to check whether the digraph is a DAG.
DirectedCycle dc = new DirectedCycle(g);
if (dc.hasCycle()) {
throw new IllegalArgumentException();
}
}
最近公共祖先比较好理解,推广到顶点的集合也很简单,就是找到所有 SAP 的最短的。
求公共祖先的过程是一个广度优先搜索的过程。
while (!q.isEmpty()) {
int x = q.poll();
Iterable<Integer> bag = g.adj(x);
// 加入后面的点
for (int vv : bag) {
if (!visited[vv]) {
q.add(vv);
visited[vv] = true;
// 更新距离
int d = distanceV.get(x);
int dd = distanceV.getOrDefault(vv, d + 1);
distanceV.put(vv, dd);
}
}
}
具体的策略是,先对 v 点做一次广搜,直到根结点,在每一次搜索的时候记录下 depth。
然后对 w 做一次广搜,搜索过程中遇到符合条件的(搜索过的),都是祖先,记录下 depth 并相加,取最小的 depth 就是最近公共祖先。
最多对所有的点访问 2 次(从 v 出发一次,从 w 出发一次),所以时间复杂度只与点的个数有关。
if (distanceV.containsKey(x)) {
// 更新最短的 LCA
int minDistance = distanceV.get(x) + distanceW.get(x);
if (sap[0] == -1 || minDistance < sap[0]) {
sap[0] = minDistance;
sap[1] = x;
}
}
// 这里不是 else 的关系,要继续往上找
Iterable<Integer> bag = g.adj(x);
for (int vv : bag) {
if (!visited[vv]) {
q.add(vv);
visited[vv] = true;
// 更新距离
int d = distanceW.get(x);
int dd = distanceW.getOrDefault(vv, d + 1);
distanceW.put(vv, dd);
}
}
另外在实现的时候需要注意性能和安全,例如有些可以在构造的时候就缓存出来的值就在构造的时候缓存好,不要等用的时候再去遍历。
validate(word); // 直接查缓存就可以了 return synsetToIdMap.containsKey(word);
以及有些返回类型是 Iterable 的,一定要确保返回类型不可变,不要返回原来的那个,要 new 一个新的去返回。
// 时刻记得做成不可变的 this.g = new Digraph(g);
// 时刻注意这种类型返回的时候一定要不可变,所以这里返回的时候返回一个新的,不返回原来那个 return new ArrayList<>(synsetToIdMap.keySet());
Princeton 的作业质量高还在于它的异常处理,需要注意所有不合法情况的判断,本题中对应的是单词为 null,更进一步的是单词不存在于单词集中。而对于图,则需要满足 DAG 和 rooted 这两个条件。
封装的情况也需要考虑,例如本题,在搜索 sap(int, int) 和 sap(Iterable, Iterable) 的时候,可以将这两个功能抽象出来,做一次封装。
最简单的当然是写 sap(int, int),然后对于 sap(Iterable, Iterable) 的情况,做一次双重循环,对于每一个 i 对于每一个 j 做一次 sap(i, j) 并取最小值。
但是这样的写法是不是真的足够好?在搜索的时候是不是重复搜索了很多?
所以我们可以按下面代码所述,进行更合理的封装和优化,对于已经搜索过的,就不再搜索了。
需要注意的是本题存在的映射关系有:“已知 id 获取同义词集”,“已知单词获取该单词对应的同义词集”这两种。其中,“已知单词获取该单词对应的同义词集”可以由“已知单词获取 id”和“已知 id 获取同义词集”完成,所以我们只需要保存“已知单词获取 id”和“已知 id 获取同义词集”这样两个 HashMap 即可。
Outcast 就非常简单了,WordNet 写好了以后直接调用,取最大值就好。
import edu.princeton.cs.algs4.Digraph;
import edu.princeton.cs.algs4.DirectedCycle;
import edu.princeton.cs.algs4.In;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
public class WordNet {
// id 对应的同义词集
private final ArrayList<String> synsetList = new ArrayList<>();
// 单词对应的 id,注意一个单词可能有多个 id 中都出现了
private final HashMap<String, List<Integer>> synsetToIdMap = new HashMap<>();
private final SAP sap;
public WordNet(String synsets, String hypernyms) {
validate(synsets);
validate(hypernyms);
In in;
int vertexNumber = 0;
in = new In(synsets);
while (in.hasNextLine()) {
// 顶点个数 = 同义词集的个数
vertexNumber++;
String[] ss = in.readLine().split(",");
int id = Integer.parseInt(ss[0]);
// 因为有 API 是返回整个同义词集,所以先存下来,避免后续需要的时候遍历 List 去拼接
synsetList.add(ss[1]);
String[] words = ss[1].split(" ");
// 将 id 与同义词 List 放进 HashMap
for (String word : words) {
if (!synsetToIdMap.containsKey(word)) {
synsetToIdMap.put(word, new ArrayList<>());
}
synsetToIdMap.get(word).add(id);
}
}
Digraph g = new Digraph(vertexNumber);
in = new In(hypernyms);
while (in.hasNextLine()) {
String[] ss = in.readLine().split(",");
int id = Integer.parseInt(ss[0]);
// 每一组上位词的对应关系就是一条边
for (int i = 1; i < ss.length; i++) {
g.addEdge(id, Integer.parseInt(ss[i]));
}
}
validate(g);
sap = new SAP(g);
}
private void validate(Digraph g) {
assert g != null;
int vertexNumber = g.V();
int rootNumber = 0;
for (int i = 0; i < vertexNumber; i++) {
// 出度为 0 的点是根节点(没有上位词的同义词集)
if (g.outdegree(i) == 0) {
rootNumber++;
}
}
// 根节点不足 1 或者大于 1 都不满足条件
if (rootNumber != 1) {
throw new IllegalArgumentException();
}
// The program uses neither 'DirectedCycle' nor 'Topological' to check whether the digraph is a DAG.
DirectedCycle dc = new DirectedCycle(g);
if (dc.hasCycle()) {
throw new IllegalArgumentException();
}
}
private void validate(String string) {
if (string == null) {
throw new IllegalArgumentException();
}
}
public Iterable<String> nouns() {
// 时刻注意这种类型返回的时候一定要不可变,所以这里返回的时候返回一个新的,不返回原来那个
return new ArrayList<>(synsetToIdMap.keySet());
}
public boolean isNoun(String word) {
validate(word);
// 直接查缓存就可以了
return synsetToIdMap.containsKey(word);
}
public int distance(String nounA, String nounB) {
validate(nounA);
validate(nounB);
if (!isNoun(nounA) || !isNoun(nounB)) {
throw new IllegalArgumentException();
}
return sap.length(synsetToIdMap.get(nounA), synsetToIdMap.get(nounB));
}
public String sap(String nounA, String nounB) {
return synsetList.get(sap.ancestor(synsetToIdMap.get(nounA), synsetToIdMap.get(nounB)));
}
public static void main(String[] args) {
WordNet wordNet = new WordNet("tiny_synsets.txt", "tiny_hypernyms.txt");
System.out.println(wordNet.nouns());
System.out.println(wordNet.isNoun("wordnet"));
System.out.println(wordNet.distance("WordNet3.1", "wordnet"));
System.out.println(wordNet.sap("wordnet", "WordNet3.1"));
}
}
import edu.princeton.cs.algs4.Digraph;
import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.StdIn;
import edu.princeton.cs.algs4.StdOut;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
/**
* @author jxtxzzw
*/
public class SAP {
private final Digraph g;
private final int vertexNumber;
public SAP(Digraph g) {
// 时刻记得做成不可变的
this.g = new Digraph(g);
vertexNumber = g.V();
}
public int length(int v, int w) {
validate(v);
validate(w);
return sap(v, w)[0];
}
public int ancestor(int v, int w) {
validate(v);
validate(w);
return sap(v, w)[1];
}
public int length(Iterable<Integer> v, Iterable<Integer> w) {
validate(v);
validate(w);
return sap(v, w)[0];
}
public int ancestor(Iterable<Integer> v, Iterable<Integer> w) {
validate(v);
validate(w);
return sap(v, w)[1];
}
private void validate(int v) {
// Any vertex argument is outside its prescribed range
if (v < 0 || v >= vertexNumber) {
throw new IllegalArgumentException();
}
}
private void validate(Iterable<Integer> v) {
// Any argument is null
if (v == null) {
throw new IllegalArgumentException();
}
for (Integer i : v) {
// Any iterable argument contains a null item
if (i == null) {
throw new IllegalArgumentException();
}
// Call validate(int v)
validate(i);
}
}
private int[] sap(int v, int w) {
ArrayList<Integer> vv = new ArrayList<>();
vv.add(v);
ArrayList<Integer> ww = new ArrayList<>();
ww.add(w);
return sap(vv, ww);
}
private int[] sap(Iterable<Integer> v, Iterable<Integer> w) {
// 0 表示 length,1 表示 ancestor
int[] sap = new int[2];
// -1 if no such path
Arrays.fill(sap, -1);
// 记录点 x 到点 v 之间的距离
HashMap<Integer, Integer> distanceV = new HashMap<>();
HashMap<Integer, Integer> distanceW = new HashMap<>();
// 遍历 q 的所有父结点
LinkedList<Integer> q = new LinkedList<>();
boolean[] visited = new boolean[vertexNumber];
Arrays.fill(visited, false);
for (int vx : v) {
q.add(vx);
visited[vx] = true;
distanceV.put(vx, 0);
}
while (!q.isEmpty()) {
int x = q.poll();
Iterable<Integer> bag = g.adj(x);
// 加入后面的点
for (int vv : bag) {
if (!visited[vv]) {
q.add(vv);
visited[vv] = true;
// 更新距离
int d = distanceV.get(x);
int dd = distanceV.getOrDefault(vv, d + 1);
distanceV.put(vv, dd);
}
}
}
Arrays.fill(visited, false);
// 遍历 w 的所有父结点,找到最先遇到的
for (int wx : w) {
q.add(wx);
visited[wx] = true;
distanceW.put(wx, 0);
}
while (!q.isEmpty()) {
int x = q.poll();
if (distanceV.containsKey(x)) {
// 更新最短的 LCA
int minDistance = distanceV.get(x) + distanceW.get(x);
if (sap[0] == -1 || minDistance < sap[0]) {
sap[0] = minDistance;
sap[1] = x;
}
}
// 这里不是 else 的关系,要继续往上找
Iterable<Integer> bag = g.adj(x);
for (int vv : bag) {
if (!visited[vv]) {
q.add(vv);
visited[vv] = true;
// 更新距离
int d = distanceW.get(x);
int dd = distanceW.getOrDefault(vv, d + 1);
distanceW.put(vv, dd);
}
}
}
return sap;
}
// do unit testing of this class
public static void main(String[] args) {
In in = new In("digraph2.txt");
Digraph g = new Digraph(in);
SAP sap = new SAP(g);
while (!StdIn.isEmpty()) {
int v = StdIn.readInt();
int w = StdIn.readInt();
int length = sap.length(v, w);
int ancestor = sap.ancestor(v, w);
StdOut.printf("length = %d, ancestor = %d\n", length, ancestor);
}
// while (!StdIn.isEmpty()) {
// ArrayList<Integer> vv = new ArrayList<>();
// ArrayList<Integer> ww = new ArrayList<>();
// String[] v = StdIn.readLine().split(" ");
// String[] w = StdIn.readLine().split(" ");
// for (String s : v) {
// vv.add(Integer.parseInt(s));
// }
// for (String s : w) {
// ww.add(Integer.parseInt(s));
// }
// int length = sap.length(vv, ww);
// int ancestor = sap.ancestor(vv, ww);
// StdOut.printf("length = %d, ancestor = %d\n", length, ancestor);
// }
}
}
public class Outcast {
private final WordNet wordNet;
public Outcast(WordNet wordnet) {
this.wordNet = wordnet;
}
public String outcast(String[] nouns) {
int max = 0;
String outcase = null;
for (String noun : nouns) {
int di = 0;
for (String s : nouns) {
di += wordNet.distance(noun, s);
}
if (di > max) {
max = di;
outcase = noun;
}
}
return outcase;
}
public static void main(String[] args) {
WordNet wordnet = new WordNet("synsets.txt", "hypernyms.txt");
Outcast outcast = new Outcast(wordnet);
System.out.println(outcast.outcast("horse zebra cat bear table".split(" ")));
System.out.println(outcast.outcast("water soda bed orange_juice milk apple_juice tea coffee".split(" ")));
System.out.println(outcast.outcast("apple pear peach banana lime lemon blueberry strawberry mango watermelon potato".split(" ")));
}
}
完整代码也可参见 https://gitlab.jxtxzzw.com/jxtxzzw/coursera_assignments
