package nie.translator.rtranslator.tools.nn;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.Arrays;

/* loaded from: classes2.dex */
public class CacheContainerNative {
    private long cacheContainerNativePointer;
    private OnnxTensor[] cacheTensors;
    private int[] shape;

    public CacheContainerNative(OrtEnvironment ortEnvironment, OrtSession.Result result, int i, int i2, int i3, int i4, int i5) {
        int i6 = i * 2;
        try {
            this.cacheTensors = new OnnxTensor[i6];
            this.cacheContainerNativePointer = initialize(i6, i2, i3, i4, i5);
            int i7 = 0;
            for (int i8 = 0; i8 < i; i8++) {
                this.cacheTensors[i7] = (OnnxTensor) result.get("present." + i8 + ".decoder.key").get();
                Method declaredMethod = this.cacheTensors[i7].getClass().getDeclaredMethod("getBuffer", new Class[0]);
                declaredMethod.setAccessible(true);
                insertValues(this.cacheContainerNativePointer, i7, (ByteBuffer) declaredMethod.invoke(this.cacheTensors[i7], new Object[0]));
                int i9 = i7 + 1;
                this.cacheTensors[i9] = (OnnxTensor) result.get("present." + i8 + ".decoder.value").get();
                Method declaredMethod2 = this.cacheTensors[i9].getClass().getDeclaredMethod("getBuffer", new Class[0]);
                declaredMethod2.setAccessible(true);
                insertValues(this.cacheContainerNativePointer, i9, (ByteBuffer) declaredMethod2.invoke(this.cacheTensors[i9], new Object[0]));
                i7 += 2;
            }
            this.shape = new int[]{i6, i2, i3, i4, i5};
        } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
            e.printStackTrace();
        }
    }

    private native void close(long j);

    private native ByteBuffer getBuffer(long j, int i);

    private native long initialize(int i, int i2, int i3, int i4, int i5);

    private native void insertValues(long j, int i, ByteBuffer byteBuffer);

    private native void reorder(long j, int[] iArr);

    public void close() {
        close(this.cacheContainerNativePointer);
    }

    public OrtSession.Result getCacheResult(OrtEnvironment ortEnvironment) {
        try {
            int i = this.shape[0];
            String[] strArr = new String[i];
            OnnxValue[] onnxValueArr = new OnnxValue[i];
            boolean[] zArr = new boolean[i];
            Arrays.fill(zArr, true);
            String[] strArr2 = {"key", "value"};
            int i2 = 0;
            for (int i3 = 0; i3 < this.shape[0] / 2; i3++) {
                for (int i4 = 0; i4 < 2; i4++) {
                    strArr[i2] = "present." + i3 + ".decoder." + strArr2[i4];
                    ByteBuffer buffer = getBuffer(this.cacheContainerNativePointer, i2);
                    int[] iArr = this.shape;
                    onnxValueArr[i2] = OnnxTensor.createTensor(ortEnvironment, buffer, new long[]{iArr[1], iArr[2], iArr[3], iArr[4]}, OnnxJavaType.FLOAT);
                    i2++;
                }
            }
            Constructor declaredConstructor = OrtSession.Result.class.getDeclaredConstructor(strArr.getClass(), onnxValueArr.getClass(), zArr.getClass());
            declaredConstructor.setAccessible(true);
            return (OrtSession.Result) declaredConstructor.newInstance(strArr, onnxValueArr, zArr);
        } catch (OrtException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            e.printStackTrace();
            return null;
        }
    }

    public void reorder(int[] iArr) {
        reorder(this.cacheContainerNativePointer, iArr);
    }
}
