/*
 * Decompiled with CFR 0.152.
 */
package org.bytedeco.pytorch;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.Indexable;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.TensorOptions;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.pytorch.presets.torch;

@Properties(inherit={torch.class})
public abstract class AbstractTensor
extends Pointer
implements Indexable {
    public AbstractTensor(Pointer p) {
        super(p);
    }

    public static Tensor create(byte[] data, boolean signed) {
        return AbstractTensor.create(data, signed, data.length);
    }

    public static Tensor create(byte ... data) {
        return AbstractTensor.create(data, false, data.length);
    }

    public static Tensor create(short ... data) {
        return AbstractTensor.create(data, new long[]{data.length});
    }

    public static Tensor create(int ... data) {
        return AbstractTensor.create(data, new long[]{data.length});
    }

    public static Tensor create(long ... data) {
        return AbstractTensor.create(data, new long[]{data.length});
    }

    public static Tensor create(float ... data) {
        return AbstractTensor.create(data, new long[]{data.length});
    }

    public static Tensor create(double ... data) {
        return AbstractTensor.create(data, new long[]{data.length});
    }

    public static Tensor create(boolean ... data) {
        return AbstractTensor.create(data, data.length);
    }

    public static Tensor create(byte[] data, boolean signed, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(signed ? torch.ScalarType.Char : torch.ScalarType.Byte), null);
        ByteBuffer b = (ByteBuffer)t.createBuffer();
        b.put(data);
        return t;
    }

    public static Tensor create(byte[] data, long ... shape) {
        return AbstractTensor.create(data, false, shape);
    }

    public static Tensor create(short[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Short), null);
        ShortIndexer i = (ShortIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public static Tensor create(int[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Int), null);
        IntIndexer i = (IntIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public static Tensor create(long[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Long), null);
        LongIndexer i = (LongIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public static Tensor create(float[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Float), null);
        FloatIndexer i = (FloatIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public static Tensor create(double[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Double), null);
        DoubleIndexer i = (DoubleIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public static Tensor create(boolean[] data, long ... shape) {
        Tensor t = org.bytedeco.pytorch.global.torch.empty(shape, new TensorOptions(torch.ScalarType.Bool), null);
        BooleanIndexer i = (BooleanIndexer)t.createIndexer();
        i.put(0L, data);
        return t;
    }

    public abstract TensorOptions options();

    public abstract torch.ScalarType scalar_type();

    public abstract long ndimension();

    public abstract long size(long var1);

    public abstract long stride(long var1);

    public abstract long numel();

    public abstract long nbytes();

    public abstract Pointer data_ptr();

    public long[] shape() {
        long[] out = new long[(int)this.ndimension()];
        for (int i = 0; i < out.length; ++i) {
            out[i] = this.size(i);
        }
        return out;
    }

    public <B extends Buffer> B createBuffer() {
        return this.createBuffer(0L);
    }

    public <B extends Buffer> B createBuffer(long index) {
        TensorOptions options = this.options();
        if (options.layout().intern() != torch.Layout.Strided) {
            throw new UnsupportedOperationException("Layout not supported: " + (Object)((Object)options.layout().intern()));
        }
        if (options.device().type().intern() != torch.DeviceType.CPU) {
            throw new UnsupportedOperationException("Device type not supported: " + (Object)((Object)options.device().type().intern()));
        }
        torch.ScalarType dtype = this.scalar_type().intern();
        Pointer ptr = this.data_ptr();
        long size = this.nbytes();
        switch (dtype) {
            case Byte: {
                return (B)new BytePointer(ptr).position(index).capacity(size).asBuffer();
            }
            case Char: {
                return (B)new BytePointer(ptr).position(index).capacity(size).asBuffer();
            }
            case Short: {
                return (B)new ShortPointer(ptr).position(index).capacity(size / 2L).asBuffer();
            }
            case Int: {
                return (B)new IntPointer(ptr).position(index).capacity(size / 4L).asBuffer();
            }
            case Long: {
                return (B)new LongPointer(ptr).position(index).capacity(size / 8L).asBuffer();
            }
            case Half: {
                return (B)new ShortPointer(ptr).position(index).capacity(size / 2L).asBuffer();
            }
            case Float: {
                return (B)new FloatPointer(ptr).position(index).capacity(size / 4L).asBuffer();
            }
            case Double: {
                return (B)new DoublePointer(ptr).position(index).capacity(size / 8L).asBuffer();
            }
            case ComplexHalf: {
                return (B)new ShortPointer(ptr).position(index * 2L).capacity(size / 2L).asBuffer();
            }
            case ComplexFloat: {
                return (B)new FloatPointer(ptr).position(index * 2L).capacity(size / 4L).asBuffer();
            }
            case ComplexDouble: {
                return (B)new DoublePointer(ptr).position(index * 2L).capacity(size / 8L).asBuffer();
            }
            case Bool: {
                return (B)new BytePointer(ptr).position(index).capacity(size).asBuffer();
            }
            case QInt8: {
                return (B)new BytePointer(ptr).position(index).capacity(size).asBuffer();
            }
            case QUInt8: {
                return (B)new BytePointer(ptr).position(index).capacity(size).asBuffer();
            }
            case QInt32: {
                return (B)new IntPointer(ptr).position(index).capacity(size / 4L).asBuffer();
            }
            case BFloat16: {
                return (B)new ShortPointer(ptr).position(index).capacity(size / 2L).asBuffer();
            }
            case QUInt4x2: {
                return (B)new BytePointer(ptr).position(index / 2L).capacity(size).asBuffer();
            }
        }
        throw new UnsupportedOperationException("Data type not supported: " + (Object)((Object)dtype));
    }

    public <I extends Indexer> I createIndexer() {
        return this.createIndexer(true);
    }

    public <I extends Indexer> I createIndexer(boolean direct) {
        TensorOptions options = this.options();
        if (options.layout().intern() != torch.Layout.Strided) {
            throw new UnsupportedOperationException("Layout not supported: " + (Object)((Object)options.layout().intern()));
        }
        if (options.device().type().intern() != torch.DeviceType.CPU) {
            throw new UnsupportedOperationException("Device type not supported: " + (Object)((Object)options.device().type().intern()));
        }
        torch.ScalarType dtype = this.scalar_type().intern();
        Pointer ptr = this.data_ptr();
        long size = this.nbytes();
        int dims = (int)this.ndimension();
        boolean complex = dtype == torch.ScalarType.ComplexHalf || dtype == torch.ScalarType.ComplexFloat || dtype == torch.ScalarType.ComplexDouble;
        boolean scalar = dims == 0;
        dims = (complex ? 1 : 0) + (scalar ? 1 : dims);
        long[] sizes = new long[dims];
        long[] strides = new long[dims];
        long l = complex ? 2L : (sizes[dims - 1] = scalar ? 1L : this.size(dims - 1));
        strides[dims - 1] = complex ? 1L : (scalar ? 1L : this.stride(dims - 1));
        for (int i = dims - 2; i >= 0; --i) {
            sizes[i] = scalar ? 1L : this.size(i);
            strides[i] = scalar ? 1L : this.stride(i);
        }
        switch (dtype) {
            case Byte: {
                return (I)UByteIndexer.create((BytePointer)new BytePointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Char: {
                return (I)ByteIndexer.create((BytePointer)new BytePointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Short: {
                return (I)ShortIndexer.create((ShortPointer)new ShortPointer(ptr).capacity(size / 2L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Int: {
                return (I)IntIndexer.create((IntPointer)new IntPointer(ptr).capacity(size / 4L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Long: {
                return (I)LongIndexer.create((LongPointer)new LongPointer(ptr).capacity(size / 8L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Half: {
                return (I)HalfIndexer.create((ShortPointer)new ShortPointer(ptr).capacity(size / 2L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Float: {
                return (I)FloatIndexer.create((FloatPointer)new FloatPointer(ptr).capacity(size / 4L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Double: {
                return (I)DoubleIndexer.create((DoublePointer)new DoublePointer(ptr).capacity(size / 8L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case ComplexHalf: {
                return (I)HalfIndexer.create((ShortPointer)new ShortPointer(ptr).capacity(size / 2L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case ComplexFloat: {
                return (I)FloatIndexer.create((FloatPointer)new FloatPointer(ptr).capacity(size / 4L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case ComplexDouble: {
                return (I)DoubleIndexer.create((DoublePointer)new DoublePointer(ptr).capacity(size / 8L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case Bool: {
                return (I)BooleanIndexer.create((BooleanPointer)new BooleanPointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case QInt8: {
                return (I)ByteIndexer.create((BytePointer)new BytePointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case QUInt8: {
                return (I)UByteIndexer.create((BytePointer)new BytePointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case QInt32: {
                return (I)IntIndexer.create((IntPointer)new IntPointer(ptr).capacity(size / 4L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case BFloat16: {
                return (I)Bfloat16Indexer.create((ShortPointer)new ShortPointer(ptr).capacity(size / 2L), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
            case QUInt4x2: {
                return (I)UByteIndexer.create((BytePointer)new BytePointer(ptr).capacity(size), (long[])sizes, (long[])strides, (boolean)direct).indexable((Indexable)this);
            }
        }
        throw new UnsupportedOperationException("Data type not supported: " + (Object)((Object)dtype));
    }

    static {
        Loader.load();
    }
}

