/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal.memory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.samediff.internal.memory.AbstractMemoryMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Longs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ArrayCacheMemoryMgr
extends AbstractMemoryMgr {
    private static final Logger log = LoggerFactory.getLogger(ArrayCacheMemoryMgr.class);
    private final double maxMemFrac;
    private long smallArrayThreshold;
    private double largerArrayMaxMultiple;
    private long maxCacheBytes;
    private long totalMemBytes;
    private long currentCacheSize = 0L;
    private LinkedHashSet<Long> lruCache = new LinkedHashSet();
    private Map<Long, INDArray> lruCacheValues = new HashMap<Long, INDArray>();
    private Table<DataType, String, List<INDArray>> arrays = HashBasedTable.create();
    private boolean enableCache = Boolean.parseBoolean(System.getProperty("org.nd4j.autodiff.samediff", "true"));

    public ArrayCacheMemoryMgr() {
        this(0.25, 1024L, 2.0);
    }

    public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) {
        Preconditions.checkArgument((maxMemFrac > 0.0 && maxMemFrac < 1.0 ? 1 : 0) != 0, (String)"Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", (double)maxMemFrac);
        Preconditions.checkArgument((smallArrayThreshold >= 0L ? 1 : 0) != 0, (String)"Small array threshold must be >= 0, got %s", (long)smallArrayThreshold);
        Preconditions.checkArgument((largerArrayMaxMultiple >= 1.0 ? 1 : 0) != 0, (String)"Larger array max multiple must be >= 1.0, got %s", (double)largerArrayMaxMultiple);
        this.maxMemFrac = maxMemFrac;
        this.smallArrayThreshold = smallArrayThreshold;
        this.largerArrayMaxMultiple = largerArrayMaxMultiple;
        if (this.isCpu()) {
            this.totalMemBytes = Pointer.maxBytes();
        } else {
            Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
            List devList = (List)p.get("cuda.devicesInformation");
            Map m = (Map)devList.get(0);
            this.totalMemBytes = (Long)m.get("cuda.totalMemory");
        }
        this.maxCacheBytes = (long)(maxMemFrac * (double)this.totalMemBytes);
    }

    private boolean isCpu() {
        String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        return !"CUDA".equalsIgnoreCase(backend);
    }

    @Override
    public INDArray allocate(boolean detached, DataType dataType, long ... shape) {
        String arrayShapeString = Arrays.toString(shape);
        if (this.arrays.contains((Object)dataType, (Object)arrayShapeString) && this.enableCache) {
            INDArray arr;
            INDArray iNDArray = arr = !((List)this.arrays.get((Object)dataType, (Object)arrayShapeString)).isEmpty() ? (INDArray)((List)this.arrays.get((Object)dataType, (Object)arrayShapeString)).remove(0) : null;
            if (arr != null && !arr.wasClosed()) {
                this.currentCacheSize -= (long)dataType.width() * arr.data().length();
                log.debug("Cache hit for data type " + dataType + " and shape " + Arrays.toString(shape));
                this.lruCache.remove(arr.getId());
                this.lruCacheValues.remove(arr.getId());
                ((BaseNDArray)arr).assignNewId();
                return arr;
            }
        }
        INDArray ret = Nd4j.createUninitializedDetached(dataType, shape);
        return ret;
    }

    @Override
    public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
        long[] shape;
        String arrayShape;
        if (descriptor.isEmpty()) {
            INDArray ret = Nd4j.create(descriptor);
            if (detached) {
                ret = ret.detach();
            }
            return ret;
        }
        DataType dataType = descriptor.dataType();
        if (this.arrays.contains((Object)dataType, (Object)(arrayShape = Arrays.toString(shape = descriptor.getShape()))) && this.enableCache && shape.length > 0 && !Longs.contains((long[])shape, (long)0L)) {
            INDArray arr = null;
            List arrays2 = (List)this.arrays.get((Object)dataType, (Object)arrayShape);
            if (arrays2.size() > 0) {
                arr = (INDArray)arrays2.remove(0);
            }
            if (arr != null && arr.ordering() != descriptor.getOrder()) {
                arr.setOrder(descriptor.getOrder());
            }
            if (arr != null && !arr.wasClosed()) {
                this.currentCacheSize -= (long)dataType.width() * arr.data().length();
                log.debug("Cache hit for data type " + dataType + " and shape " + Arrays.toString(arr.shape()));
                this.lruCache.remove(arr.getId());
                this.lruCacheValues.remove(arr.getId());
                ((BaseNDArray)arr).assignNewId();
                return arr;
            }
        }
        return Nd4j.createUninitializedDetached(dataType, shape);
    }

    @Override
    public void release(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        long id = array.getId();
        Preconditions.checkState((!this.lruCache.contains(id) ? 1 : 0) != 0, (String)"Array was released multiple times: id=%s, shape=%ndShape", (Object)id, (Object)array);
        if (!this.enableCache) {
            if (array.closeable()) {
                array.close();
            }
            return;
        }
        DataType dt = array.dataType();
        if (array.data() == null && array.closeable()) {
            array.close();
            return;
        }
        if (array != null && array.data() != null && Nd4j.getExecutioner().useCount(array.data()) > 1) {
            if (array.closeable()) {
                array.close();
            }
            return;
        }
        long thisBytes = array.data().length() * (long)dt.width();
        if (array.dataType() == DataType.UTF8) {
            if (array.closeable()) {
                array.close();
            }
        } else if (this.currentCacheSize + thisBytes > this.maxCacheBytes) {
            if (thisBytes > this.maxCacheBytes) {
                if (array.closeable()) {
                    array.close();
                }
                return;
            }
            Iterator iter = this.lruCache.iterator();
            while (this.currentCacheSize + thisBytes > this.maxCacheBytes) {
                long next = (Long)iter.next();
                iter.remove();
                INDArray nextOldest = this.lruCacheValues.remove(next);
                DataType ndt = nextOldest.dataType();
                long nextBytes = (long)ndt.width() * nextOldest.data().length();
                List listx = (List)this.arrays.get((Object)ndt, (Object)Arrays.toString(nextOldest.shape()));
                if (listx != null) {
                    listx.remove(nextOldest);
                }
                this.currentCacheSize -= nextBytes;
                if (!nextOldest.closeable()) continue;
                nextOldest.close();
            }
            this.cacheArray(array);
        } else {
            this.cacheArray(array);
        }
        this.lruCache.add(array.getId());
        this.lruCacheValues.put(array.getId(), array);
    }

    private void cacheArray(INDArray array) {
        String arrayShapeString;
        DataType dt = array.dataType();
        if (!this.arrays.contains((Object)dt, (Object)(arrayShapeString = Arrays.toString(array.shape())))) {
            this.arrays.put((Object)dt, (Object)arrayShapeString, new ArrayList());
        }
        ((List)this.arrays.get((Object)dt, (Object)arrayShapeString)).add(array);
        this.currentCacheSize += array.data().length() * (long)dt.width();
        this.lruCache.add(array.getId());
        this.lruCacheValues.put(array.getId(), array);
    }

    @Override
    public void close() {
        this.arrays.values().stream().forEach(input -> input.stream().forEach(arr -> {
            if (arr.closeable()) {
                arr.close();
            }
        }));
    }

    public double getMaxMemFrac() {
        return this.maxMemFrac;
    }

    public long getSmallArrayThreshold() {
        return this.smallArrayThreshold;
    }

    public double getLargerArrayMaxMultiple() {
        return this.largerArrayMaxMultiple;
    }

    public long getMaxCacheBytes() {
        return this.maxCacheBytes;
    }

    public long getTotalMemBytes() {
        return this.totalMemBytes;
    }

    public long getCurrentCacheSize() {
        return this.currentCacheSize;
    }

    public LinkedHashSet<Long> getLruCache() {
        return this.lruCache;
    }

    public Map<Long, INDArray> getLruCacheValues() {
        return this.lruCacheValues;
    }

    public Table<DataType, String, List<INDArray>> getArrays() {
        return this.arrays;
    }

    public boolean isEnableCache() {
        return this.enableCache;
    }

    public void setSmallArrayThreshold(long smallArrayThreshold) {
        this.smallArrayThreshold = smallArrayThreshold;
    }

    public void setLargerArrayMaxMultiple(double largerArrayMaxMultiple) {
        this.largerArrayMaxMultiple = largerArrayMaxMultiple;
    }

    public void setMaxCacheBytes(long maxCacheBytes) {
        this.maxCacheBytes = maxCacheBytes;
    }

    public void setTotalMemBytes(long totalMemBytes) {
        this.totalMemBytes = totalMemBytes;
    }

    public void setCurrentCacheSize(long currentCacheSize) {
        this.currentCacheSize = currentCacheSize;
    }

    public void setLruCache(LinkedHashSet<Long> lruCache) {
        this.lruCache = lruCache;
    }

    public void setLruCacheValues(Map<Long, INDArray> lruCacheValues) {
        this.lruCacheValues = lruCacheValues;
    }

    public void setArrays(Table<DataType, String, List<INDArray>> arrays) {
        this.arrays = arrays;
    }

    public void setEnableCache(boolean enableCache) {
        this.enableCache = enableCache;
    }
}

