/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.dashscope.tokenizers;

import com.alibaba.dashscope.exception.NoSpecialTokenExists;
import com.alibaba.dashscope.exception.UnSupportedSpecialTokenMode;
import com.alibaba.dashscope.tokenizers.EncodeBytesEntity;
import com.alibaba.dashscope.tokenizers.Tokenizer;
import com.alibaba.dashscope.utils.StringUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class QwenTokenizer
implements Tokenizer {
    private static final String SPECIAL_START = "<|";
    private static final String SPECIAL_END = "|>";
    private static final String ENDOFTEXT = "<|endoftext|>";
    private static final String IMSTART = "<|im_start|>";
    private static final String IMEND = "<|im_end|>";
    private static final String PATTEN_STRING = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
    private static final int SPECIAL_START_ID = 151643;
    private static final String TOKEN_RANK_SEPARATOR = " ";
    private static final String vocabularyBpeFile = "qwen.tiktoken";
    private static final Map<EncodeBytesEntity, Integer> mergeableRanks;
    private static final Map<String, Integer> specialTokens;
    private static final byte[][] decodeMap;

    private EncodeBytesEntity mergePair(EncodeBytesEntity first, EncodeBytesEntity second) {
        byte[] bytesPair = Arrays.copyOf(first.bytes, first.bytes.length + second.bytes.length);
        System.arraycopy(second.bytes, 0, bytesPair, first.bytes.length, second.bytes.length);
        return new EncodeBytesEntity(bytesPair);
    }

    private EncodeBytesEntity getLowestIndexBytePair(EncodeBytesEntity[] ids) {
        ArrayList<EncodeBytesEntity> bytePairs = new ArrayList<EncodeBytesEntity>();
        Integer minRank = Integer.MAX_VALUE;
        EncodeBytesEntity minRankPair = null;
        for (int i = 0; i < ids.length - 1; ++i) {
            EncodeBytesEntity bytePair = this.mergePair(ids[i], ids[i + 1]);
            if (bytePairs.indexOf(bytePair) != -1) continue;
            Integer rank = mergeableRanks.get(bytePair);
            if (rank == null) {
                bytePair.rank = Integer.MAX_VALUE;
            } else {
                bytePair.rank = rank;
                if (rank < minRank) {
                    minRank = rank;
                    minRankPair = bytePair;
                }
            }
            bytePairs.add(bytePair);
        }
        return minRankPair;
    }

    private EncodeBytesEntity[] merge(EncodeBytesEntity[] ids, EncodeBytesEntity bytePair) {
        EncodeBytesEntity[] merged = new EncodeBytesEntity[ids.length];
        int mergedIndex = 0;
        int i = 0;
        while (i < ids.length) {
            if (i < ids.length - 1) {
                EncodeBytesEntity mergePair = this.mergePair(ids[i], ids[i + 1]);
                if (mergePair.equals(bytePair)) {
                    merged[mergedIndex++] = bytePair;
                    i += 2;
                    continue;
                }
                merged[mergedIndex++] = ids[i];
                ++i;
                continue;
            }
            merged[mergedIndex++] = ids[i];
            ++i;
        }
        return Arrays.copyOfRange(merged, 0, mergedIndex);
    }

    private List<Integer> encodeChunk(String chunk) {
        EncodeBytesEntity bytePair;
        byte[] chunkBytes = chunk.getBytes(StandardCharsets.UTF_8);
        EncodeBytesEntity[] ids = new EncodeBytesEntity[chunkBytes.length];
        int idx = 0;
        for (byte b : chunkBytes) {
            EncodeBytesEntity rankKey = new EncodeBytesEntity(new byte[]{b});
            rankKey.rank = mergeableRanks.get(rankKey);
            ids[idx++] = rankKey;
        }
        ArrayList<Integer> tokens = new ArrayList<Integer>();
        if (ids.length < 2) {
            for (EncodeBytesEntity key : ids) {
                tokens.add(key.rank);
            }
            return tokens;
        }
        while (ids.length >= 2 && (bytePair = this.getLowestIndexBytePair(ids)) != null) {
            ids = this.merge(ids, bytePair);
        }
        for (EncodeBytesEntity key : ids) {
            tokens.add(key.rank);
        }
        return tokens;
    }

    @Override
    public List<Integer> encodeOrdinary(String text) {
        ArrayList<Integer> tokenIds = new ArrayList<Integer>();
        Pattern pattern = Pattern.compile(PATTEN_STRING);
        Matcher matcher = pattern.matcher(text);
        while (matcher.find()) {
            tokenIds.addAll(this.encodeChunk(matcher.group()));
        }
        return tokenIds;
    }

    private List<String> splitWithSpecial(String text) {
        List<String> chunks = new ArrayList<String>();
        if (text.contains(SPECIAL_START) && text.contains(SPECIAL_END)) {
            chunks = StringUtils.splitByStrings(text, specialTokens.keySet());
        } else {
            chunks.add(text);
        }
        return chunks;
    }

    @Override
    public List<Integer> encode(String text, String allowedSpecial) throws NoSpecialTokenExists, UnSupportedSpecialTokenMode {
        if (allowedSpecial == null) {
            allowedSpecial = "all";
        }
        Map<String, Integer> specialTokensUse = null;
        if ("all".equals(allowedSpecial)) {
            specialTokensUse = specialTokens;
        } else if ("none".equals(allowedSpecial)) {
            specialTokensUse = new LinkedHashMap<String, Integer>();
        } else if ("none_raise".equals(allowedSpecial)) {
            specialTokensUse = new LinkedHashMap<String, Integer>();
            boolean isSpecialTokenExists = false;
            for (String token : specialTokens.keySet()) {
                if (text.indexOf(token) == -1) continue;
                isSpecialTokenExists = true;
                break;
            }
            if (!isSpecialTokenExists) {
                throw new NoSpecialTokenExists(String.format("No special token in %s", text));
            }
        } else {
            throw new UnSupportedSpecialTokenMode(String.format("UnSupport allowedSpecial: %s", allowedSpecial));
        }
        if (specialTokensUse.isEmpty()) {
            return this.encodeOrdinary(text);
        }
        List<String> chunks = this.splitWithSpecial(text);
        ArrayList<Integer> tokens = new ArrayList<Integer>();
        for (String chunk : chunks) {
            if (specialTokensUse.containsKey(chunk)) {
                tokens.add(specialTokensUse.get(chunk));
                continue;
            }
            tokens.addAll(this.encodeOrdinary(chunk));
        }
        return tokens;
    }

    @Override
    public String decode(List<Integer> tokens) {
        StringBuilder sb = new StringBuilder();
        for (Integer token : tokens) {
            byte[] bytes = decodeMap[token];
            sb.append(new String(bytes, StandardCharsets.UTF_8));
        }
        return sb.toString();
    }

    static {
        LinkedHashMap<String, Integer> map = new LinkedHashMap<String, Integer>();
        int specialStartIndex = 151643;
        map.put(ENDOFTEXT, specialStartIndex++);
        map.put(IMSTART, specialStartIndex++);
        map.put(IMEND, specialStartIndex++);
        for (int i = 0; i < 205; ++i) {
            String specialToken = String.format("<|extra_%d|>", i);
            map.put(specialToken, specialStartIndex++);
        }
        specialTokens = Collections.unmodifiableMap(map);
        mergeableRanks = new LinkedHashMap<EncodeBytesEntity, Integer>();
        ClassLoader classLoader = QwenTokenizer.class.getClassLoader();
        try {
            String line;
            InputStream inputStream = classLoader.getResourceAsStream(vocabularyBpeFile);
            BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
            while ((line = reader.readLine()) != null) {
                String[] splits = line.split(TOKEN_RANK_SEPARATOR);
                assert (splits.length == 2) : "Invalid line in qwen.tiktoken: " + line;
                byte[] byArray = Base64.getDecoder().decode(splits[0].getBytes(StandardCharsets.UTF_8));
                int rank = Integer.valueOf(splits[1]);
                mergeableRanks.put(new EncodeBytesEntity(byArray, rank), rank);
            }
            decodeMap = new byte[mergeableRanks.size() + specialTokens.size()][];
            for (Map.Entry<EncodeBytesEntity, Integer> entry : mergeableRanks.entrySet()) {
                QwenTokenizer.decodeMap[entry.getValue().intValue()] = Arrays.copyOf(entry.getKey().bytes, entry.getKey().bytes.length);
            }
            for (Map.Entry<Object, Integer> entry : specialTokens.entrySet()) {
                byte[] b = ((String)entry.getKey()).getBytes(StandardCharsets.UTF_8);
                QwenTokenizer.decodeMap[entry.getValue().intValue()] = Arrays.copyOf(b, b.length);
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Could not load qwen.tiktoken from resources", e);
        }
    }
}

