/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.client.ml.inference.preprocessing;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

public class CustomWordEmbedding
implements PreProcessor {
    public static final String NAME = "custom_word_embedding";
    static final ParseField FIELD = new ParseField("field", new String[0]);
    static final ParseField DEST_FIELD = new ParseField("dest_field", new String[0]);
    static final ParseField EMBEDDING_WEIGHTS = new ParseField("embedding_weights", new String[0]);
    static final ParseField EMBEDDING_QUANT_SCALES = new ParseField("embedding_quant_scales", new String[0]);
    public static final ConstructingObjectParser<CustomWordEmbedding, Void> PARSER = new ConstructingObjectParser("custom_word_embedding", true, a -> new CustomWordEmbedding((short[][])a[0], (byte[][])a[1], (String)a[2], (String)a[3]));
    private final short[][] embeddingsQuantScales;
    private final byte[][] embeddingsWeights;
    private final String fieldName;
    private final String destField;

    private static <T> List<List<T>> parseArrays(String fieldName, CheckedFunction<XContentParser, T, IOException> fromParser, XContentParser p) throws IOException {
        if (p.currentToken() != XContentParser.Token.START_ARRAY) {
            throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
        }
        ArrayList<List<T>> values = new ArrayList<List<T>>();
        while (p.nextToken() != XContentParser.Token.END_ARRAY) {
            if (p.currentToken() != XContentParser.Token.START_ARRAY) {
                throw new IllegalArgumentException("unexpected token [" + p.currentToken() + "] for [" + fieldName + "]");
            }
            ArrayList<Object> innerList = new ArrayList<Object>();
            while (p.nextToken() != XContentParser.Token.END_ARRAY) {
                if (!p.currentToken().isValue()) {
                    throw new IllegalStateException("expected non-null value but got [" + p.currentToken() + "] for [" + fieldName + "]");
                }
                innerList.add(fromParser.apply((Object)p));
            }
            values.add(innerList);
        }
        return values;
    }

    public static CustomWordEmbedding fromXContent(XContentParser parser) {
        return (CustomWordEmbedding)PARSER.apply(parser, null);
    }

    CustomWordEmbedding(short[][] embeddingsQuantScales, byte[][] embeddingsWeights, String fieldName, String destField) {
        this.embeddingsQuantScales = embeddingsQuantScales;
        this.embeddingsWeights = embeddingsWeights;
        this.fieldName = fieldName;
        this.destField = destField;
    }

    @Override
    public String getName() {
        return NAME;
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(FIELD.getPreferredName(), this.fieldName);
        builder.field(DEST_FIELD.getPreferredName(), this.destField);
        builder.field(EMBEDDING_QUANT_SCALES.getPreferredName(), (Object)this.embeddingsQuantScales);
        builder.field(EMBEDDING_WEIGHTS.getPreferredName(), (Object)this.embeddingsWeights);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        CustomWordEmbedding that = (CustomWordEmbedding)o;
        return Objects.equals(this.fieldName, that.fieldName) && Objects.equals(this.destField, that.destField) && Arrays.deepEquals((Object[])this.embeddingsWeights, (Object[])that.embeddingsWeights) && Arrays.deepEquals((Object[])this.embeddingsQuantScales, (Object[])that.embeddingsQuantScales);
    }

    public int hashCode() {
        return Objects.hash(this.fieldName, this.destField, Arrays.deepHashCode((Object[])this.embeddingsQuantScales), Arrays.deepHashCode((Object[])this.embeddingsWeights));
    }

    static {
        PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
            List listOfListOfShorts = CustomWordEmbedding.parseArrays(EMBEDDING_QUANT_SCALES.getPreferredName(), XContentParser::shortValue, p);
            short[][] primitiveShorts = new short[listOfListOfShorts.size()][];
            int i = 0;
            for (List shorts : listOfListOfShorts) {
                short[] innerShorts = new short[shorts.size()];
                for (int j = 0; j < shorts.size(); ++j) {
                    innerShorts[j] = (Short)shorts.get(j);
                }
                primitiveShorts[i++] = innerShorts;
            }
            return primitiveShorts;
        }, EMBEDDING_QUANT_SCALES, ObjectParser.ValueType.VALUE_ARRAY);
        PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
            ArrayList<byte[]> values = new ArrayList<byte[]>();
            while (p.nextToken() != XContentParser.Token.END_ARRAY) {
                values.add(p.binaryValue());
            }
            byte[][] primitiveBytes = new byte[values.size()][];
            int i = 0;
            for (byte[] bytes : values) {
                primitiveBytes[i++] = bytes;
            }
            return primitiveBytes;
        }, EMBEDDING_WEIGHTS, ObjectParser.ValueType.VALUE_ARRAY);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), DEST_FIELD);
    }
}

