package com.hankcs.hanlp.dependency.perceptron.learning;

import com.hankcs.hanlp.dependency.perceptron.structures.CompactArray;
import com.hankcs.hanlp.dependency.perceptron.structures.ParserModel;
import com.hankcs.hanlp.dependency.perceptron.transition.parser.Action;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:com/hankcs/hanlp/dependency/perceptron/learning/AveragedPerceptron.class */
public class AveragedPerceptron {
    public HashMap<Object, Float>[] shiftFeatureWeights;
    public HashMap<Object, Float>[] reduceFeatureWeights;
    public HashMap<Object, CompactArray>[] leftArcFeatureWeights;
    public HashMap<Object, CompactArray>[] rightArcFeatureWeights;
    public int iteration;
    public int dependencySize;
    public HashMap<Object, Float>[] shiftFeatureAveragedWeights;
    public HashMap<Object, Float>[] reduceFeatureAveragedWeights;
    public HashMap<Object, CompactArray>[] leftArcFeatureAveragedWeights;
    public HashMap<Object, CompactArray>[] rightArcFeatureAveragedWeights;

    public AveragedPerceptron(int i, int i2) {
        this.shiftFeatureWeights = new HashMap[i];
        this.reduceFeatureWeights = new HashMap[i];
        this.leftArcFeatureWeights = new HashMap[i];
        this.rightArcFeatureWeights = new HashMap[i];
        this.shiftFeatureAveragedWeights = new HashMap[i];
        this.reduceFeatureAveragedWeights = new HashMap[i];
        this.leftArcFeatureAveragedWeights = new HashMap[i];
        this.rightArcFeatureAveragedWeights = new HashMap[i];
        for (int i3 = 0; i3 < i; i3++) {
            this.shiftFeatureWeights[i3] = new HashMap<>();
            this.reduceFeatureWeights[i3] = new HashMap<>();
            this.leftArcFeatureWeights[i3] = new HashMap<>();
            this.rightArcFeatureWeights[i3] = new HashMap<>();
            this.shiftFeatureAveragedWeights[i3] = new HashMap<>();
            this.reduceFeatureAveragedWeights[i3] = new HashMap<>();
            this.leftArcFeatureAveragedWeights[i3] = new HashMap<>();
            this.rightArcFeatureAveragedWeights[i3] = new HashMap<>();
        }
        this.iteration = 1;
        this.dependencySize = i2;
    }

    private AveragedPerceptron(HashMap<Object, Float>[] hashMapArr, HashMap<Object, Float>[] hashMapArr2, HashMap<Object, CompactArray>[] hashMapArr3, HashMap<Object, CompactArray>[] hashMapArr4, int i) {
        this.shiftFeatureAveragedWeights = hashMapArr;
        this.reduceFeatureAveragedWeights = hashMapArr2;
        this.leftArcFeatureAveragedWeights = hashMapArr3;
        this.rightArcFeatureAveragedWeights = hashMapArr4;
        this.dependencySize = i;
    }

    public AveragedPerceptron(ParserModel parserModel) {
        this(parserModel.shiftFeatureAveragedWeights, parserModel.reduceFeatureAveragedWeights, parserModel.leftArcFeatureAveragedWeights, parserModel.rightArcFeatureAveragedWeights, parserModel.dependencySize);
    }

    public float changeWeight(Action action, int i, Object obj, int i2, float f) {
        if (obj == null) {
            return 0.0f;
        }
        if (action == Action.Shift) {
            if (this.shiftFeatureWeights[i].containsKey(obj)) {
                this.shiftFeatureWeights[i].put(obj, Float.valueOf(this.shiftFeatureWeights[i].get(obj).floatValue() + f));
            } else {
                this.shiftFeatureWeights[i].put(obj, Float.valueOf(f));
            }
            if (this.shiftFeatureAveragedWeights[i].containsKey(obj)) {
                this.shiftFeatureAveragedWeights[i].put(obj, Float.valueOf(this.shiftFeatureAveragedWeights[i].get(obj).floatValue() + (this.iteration * f)));
            } else {
                this.shiftFeatureAveragedWeights[i].put(obj, Float.valueOf(this.iteration * f));
            }
        } else if (action == Action.Reduce) {
            if (this.reduceFeatureWeights[i].containsKey(obj)) {
                this.reduceFeatureWeights[i].put(obj, Float.valueOf(this.reduceFeatureWeights[i].get(obj).floatValue() + f));
            } else {
                this.reduceFeatureWeights[i].put(obj, Float.valueOf(f));
            }
            if (this.reduceFeatureAveragedWeights[i].containsKey(obj)) {
                this.reduceFeatureAveragedWeights[i].put(obj, Float.valueOf(this.reduceFeatureAveragedWeights[i].get(obj).floatValue() + (this.iteration * f)));
            } else {
                this.reduceFeatureAveragedWeights[i].put(obj, Float.valueOf(this.iteration * f));
            }
        } else if (action == Action.RightArc) {
            changeFeatureWeight(this.rightArcFeatureWeights[i], this.rightArcFeatureAveragedWeights[i], obj, i2, f, this.dependencySize);
        } else if (action == Action.LeftArc) {
            changeFeatureWeight(this.leftArcFeatureWeights[i], this.leftArcFeatureAveragedWeights[i], obj, i2, f, this.dependencySize);
        }
        return f;
    }

    public void changeFeatureWeight(HashMap<Object, CompactArray> hashMap, HashMap<Object, CompactArray> hashMap2, Object obj, int i, float f, int i2) {
        CompactArray compactArray = hashMap.get(obj);
        if (compactArray != null) {
            compactArray.set(i, f);
            hashMap2.get(obj).set(i, this.iteration * f);
        } else {
            hashMap.put(obj, new CompactArray(i, new float[]{f}));
            hashMap2.put(obj, new CompactArray(i, new float[]{this.iteration * f}));
        }
    }

    public void incrementIteration() {
        this.iteration++;
    }

    public float shiftScore(Object[] objArr, boolean z) {
        Float f;
        float f2 = 0.0f;
        HashMap<Object, Float>[] hashMapArr = z ? this.shiftFeatureAveragedWeights : this.shiftFeatureWeights;
        for (int i = 0; i < objArr.length; i++) {
            if (objArr[i] != null && ((i < 26 || i >= 32) && (f = hashMapArr[i].get(objArr[i])) != null)) {
                f2 += f.floatValue();
            }
        }
        return f2;
    }

    public float reduceScore(Object[] objArr, boolean z) {
        Float f;
        float f2 = 0.0f;
        HashMap<Object, Float>[] hashMapArr = z ? this.reduceFeatureAveragedWeights : this.reduceFeatureWeights;
        for (int i = 0; i < objArr.length; i++) {
            if (objArr[i] != null && ((i < 26 || i >= 32) && (f = hashMapArr[i].get(objArr[i])) != null)) {
                f2 += f.floatValue();
            }
        }
        return f2;
    }

    public float[] leftArcScores(Object[] objArr, boolean z) {
        CompactArray compactArray;
        float[] fArr = new float[this.dependencySize];
        HashMap<Object, CompactArray>[] hashMapArr = z ? this.leftArcFeatureAveragedWeights : this.leftArcFeatureWeights;
        for (int i = 0; i < objArr.length; i++) {
            if (objArr[i] != null && (compactArray = hashMapArr[i].get(objArr[i])) != null) {
                int offset = compactArray.getOffset();
                float[] array = compactArray.getArray();
                for (int i2 = offset; i2 < offset + array.length; i2++) {
                    int i3 = i2;
                    fArr[i3] = fArr[i3] + array[i2 - offset];
                }
            }
        }
        return fArr;
    }

    public float[] rightArcScores(Object[] objArr, boolean z) {
        CompactArray compactArray;
        float[] fArr = new float[this.dependencySize];
        HashMap<Object, CompactArray>[] hashMapArr = z ? this.rightArcFeatureAveragedWeights : this.rightArcFeatureWeights;
        for (int i = 0; i < objArr.length; i++) {
            if (objArr[i] != null && (compactArray = hashMapArr[i].get(objArr[i])) != null) {
                int offset = compactArray.getOffset();
                float[] array = compactArray.getArray();
                for (int i2 = offset; i2 < offset + array.length; i2++) {
                    int i3 = i2;
                    fArr[i3] = fArr[i3] + array[i2 - offset];
                }
            }
        }
        return fArr;
    }

    public int featureSize() {
        return this.shiftFeatureAveragedWeights.length;
    }

    public int raSize() {
        int i = 0;
        for (int i2 = 0; i2 < this.leftArcFeatureAveragedWeights.length; i2++) {
            Iterator<Object> it = this.rightArcFeatureAveragedWeights[i2].keySet().iterator();
            while (it.hasNext()) {
                i += this.rightArcFeatureAveragedWeights[i2].get(it.next()).length();
            }
        }
        return i;
    }

    public int effectiveRaSize() {
        int i = 0;
        for (int i2 = 0; i2 < this.leftArcFeatureAveragedWeights.length; i2++) {
            Iterator<Object> it = this.rightArcFeatureAveragedWeights[i2].keySet().iterator();
            while (it.hasNext()) {
                for (float f : this.rightArcFeatureAveragedWeights[i2].get(it.next()).getArray()) {
                    if (f != 0.0f) {
                        i++;
                    }
                }
            }
        }
        return i;
    }

    public int laSize() {
        int i = 0;
        for (int i2 = 0; i2 < this.leftArcFeatureAveragedWeights.length; i2++) {
            Iterator<Object> it = this.leftArcFeatureAveragedWeights[i2].keySet().iterator();
            while (it.hasNext()) {
                i += this.leftArcFeatureAveragedWeights[i2].get(it.next()).length();
            }
        }
        return i;
    }

    public int effectiveLaSize() {
        int i = 0;
        for (int i2 = 0; i2 < this.leftArcFeatureAveragedWeights.length; i2++) {
            Iterator<Object> it = this.leftArcFeatureAveragedWeights[i2].keySet().iterator();
            while (it.hasNext()) {
                for (float f : this.leftArcFeatureAveragedWeights[i2].get(it.next()).getArray()) {
                    if (f != 0.0f) {
                        i++;
                    }
                }
            }
        }
        return i;
    }
}
