/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.config;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.autodiff.samediff.config.SDValueType;
import org.nd4j.autodiff.samediff.internal.IDependeeGroup;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class SDValue
implements IDependeeGroup<INDArray> {
    private SDValueType sdValueType;
    private INDArray tensorValue;
    private Map<String, INDArray> dictValue;
    private List<INDArray> listValue;
    private static final AtomicLong counter = new AtomicLong(0L);
    protected transient long id = counter.getAndIncrement();

    private SDValue() {
    }

    @Override
    public long getId() {
        return this.id;
    }

    @Override
    public Collection<INDArray> getCollection() {
        return this.getListValue();
    }

    public static SDValue empty(SDValueType valueType, DataType dataType) {
        switch (valueType) {
            case LIST: {
                return SDValue.create(Arrays.asList(new INDArray[0]));
            }
            case DICT: {
                return SDValue.create(Collections.emptyMap());
            }
            case TENSOR: {
                return SDValue.create(Nd4j.zeros(1).castTo(dataType));
            }
        }
        throw new IllegalArgumentException("Unable to create empty value, unknown value type " + valueType);
    }

    public INDArray getTensorValue() {
        if (this.listValue != null && this.listValue.size() == 1) {
            return this.listValue.get(0);
        }
        return this.tensorValue;
    }

    public List<INDArray> getListValue() {
        if (this.tensorValue != null) {
            return Arrays.asList(this.tensorValue);
        }
        return this.listValue;
    }

    public static SDValue create(INDArray inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.tensorValue = inputValue;
        sdValue.sdValueType = SDValueType.TENSOR;
        return sdValue;
    }

    public static SDValue create(Collection<INDArray> inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.listValue = (List)inputValue;
        sdValue.sdValueType = SDValueType.LIST;
        return sdValue;
    }

    public static SDValue create(List<INDArray> inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.listValue = inputValue;
        sdValue.sdValueType = SDValueType.LIST;
        return sdValue;
    }

    public static SDValue create(Map<String, INDArray> inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.dictValue = inputValue;
        sdValue.sdValueType = SDValueType.DICT;
        return sdValue;
    }

    public boolean equals(Object o) {
        SDValue sd = (SDValue)o;
        return sd.getId() == this.getId();
    }

    public int hashCode() {
        return Long.hashCode(this.getId());
    }

    public String toString() {
        INDArray h = this.getTensorValue();
        StringBuilder st = new StringBuilder();
        if (h != null) {
            st.append("--sdValueId-");
            st.append(this.getId() + "--key--" + this.getSdValueType() + " --Array " + h.getId());
        } else {
            List<INDArray> listx = this.getListValue();
            if (listx != null && listx.size() > 0) {
                st.append("--sdValueId-");
                st.append(this.getId() + "--key--" + this.getSdValueType() + " -- List Size " + listx.size());
                for (INDArray gh : this.getListValue()) {
                    if (gh == null) {
                        st.append(" --Array NULL ");
                        continue;
                    }
                    st.append(" --Array " + gh.getId() + " --\t ");
                }
            }
        }
        return st.toString();
    }

    public SDValueType getSdValueType() {
        return this.sdValueType;
    }

    public Map<String, INDArray> getDictValue() {
        return this.dictValue;
    }
}

