/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.NDSerializer;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PushbackInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class NDList
extends ArrayList<NDArray>
implements NDResource,
BytesSupplier {
    private static final long serialVersionUID = 1L;

    public NDList() {
    }

    public NDList(int initialCapacity) {
        super(initialCapacity);
    }

    public NDList(NDArray ... arrays) {
        super(Arrays.asList(arrays));
    }

    public NDList(Collection<NDArray> other) {
        super(other);
    }

    public static NDList decode(NDManager manager, byte[] byteArray) {
        if (byteArray.length < 9) {
            throw new IllegalArgumentException("Invalid input length: " + byteArray.length);
        }
        try {
            if (byteArray[0] == 80 && byteArray[1] == 75) {
                return NDList.decodeNumpy(manager, new ByteArrayInputStream(byteArray));
            }
            if (byteArray[0] == 57 && byteArray[1] == 78 && byteArray[2] == 85 && byteArray[3] == 77) {
                return new NDList(NDSerializer.decode(manager, new ByteArrayInputStream(byteArray)));
            }
            if (byteArray[8] == 123) {
                return NDList.decodeSafetensors(manager, new ByteArrayInputStream(byteArray));
            }
            ByteBuffer bb = ByteBuffer.wrap(byteArray);
            int size = bb.getInt();
            if (size < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + size);
            }
            NDList list = new NDList();
            for (int i = 0; i < size; ++i) {
                list.add(i, NDSerializer.decode(manager, bb));
            }
            return list;
        }
        catch (IOException | BufferUnderflowException e) {
            throw new IllegalArgumentException("Invalid NDArray input", e);
        }
    }

    public static NDList decode(NDManager manager, InputStream is) {
        try {
            DataInputStream dis = new DataInputStream(is);
            byte[] magic = new byte[9];
            dis.readFully(magic);
            PushbackInputStream pis = new PushbackInputStream(is, 9);
            pis.unread(magic);
            if (magic[0] == 80 && magic[1] == 75) {
                return NDList.decodeNumpy(manager, pis);
            }
            if (magic[0] == 57 && magic[1] == 78 && magic[2] == 85 && magic[3] == 77) {
                return new NDList(NDSerializer.decode(manager, pis));
            }
            if (magic[8] == 123) {
                return NDList.decodeSafetensors(manager, pis);
            }
            dis = new DataInputStream(pis);
            int size = dis.readInt();
            if (size < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + size);
            }
            NDList list = new NDList();
            for (int i = 0; i < size; ++i) {
                list.add(i, manager.decode(dis));
            }
            return list;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Malformed data", e);
        }
    }

    private static NDList decodeSafetensors(NDManager manager, InputStream is) throws IOException {
        DataInputStream dis = is instanceof DataInputStream ? (DataInputStream)is : new DataInputStream(is);
        byte[] buf = new byte[8];
        dis.readFully(buf);
        int len = Math.toIntExact(ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getLong());
        buf = new byte[len];
        dis.readFully(buf);
        String json = new String(buf, StandardCharsets.UTF_8);
        JsonObject jsonObject = (JsonObject)JsonUtils.GSON.fromJson(json, JsonObject.class);
        ArrayList<Pair<String, SafeTensor>> list = new ArrayList<Pair<String, SafeTensor>>();
        int max = 0;
        for (String key : jsonObject.keySet()) {
            if ("__metadata__".equals(key)) continue;
            SafeTensor safeTensor = (SafeTensor)JsonUtils.GSON.fromJson(jsonObject.get(key), SafeTensor.class);
            if (safeTensor.offsets.length != 2) {
                throw new IOException("Malformed safetensors metadata: " + json);
            }
            max = Math.max(max, safeTensor.offsets[1]);
            list.add(new Pair<String, SafeTensor>(key, safeTensor));
        }
        buf = new byte[max];
        dis.readFully(buf);
        NDList ret = new NDList(list.size());
        for (Pair pair : list) {
            if ("__metadata__".equals(pair.getKey())) continue;
            SafeTensor st = (SafeTensor)pair.getValue();
            Shape shape = new Shape(st.shape);
            ByteBuffer bb = ByteBuffer.wrap(buf, st.offsets[0], st.size());
            bb.order(ByteOrder.LITTLE_ENDIAN);
            DataType dataType = DataType.fromSafetensors(st.dtype);
            NDArray array = manager.create(bb, shape, dataType);
            array.setName((String)pair.getKey());
            ret.add(array);
        }
        return ret;
    }

    private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException {
        ZipEntry entry;
        NDList list = new NDList();
        ZipInputStream zis = new ZipInputStream(is);
        while ((entry = zis.getNextEntry()) != null) {
            String name = entry.getName();
            NDArray array = NDSerializer.decodeNumpy(manager, zis);
            if (!name.startsWith("arr_") && name.endsWith(".npy")) {
                array.setName(name.substring(0, name.length() - 4));
            }
            list.add(array);
        }
        return list;
    }

    public NDArray get(String name) {
        for (NDArray array : this) {
            if (!name.equals(array.getName())) continue;
            return array;
        }
        return null;
    }

    public NDArray remove(String name) {
        int index = 0;
        for (NDArray array : this) {
            if (name.equals(array.getName())) {
                this.remove(index);
                return array;
            }
            ++index;
        }
        return null;
    }

    public boolean contains(String name) {
        for (NDArray array : this) {
            if (!name.equals(array.getName())) continue;
            return true;
        }
        return false;
    }

    public NDArray head() {
        return (NDArray)this.get(0);
    }

    public NDArray singletonOrThrow() {
        if (this.size() != 1) {
            throw new IndexOutOfBoundsException("Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + this.size());
        }
        return (NDArray)this.get(0);
    }

    public NDList addAll(NDList other) {
        for (NDArray array : other) {
            this.add(array);
        }
        return this;
    }

    public NDList subNDList(int fromIndex) {
        return this.subNDList(fromIndex, this.size());
    }

    public NDList subNDList(int fromIndex, int toIndex) {
        return new NDList((Collection<NDArray>)this.subList(fromIndex, toIndex));
    }

    public NDList toDevice(Device device, boolean copy) {
        if (!copy && this.stream().allMatch(array -> array.getDevice() == device)) {
            return this;
        }
        NDList newNDList = new NDList(this.size());
        this.forEach((? super E a) -> newNDList.add(a.toDevice(device, copy)));
        return newNDList;
    }

    @Override
    public NDManager getManager() {
        return this.head().getManager();
    }

    @Override
    public List<NDArray> getResourceNDArrays() {
        return this;
    }

    @Override
    public void attach(NDManager manager) {
        this.forEach((? super E array) -> array.attach(manager));
    }

    @Override
    public void tempAttach(NDManager manager) {
        this.forEach((? super E array) -> array.tempAttach(manager));
    }

    @Override
    public void detach() {
        this.forEach(NDResource::detach);
    }

    public byte[] encode() {
        return this.encode(Encoding.ND_LIST);
    }

    public byte[] encode(Encoding encoding) {
        byte[] byArray;
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try {
            this.encode(baos, encoding);
            byArray = baos.toByteArray();
        }
        catch (Throwable throwable) {
            try {
                try {
                    baos.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new AssertionError("NDList is not writable", e);
            }
        }
        baos.close();
        return byArray;
    }

    public void encode(OutputStream os) throws IOException {
        this.encode(os, Encoding.ND_LIST);
    }

    public void encode(OutputStream os, Encoding encoding) throws IOException {
        if (encoding == Encoding.NPZ) {
            ZipOutputStream zos = new ZipOutputStream(os);
            int i = 0;
            for (NDArray nd : this) {
                String name = nd.getName();
                if (name == null) {
                    zos.putNextEntry(new ZipEntry("arr_" + i + ".npy"));
                    ++i;
                } else {
                    zos.putNextEntry(new ZipEntry(name + ".npy"));
                }
                NDSerializer.encodeAsNumpy(nd, zos);
            }
            zos.finish();
            zos.flush();
            return;
        }
        if (encoding == Encoding.SAFETENSORS) {
            ConcurrentHashMap<String, SafeTensor> map = new ConcurrentHashMap<String, SafeTensor>(this.size());
            int i = 0;
            int offset = 0;
            for (NDArray nd : this) {
                String name = nd.getName();
                if (name == null) {
                    name = "arr_" + i;
                    ++i;
                }
                SafeTensor st = new SafeTensor();
                st.dtype = nd.getDataType().asSafetensors();
                st.shape = nd.getShape().getShape();
                long size = (long)nd.getDataType().getNumOfBytes() * nd.size();
                int limit = offset + Math.toIntExact(size);
                st.offsets = new int[]{offset, limit};
                map.put(name, st);
                offset = limit;
            }
            byte[] json = JsonUtils.GSON.toJson(map).getBytes(StandardCharsets.UTF_8);
            ByteBuffer buf = ByteBuffer.allocate(8);
            buf.order(ByteOrder.LITTLE_ENDIAN);
            buf.putLong(0, json.length);
            os.write(buf.array());
            os.write(json);
            for (NDArray nd : this) {
                os.write(nd.toByteArray());
            }
            return;
        }
        DataOutputStream dos = new DataOutputStream(os);
        dos.writeInt(this.size());
        for (NDArray nd : this) {
            NDSerializer.encode(nd, dos);
        }
        dos.flush();
    }

    @Override
    public byte[] getAsBytes() {
        return this.encode();
    }

    @Override
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(this.encode());
    }

    public Shape[] getShapes() {
        return (Shape[])this.stream().map(NDArray::getShape).toArray(Shape[]::new);
    }

    @Override
    public void close() {
        this.forEach(NDArray::close);
        this.clear();
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder(200);
        builder.append("NDList size: ").append(this.size()).append('\n');
        int index = 0;
        for (NDArray array : this) {
            String name = array.getName();
            builder.append(index++).append(' ');
            if (name != null) {
                builder.append(name);
            }
            builder.append(": ").append(array.getShape()).append(' ').append((Object)array.getDataType()).append('\n');
        }
        return builder.toString();
    }

    private static final class SafeTensor {
        String dtype;
        long[] shape;
        @SerializedName(value="data_offsets")
        int[] offsets;

        private SafeTensor() {
        }

        int size() {
            return this.offsets[1] - this.offsets[0];
        }
    }

    public static enum Encoding {
        ND_LIST,
        NPZ,
        SAFETENSORS;

    }
}

