Различные распределения тем для одних и тех же данных с моделированием тем с помощью молотка

Я использую Mallet topic modeling и обучил модель. Сразу после обучения распечатываю раздачу тем для одного из документов обучающего набора и сохраняю. Затем я пробую тот же документ, что и тестовый набор, и пропускаю его по тем же каналам и так далее. Но у меня для этого есть совсем другая раздача тем. Тема с наивысшим рейтингом после обучения, которая с вероятностью около 0,54 имеет вероятность 0,000 при использовании в качестве тестового набора. Вот мои коды для обучения и тестирования:

 public static ArrayList<Object> trainModel() throws IOException {

        String fileName = "E:\\Alltogether.txt";
        String stopwords = "E:\\stopwords-en.txt";
        // Begin by importing documents from text to feature sequences
        ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

        // Pipes: lowercase, tokenize, remove stopwords, map to features
        pipeList.add(new CharSequenceLowercase());
        pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
        pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
        pipeList.add(new TokenSequenceRemoveNonAlpha(true));
        pipeList.add(new TokenSequence2FeatureSequence());
        InstanceList instances = new InstanceList(new SerialPipes(pipeList));

        Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
        instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
                3, 2, 1)); // data, label, name fields

        int numTopics = 75;
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);

        model.setOptimizeInterval(20);
        model.addInstances(instances);
        model.setNumThreads(2);
        model.setNumIterations(2000);
        model.estimate();

        ArrayList<Object> results = new ArrayList<>();
        results.add(model);
        results.add(instances);

        Alphabet dataAlphabet = instances.getDataAlphabet();

        FeatureSequence tokens = (FeatureSequence) model.getData().get(66).instance.getData();
        LabelSequence topics = model.getData().get(66).topicSequence;

        Formatter out = new Formatter(new StringBuilder(), Locale.US);
        for (int position = 0; position < tokens.getLength(); position++) {
            out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
        }
        System.out.println(out);

        // Estimate the topic distribution of the 66th instance,
        //  given the current Gibbs state.
        double[] topicDistribution = model.getTopicProbabilities(66);

        ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();

        for (int topic = 0; topic < numTopics; topic++) {
            Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();

            out = new Formatter(new StringBuilder(), Locale.US);
            out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
            int rank = 0;
            while (iterator.hasNext() && rank < 10) {
                IDSorter idCountPair = iterator.next();
                out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
                rank++;
            }
            System.out.println(out);
        }

        return results;
    }

А вот и тестовая часть:

private static void testModel(ArrayList<Object> results, String testDir) {


    ParallelTopicModel model = (ParallelTopicModel) results.get(0);
    InstanceList allTrainInstances = (InstanceList) results.get(1);

    String stopwords = "E:\\stopwords-en.txt";

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

    pipeList.add(new CharSequenceLowercase());
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
    pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
    pipeList.add(new TokenSequenceRemoveNonAlpha(true));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList instances = new InstanceList(new SerialPipes(pipeList));

    Reader fileReader = null;
    try {
        fileReader = new InputStreamReader(new FileInputStream(new File(testDir)), "UTF-8");
    } catch (UnsupportedEncodingException e) {
        e.printStackTrace();
    } catch (FileNotFoundException e) {
        e.printStackTrace();
    }
    instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3, 2, 1)); // data, label, name fields

    TopicInferencer inferencer = model.getInferencer();
    inferencer.setRandomSeed(1);

    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
    System.out.println(testProbabilities);
    int index = getMaximum(testProbabilities);

    ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();

    Alphabet dataAlphabet = allTrainInstances.getDataAlphabet();
    Formatter out = new Formatter(new StringBuilder(), Locale.US);

    for (int topic = 0; topic < 75; topic++) {
        Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();

        out = new Formatter(new StringBuilder(), Locale.US);
        out.format("%d\t%.3f\t", topic, testProbabilities[topic]);
        int rank = 0;
        while (iterator.hasNext() && rank < 10) {
            IDSorter idCountPair = iterator.next();
            out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
            rank++;
        }
        System.out.println(out);
    }

}

В соответствии

    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);

Я просто вижу, что вероятности разные. Тем временем я пробовал использовать разные файлы, но всегда получаю ту же тему, что и тема с самым высоким рейтингом. Любая помощь приветствуется.


person user1419243    schedule 04.01.2018    source источник


Ответы (1)


Я отвечаю на свой вопрос для дальнейшего использования, если кто-то столкнется с той же проблемой. В документах MALLET сказано, что вы должны использовать одни и те же каналы для обучения и тестирования. Я понял, что «новое» использование тех же каналов, что и на этапе обучения, НЕ означает использование тех же каналов. Вы должны сохранить трубы при обучении модели и повторно загрузить их при тестировании. Я взял образец кода для этот вопрос, и теперь он работает.

person user1419243    schedule 05.01.2018