/*
 * Decompiled with CFR 0.152.
 */
package ghidra.app.plugin.core.strings;

import generic.jar.ResourceFile;
import ghidra.app.plugin.core.strings.StringTrigramIterator;
import ghidra.app.plugin.core.strings.Trigram;
import ghidra.app.services.StringValidatorQuery;
import ghidra.app.services.StringValidatorService;
import ghidra.app.services.StringValidityScore;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

public class TrigramStringValidator
implements StringValidatorService {
    @Deprecated(forRemoval=true, since="10.3")
    private static final boolean PRESERVE_BUG_SKIP_TRIGRAM = true;
    private static final double DEFAULT_LOG_VALUE = -20.0;
    private static final double INVALID_THRESHOLD = 10.0;
    private ResourceFile sourceFile;
    private Map<Trigram, Double> trigramLogs;
    private long totalNumTrigrams;
    private Function<String, String> modelValueTransformer;
    private double[] thresholds;

    public static TrigramStringValidator read(ResourceFile f) throws IOException {
        return TrigramStringValidator.readModel(f);
    }

    public TrigramStringValidator(Map<Trigram, Double> trigramLogs, long totalNumTrigrams, Function<String, String> modelValueTransformer, double[] thresholds, ResourceFile sourceFile) {
        this.trigramLogs = trigramLogs;
        this.totalNumTrigrams = totalNumTrigrams;
        this.modelValueTransformer = modelValueTransformer;
        this.thresholds = thresholds;
        this.sourceFile = sourceFile;
    }

    public ResourceFile getSourceFile() {
        return this.sourceFile;
    }

    @Override
    public String getValidatorServiceName() {
        return "ngram";
    }

    @Override
    public StringValidityScore getStringValidityScore(StringValidatorQuery query) {
        String transformedString = this.modelValueTransformer.apply(query.stringValue());
        double score = -20.0;
        int trigramCount = 0;
        StringTrigramIterator it = Trigram.iterate(transformedString);
        if (it.hasNext()) {
            double missingTrigramScore = Math.log10(1.0 / (double)this.totalNumTrigrams);
            score = 0.0;
            while (it.hasNext()) {
                Trigram trigram = it.next();
                if (++trigramCount == 2) continue;
                Double logProb = this.trigramLogs.get(trigram);
                if (logProb == null) {
                    logProb = missingTrigramScore;
                }
                score += logProb.doubleValue();
            }
            score /= (double)trigramCount;
        }
        return new StringValidityScore(query.stringValue(), transformedString, score, this.getThresholdForStringOfLength(trigramCount));
    }

    public long getTotalNumTrigrams() {
        return this.totalNumTrigrams;
    }

    public Iterator<String> dumpModel() {
        return this.trigramLogs.keySet().stream().sorted().map(trigram -> "%s=%s".formatted(trigram.toCharSeq(), this.trigramLogs.get(trigram))).iterator();
    }

    private double getThresholdForStringOfLength(int len) {
        int index = len - 4;
        if (index < 0) {
            return 10.0;
        }
        if (index >= this.thresholds.length) {
            index = this.thresholds.length - 1;
        }
        return this.thresholds[index];
    }

    private static TrigramStringValidator readModel(ResourceFile sourceFile) throws IOException {
        TrigramStringValidator trigramStringValidator;
        HashMap<Trigram, Integer> counts = new HashMap<Trigram, Integer>();
        long totalTrigrams = 0L;
        String modelType = null;
        double[] thresholds = null;
        int symbolSize = 128;
        int lineNum = 0;
        boolean inFileHeaderSection = true;
        String currString = "";
        BufferedReader br = new BufferedReader(new InputStreamReader(sourceFile.getInputStream(), StandardCharsets.UTF_8));
        try {
            while ((currString = br.readLine()) != null) {
                ++lineNum;
                if (currString.isBlank()) continue;
                if (inFileHeaderSection && currString.startsWith("#")) {
                    String[] headerFields = TrigramStringValidator.parseHeaderLine(currString.substring(1).trim());
                    if (headerFields == null) continue;
                    switch (headerFields[0]) {
                        case "Model Type": {
                            modelType = headerFields[1];
                            break;
                        }
                        case "Thresholds": {
                            thresholds = TrigramStringValidator.parseThresholds(headerFields[1]);
                            break;
                        }
                        case "Symbol Size": {
                            symbolSize = Integer.parseInt(headerFields[1]);
                        }
                    }
                    continue;
                }
                inFileHeaderSection = false;
                String[] lineParts = currString.split("\\t");
                if (lineParts.length != 4) {
                    throw new IOException("Invalid field count in ngram %s:%d: %s".formatted(sourceFile.getName(), lineNum, currString));
                }
                Trigram trigram = Trigram.fromStringRep(lineParts[0], lineParts[1], lineParts[2]);
                int currCount = Integer.parseInt(lineParts[3]);
                int[] codePoints = trigram.codePoints();
                if (codePoints[1] == 0 || codePoints[0] == 0 && codePoints[2] == 0) continue;
                counts.merge(trigram, currCount, (oldVal, newVal) -> oldVal + newVal);
                totalTrigrams += (long)currCount;
            }
            int trigramEntryCount = counts.size();
            int expectedEntryCount = symbolSize * symbolSize * symbolSize + symbolSize * symbolSize * 2;
            Map<Trigram, Double> logProb = TrigramStringValidator.calculateLogProbs(counts, totalTrigrams += (long)(expectedEntryCount - trigramEntryCount));
            modelType = Objects.requireNonNullElse(modelType, "");
            Function<String, String> transformer = TrigramStringValidator.getStringTransformer(modelType);
            transformer = transformer.andThen(s -> s.trim().replaceAll(" {2,}", " ").replaceAll("\t{2,}", "\t"));
            trigramStringValidator = new TrigramStringValidator(logProb, totalTrigrams, transformer, thresholds, sourceFile);
        }
        catch (Throwable throwable) {
            try {
                try {
                    br.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (NumberFormatException nfe) {
                throw new IOException("Error parsing string ngram %s:%d: %s".formatted(sourceFile.getName(), lineNum, currString));
            }
        }
        br.close();
        return trigramStringValidator;
    }

    private static Function<String, String> getStringTransformer(String modelTypeName) {
        Function<String, String> transformer = switch (modelTypeName) {
            case "lowercase" -> String::toLowerCase;
            default -> Function.identity();
        };
        return transformer;
    }

    private static String[] parseHeaderLine(String s) {
        String[] stringArray;
        int colon = s.indexOf(58);
        if (colon > 0) {
            String[] stringArray2 = new String[2];
            stringArray2[0] = s.substring(0, colon).trim();
            stringArray = stringArray2;
            stringArray2[1] = s.substring(colon + 1).trim();
        } else {
            stringArray = null;
        }
        return stringArray;
    }

    private static double[] parseThresholds(String s) {
        String[] parts = s.split(",");
        double[] results = new double[parts.length];
        for (int i = 0; i < parts.length; ++i) {
            double d;
            String thresholdValStr = parts[i];
            results[i] = d = Double.parseDouble(thresholdValStr.trim());
        }
        return results;
    }

    private static Map<Trigram, Double> calculateLogProbs(Map<Trigram, Integer> counts, long totalTrigrams) {
        double totalTrigramsD = totalTrigrams;
        HashMap<Trigram, Double> logTrigrams = new HashMap<Trigram, Double>();
        for (Map.Entry<Trigram, Integer> entry : counts.entrySet()) {
            Trigram trigram = entry.getKey();
            Integer count = entry.getValue();
            logTrigrams.put(trigram, Math.log10((double)count.intValue() / totalTrigramsD));
        }
        return logTrigrams;
    }
}

