作业|学习|算法|题解

WordNet

凝神长老 · 4月3日 · 2020年 · · · · · 617次已读

Princeton Algorithms, Part II, WordNet

普林斯顿大学算法课 WordNet 题解与代码

首先要理解什么是 WordNet,这里定义了同义词集、下位词、上位词、逻辑门等计算语言学的复杂概念,确实不太好懂,但是总体上说,它是一个有根的有向无环图 a rooted DAG,但是它不一定是树。

同义词集列表中第一个字段是 id,第二个字段是同义词集,构成同义词集的各个名词之间用空格分隔,第三个字段与本次作业无关。

上位词列表中第一个字段表示同义词的 id,后续字段是改同义词的上位词 id 号,一个词可以有多个上位词,所以可能有多个字段,这就相当于是确定了 DAG 中的边的关系。

在 WordNet 的类的构造中,参数为 null 或者单词不是 WordNet 中有效的单词,给出异常,这个比较好做。难以理解的是 The input to the constructor does not correspond to a rooted DAG. 这个要求,我们如何才能确定自己构造的是不是一个有根的有向无环图?实际上,algs4.jar 中的 Digraph(有向图)给出了 2 个函数,indegree() 和 outdegree() 可以很方便地计算一个顶点的入度和出度,我们可以通过这个来判断是不是一个 rooted DAG。

除了是不是单根,还需要调用 DirectedCycle 或者 Topological 去检查是不是有环。

private void validate(Digraph g) {
    assert g != null;
    int vertexNumber = g.V();
    int rootNumber = 0;
    for (int i = 0; i < vertexNumber; i++) {
        // 出度为 0 的点是根节点(没有上位词的同义词集)
        if (g.outdegree(i) == 0) {
            rootNumber++;
        }
    }
    // 根节点不足 1 或者大于 1 都不满足条件
    if (rootNumber != 1) {
        throw new IllegalArgumentException();
    }
    // The program uses neither 'DirectedCycle' nor 'Topological' to check whether the digraph is a DAG.
    DirectedCycle dc = new DirectedCycle(g);
    if (dc.hasCycle()) {
        throw new IllegalArgumentException();
    }
}

最近公共祖先比较好理解,推广到顶点的集合也很简单,就是找到所有 SAP 的最短的。

求公共祖先的过程是一个广度优先搜索的过程。

while (!q.isEmpty()) {
    int x = q.poll();
    Iterable<Integer> bag = g.adj(x);
    // 加入后面的点
    for (int vv : bag) {
        if (!visited[vv]) {
            q.add(vv);
            visited[vv] = true;
            // 更新距离
            int d = distanceV.get(x);
            int dd = distanceV.getOrDefault(vv, d + 1);
            distanceV.put(vv, dd);
        }
    }
}

具体的策略是,先对 v 点做一次广搜,直到根结点,在每一次搜索的时候记录下 depth。

然后对 w 做一次广搜,搜索过程中遇到符合条件的(搜索过的),都是祖先,记录下 depth 并相加,取最小的 depth 就是最近公共祖先。

最多对所有的点访问 2 次(从 v 出发一次,从 w 出发一次),所以时间复杂度只与点的个数有关。

if (distanceV.containsKey(x)) {
    // 更新最短的 LCA
    int minDistance = distanceV.get(x) + distanceW.get(x);
    if (sap[0] == -1 || minDistance < sap[0]) {
        sap[0] = minDistance;
        sap[1] = x;
    }
}
// 这里不是 else 的关系,要继续往上找
Iterable<Integer> bag = g.adj(x);
for (int vv : bag) {
    if (!visited[vv]) {
        q.add(vv);
        visited[vv] = true;
        // 更新距离
        int d = distanceW.get(x);
        int dd = distanceW.getOrDefault(vv, d + 1);
        distanceW.put(vv, dd);
    }
}

另外在实现的时候需要注意性能和安全,例如有些可以在构造的时候就缓存出来的值就在构造的时候缓存好,不要等用的时候再去遍历。

validate(word);
// 直接查缓存就可以了
return synsetToIdMap.containsKey(word);

以及有些返回类型是 Iterable 的,一定要确保返回类型不可变,不要返回原来的那个,要 new 一个新的去返回。

// 时刻记得做成不可变的
this.g = new Digraph(g);
// 时刻注意这种类型返回的时候一定要不可变,所以这里返回的时候返回一个新的,不返回原来那个
return new ArrayList<>(synsetToIdMap.keySet());

Princeton 的作业质量高还在于它的异常处理,需要注意所有不合法情况的判断,本题中对应的是单词为 null,更进一步的是单词不存在于单词集中。而对于图,则需要满足 DAG 和 rooted 这两个条件。

封装的情况也需要考虑,例如本题,在搜索 sap(int, int) 和 sap(Iterable, Iterable) 的时候,可以将这两个功能抽象出来,做一次封装。

最简单的当然是写 sap(int, int),然后对于 sap(Iterable, Iterable) 的情况,做一次双重循环,对于每一个 i 对于每一个 j 做一次 sap(i, j) 并取最小值。

但是这样的写法是不是真的足够好?在搜索的时候是不是重复搜索了很多?

所以我们可以按下面代码所述,进行更合理的封装和优化,对于已经搜索过的,就不再搜索了。

需要注意的是本题存在的映射关系有:“已知 id 获取同义词集”,“已知单词获取该单词对应的同义词集”这两种。其中,“已知单词获取该单词对应的同义词集”可以由“已知单词获取 id”和“已知 id 获取同义词集”完成,所以我们只需要保存“已知单词获取 id”和“已知 id 获取同义词集”这样两个 HashMap 即可。

Outcast 就非常简单了,WordNet 写好了以后直接调用,取最大值就好。

import edu.princeton.cs.algs4.Digraph;
import edu.princeton.cs.algs4.DirectedCycle;
import edu.princeton.cs.algs4.In;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;


public class WordNet {

    // id 对应的同义词集
    private final ArrayList<String> synsetList = new ArrayList<>();
    // 单词对应的 id,注意一个单词可能有多个 id 中都出现了
    private final HashMap<String, List<Integer>> synsetToIdMap = new HashMap<>();
    private final SAP sap;

    public WordNet(String synsets, String hypernyms) {
        validate(synsets);
        validate(hypernyms);
        In in;
        int vertexNumber = 0;

        in = new In(synsets);
        while (in.hasNextLine()) {
            // 顶点个数 = 同义词集的个数
            vertexNumber++;
            String[] ss = in.readLine().split(",");
            int id = Integer.parseInt(ss[0]);
            // 因为有 API 是返回整个同义词集,所以先存下来,避免后续需要的时候遍历 List 去拼接
            synsetList.add(ss[1]);
            String[] words = ss[1].split(" ");
            // 将 id 与同义词 List 放进 HashMap
            for (String word : words) {
                if (!synsetToIdMap.containsKey(word)) {
                    synsetToIdMap.put(word, new ArrayList<>());
                }
                synsetToIdMap.get(word).add(id);
            }
        }

        Digraph g = new Digraph(vertexNumber);
        in = new In(hypernyms);
        while (in.hasNextLine()) {
            String[] ss = in.readLine().split(",");
            int id = Integer.parseInt(ss[0]);
            // 每一组上位词的对应关系就是一条边
            for (int i = 1; i < ss.length; i++) {
                g.addEdge(id, Integer.parseInt(ss[i]));
            }
        }

        validate(g);

        sap = new SAP(g);
    }

    private void validate(Digraph g) {
        assert g != null;
        int vertexNumber = g.V();
        int rootNumber = 0;
        for (int i = 0; i < vertexNumber; i++) {
            // 出度为 0 的点是根节点(没有上位词的同义词集)
            if (g.outdegree(i) == 0) {
                rootNumber++;
            }
        }
        // 根节点不足 1 或者大于 1 都不满足条件
        if (rootNumber != 1) {
            throw new IllegalArgumentException();
        }
        // The program uses neither 'DirectedCycle' nor 'Topological' to check whether the digraph is a DAG.
        DirectedCycle dc = new DirectedCycle(g);
        if (dc.hasCycle()) {
            throw new IllegalArgumentException();
        }
    }

    private void validate(String string) {
        if (string == null) {
            throw new IllegalArgumentException();
        }
    }

    public Iterable<String> nouns() {
        // 时刻注意这种类型返回的时候一定要不可变,所以这里返回的时候返回一个新的,不返回原来那个
        return new ArrayList<>(synsetToIdMap.keySet());
    }

    public boolean isNoun(String word) {
        validate(word);
        // 直接查缓存就可以了
        return synsetToIdMap.containsKey(word);
    }

    public int distance(String nounA, String nounB) {
        validate(nounA);
        validate(nounB);
        if (!isNoun(nounA) || !isNoun(nounB)) {
            throw new IllegalArgumentException();
        }
        return sap.length(synsetToIdMap.get(nounA), synsetToIdMap.get(nounB));
    }

    public String sap(String nounA, String nounB) {
        return synsetList.get(sap.ancestor(synsetToIdMap.get(nounA), synsetToIdMap.get(nounB)));
    }

    public static void main(String[] args) {
        WordNet wordNet = new WordNet("tiny_synsets.txt", "tiny_hypernyms.txt");
        System.out.println(wordNet.nouns());
        System.out.println(wordNet.isNoun("wordnet"));
        System.out.println(wordNet.distance("WordNet3.1", "wordnet"));
        System.out.println(wordNet.sap("wordnet", "WordNet3.1"));
    }
}
import edu.princeton.cs.algs4.Digraph;
import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.StdIn;
import edu.princeton.cs.algs4.StdOut;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;


/**
 * @author jxtxzzw
 */
public class SAP {

    private final Digraph g;
    private final int vertexNumber;

    public SAP(Digraph g) {
        // 时刻记得做成不可变的
        this.g = new Digraph(g);
        vertexNumber = g.V();
    }

    public int length(int v, int w) {
        validate(v);
        validate(w);
        return sap(v, w)[0];
    }

    public int ancestor(int v, int w) {
        validate(v);
        validate(w);
        return sap(v, w)[1];
    }

    public int length(Iterable<Integer> v, Iterable<Integer> w) {
        validate(v);
        validate(w);
        return sap(v, w)[0];
    }

    public int ancestor(Iterable<Integer> v, Iterable<Integer> w) {
        validate(v);
        validate(w);
        return sap(v, w)[1];
    }

    private void validate(int v) {
        // Any vertex argument is outside its prescribed range
        if (v < 0 || v >= vertexNumber) {
            throw new IllegalArgumentException();
        }
    }

    private void validate(Iterable<Integer> v) {
        // Any argument is null
        if (v == null) {
            throw new IllegalArgumentException();
        }
        for (Integer i : v) {
            // Any iterable argument contains a null item
            if (i == null) {
                throw new IllegalArgumentException();
            }
            // Call validate(int v)
            validate(i);
        }
    }

    private int[] sap(int v, int w) {
        ArrayList<Integer> vv = new ArrayList<>();
        vv.add(v);
        ArrayList<Integer> ww = new ArrayList<>();
        ww.add(w);
        return sap(vv, ww);
    }

    private int[] sap(Iterable<Integer> v, Iterable<Integer> w) {
        // 0 表示 length,1 表示 ancestor
        int[] sap = new int[2];
        // -1 if no such path
        Arrays.fill(sap, -1);
        // 记录点 x 到点 v 之间的距离
        HashMap<Integer, Integer> distanceV = new HashMap<>();
        HashMap<Integer, Integer> distanceW = new HashMap<>();

        // 遍历 q 的所有父结点
        LinkedList<Integer> q = new LinkedList<>();
        boolean[] visited = new boolean[vertexNumber];
        Arrays.fill(visited, false);
        for (int vx : v) {
            q.add(vx);
            visited[vx] = true;
            distanceV.put(vx, 0);
        }
        while (!q.isEmpty()) {
            int x = q.poll();
            Iterable<Integer> bag = g.adj(x);
            // 加入后面的点
            for (int vv : bag) {
                if (!visited[vv]) {
                    q.add(vv);
                    visited[vv] = true;
                    // 更新距离
                    int d = distanceV.get(x);
                    int dd = distanceV.getOrDefault(vv, d + 1);
                    distanceV.put(vv, dd);
                }
            }
        }
        Arrays.fill(visited, false);
        // 遍历 w 的所有父结点,找到最先遇到的
        for (int wx : w) {
            q.add(wx);
            visited[wx] = true;
            distanceW.put(wx, 0);
        }
        while (!q.isEmpty()) {
            int x = q.poll();
            if (distanceV.containsKey(x)) {
                // 更新最短的 LCA
                int minDistance = distanceV.get(x) + distanceW.get(x);
                if (sap[0] == -1 || minDistance < sap[0]) {
                    sap[0] = minDistance;
                    sap[1] = x;
                }
            }
            // 这里不是 else 的关系,要继续往上找
            Iterable<Integer> bag = g.adj(x);
            for (int vv : bag) {
                if (!visited[vv]) {
                    q.add(vv);
                    visited[vv] = true;
                    // 更新距离
                    int d = distanceW.get(x);
                    int dd = distanceW.getOrDefault(vv, d + 1);
                    distanceW.put(vv, dd);
                }
            }
        }
        return sap;
    }

    // do unit testing of this class
    public static void main(String[] args) {
        In in = new In("digraph2.txt");
        Digraph g = new Digraph(in);
        SAP sap = new SAP(g);
        while (!StdIn.isEmpty()) {
            int v = StdIn.readInt();
            int w = StdIn.readInt();
            int length = sap.length(v, w);
            int ancestor = sap.ancestor(v, w);
            StdOut.printf("length = %d, ancestor = %d\n", length, ancestor);
        }
//        while (!StdIn.isEmpty()) {
//            ArrayList<Integer> vv = new ArrayList<>();
//            ArrayList<Integer> ww = new ArrayList<>();
//            String[] v = StdIn.readLine().split(" ");
//            String[] w = StdIn.readLine().split(" ");
//            for (String s : v) {
//                vv.add(Integer.parseInt(s));
//            }
//            for (String s : w) {
//                ww.add(Integer.parseInt(s));
//            }
//            int length = sap.length(vv, ww);
//            int ancestor = sap.ancestor(vv, ww);
//            StdOut.printf("length = %d, ancestor = %d\n", length, ancestor);
//        }
    }
}
public class Outcast {
    private final WordNet wordNet;

    public Outcast(WordNet wordnet) {
        this.wordNet = wordnet;
    }

    public String outcast(String[] nouns) {
        int max = 0;
        String outcase = null;
        for (String noun : nouns) {
            int di = 0;
            for (String s : nouns) {
                di += wordNet.distance(noun, s);
            }
            if (di > max) {
                max = di;
                outcase = noun;
            }
        }
        return outcase;
    }

    public static void main(String[] args) {
        WordNet wordnet = new WordNet("synsets.txt", "hypernyms.txt");
        Outcast outcast = new Outcast(wordnet);
        System.out.println(outcast.outcast("horse zebra cat bear table".split(" ")));
        System.out.println(outcast.outcast("water soda bed orange_juice milk apple_juice tea coffee".split(" ")));
        System.out.println(outcast.outcast("apple pear peach banana lime lemon blueberry strawberry mango watermelon potato".split(" ")));
    }
}

完整代码也可参见 https://gitlab.jxtxzzw.com/jxtxzzw/coursera_assignments

订阅评论动态
提醒
guest
0 评论
行内反馈
查看所有评论