资讯专栏INFORMATION COLUMN

动手实现一个 LRU cache

Cc_2011 / 1782人阅读

摘要:不过其中的流程算是一个简易的实现,可以对加深一些理解。实现二因此如何来实现一个完整的缓存呢,这次不考虑过期时间的问题。缓存数量超过阈值时移除链表尾部数据。

前言

LRU 是 Least Recently Used 的简写,字面意思则是最近最少使用

通常用于缓存的淘汰策略实现,由于缓存的内存非常宝贵,所以需要根据某种规则来剔除数据保证内存不被撑满。

如常用的 Redis 就有以下几种策略:

策略 描述
volatile-lru 从已设置过期时间的数据集中挑选最近最少使用的数据淘汰
volatile-ttl 从已设置过期时间的数据集中挑选将要过期的数据淘汰
volatile-random 从已设置过期时间的数据集中任意选择数据淘汰
allkeys-lru 从所有数据集中挑选最近最少使用的数据淘汰
allkeys-random 从所有数据集中任意选择数据进行淘汰
no-envicition 禁止驱逐数据
摘抄自:https://github.com/CyC2018/Interview-Notebook/blob/master/notes/Redis.md#%E5%8D%81%E4%B8%89%E6%95%B0%E6%8D%AE%E6%B7%98%E6%B1%B0%E7%AD%96%E7%95%A5
实现一

之前也有接触过一道面试题,大概需求是:

实现一个 LRU 缓存,当缓存数据达到 N 之后需要淘汰掉最近最少使用的数据。

N 小时之内没有被访问的数据也需要淘汰掉。

以下是我的实现:

public class LRUAbstractMap extends java.util.AbstractMap {

    private final static Logger LOGGER = LoggerFactory.getLogger(LRUAbstractMap.class);

    /**
     * 检查是否超期线程
     */
    private ExecutorService checkTimePool ;

    /**
     * map 最大size
     */
    private final static int MAX_SIZE = 1024 ;

    private final static ArrayBlockingQueue QUEUE = new ArrayBlockingQueue<>(MAX_SIZE) ;

    /**
     * 默认大小
     */
    private final static int DEFAULT_ARRAY_SIZE =1024 ;


    /**
     * 数组长度
     */
    private int arraySize ;

    /**
     * 数组
     */
    private Object[] arrays ;


    /**
     * 判断是否停止 flag
     */
    private volatile boolean flag = true ;


    /**
     * 超时时间
     */
    private final static Long EXPIRE_TIME = 60 * 60 * 1000L ;

    /**
     * 整个 Map 的大小
     */
    private volatile AtomicInteger size  ;


    public LRUAbstractMap() {


        arraySize = DEFAULT_ARRAY_SIZE;
        arrays = new Object[arraySize] ;

        //开启一个线程检查最先放入队列的值是否超期
        executeCheckTime();
    }

    /**
     * 开启一个线程检查最先放入队列的值是否超期 设置为守护线程
     */
    private void executeCheckTime() {
        ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()
                .setNameFormat("check-thread-%d")
                .setDaemon(true)
                .build();
        checkTimePool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS,
                new ArrayBlockingQueue<>(1),namedThreadFactory,new ThreadPoolExecutor.AbortPolicy());
        checkTimePool.execute(new CheckTimeThread()) ;

    }

    @Override
    public Set entrySet() {
        return super.keySet();
    }

    @Override
    public Object put(Object key, Object value) {
        int hash = hash(key);
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            arrays[index] = new Node(null,null, key, value);

            //写入队列
            QUEUE.offer((Node) arrays[index]) ;

            sizeUp();
        }else {
            Node cNode = currentNode ;
            Node nNode = cNode ;

            //存在就覆盖
            if (nNode.key == key){
                cNode.val = value ;
            }

            while (nNode.next != null){
                //key 存在 就覆盖 简单判断
                if (nNode.key == key){
                    nNode.val = value ;
                    break ;
                }else {
                    //不存在就新增链表
                    sizeUp();
                    Node node = new Node(nNode,null,key,value) ;

                    //写入队列
                    QUEUE.offer(currentNode) ;

                    cNode.next = node ;
                }

                nNode = nNode.next ;
            }

        }

        return null ;
    }


    @Override
    public Object get(Object key) {

        int hash = hash(key) ;
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            return null ;
        }
        if (currentNode.next == null){

            //更新时间
            currentNode.setUpdateTime(System.currentTimeMillis());

            //没有冲突
            return currentNode ;

        }

        Node nNode = currentNode ;
        while (nNode.next != null){

            if (nNode.key == key){

                //更新时间
                currentNode.setUpdateTime(System.currentTimeMillis());

                return nNode ;
            }

            nNode = nNode.next ;
        }

        return super.get(key);
    }


    @Override
    public Object remove(Object key) {

        int hash = hash(key) ;
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            return null ;
        }

        if (currentNode.key == key){
            sizeDown();
            arrays[index] = null ;

            //移除队列
            QUEUE.poll();
            return currentNode ;
        }

        Node nNode = currentNode ;
        while (nNode.next != null){

            if (nNode.key == key){
                sizeDown();
                //在链表中找到了 把上一个节点的 next 指向当前节点的下一个节点
                nNode.pre.next = nNode.next ;
                nNode = null ;

                //移除队列
                QUEUE.poll();

                return nNode;
            }

            nNode = nNode.next ;
        }

        return super.remove(key);
    }

    /**
     * 增加size
     */
    private void sizeUp(){

        //在put值时候认为里边已经有数据了
        flag = true ;

        if (size == null){
            size = new AtomicInteger() ;
        }
        int size = this.size.incrementAndGet();
        if (size >= MAX_SIZE) {
            //找到队列头的数据
            Node node = QUEUE.poll() ;
            if (node == null){
                throw new RuntimeException("data error") ;
            }

            //移除该 key
            Object key = node.key ;
            remove(key) ;
            lruCallback() ;
        }

    }

    /**
     * 数量减小
     */
    private void sizeDown(){

        if (QUEUE.size() == 0){
            flag = false ;
        }

        this.size.decrementAndGet() ;
    }

    @Override
    public int size() {
        return size.get() ;
    }

    /**
     * 链表
     */
    private class Node{
        private Node next ;
        private Node pre ;
        private Object key ;
        private Object val ;
        private Long updateTime ;

        public Node(Node pre,Node next, Object key, Object val) {
            this.pre = pre ;
            this.next = next;
            this.key = key;
            this.val = val;
            this.updateTime = System.currentTimeMillis() ;
        }

        public void setUpdateTime(Long updateTime) {
            this.updateTime = updateTime;
        }

        public Long getUpdateTime() {
            return updateTime;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "key=" + key +
                    ", val=" + val +
                    "}";
        }
    }


    /**
     * copy HashMap 的 hash 实现
     * @param key
     * @return
     */
    public int hash(Object key) {
        int h;
        return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
    }

    private void lruCallback(){
        LOGGER.debug("lruCallback");
    }


    private class CheckTimeThread implements Runnable{

        @Override
        public void run() {
            while (flag){
                try {
                    Node node = QUEUE.poll();
                    if (node == null){
                        continue ;
                    }
                    Long updateTime = node.getUpdateTime() ;

                    if ((updateTime - System.currentTimeMillis()) >= EXPIRE_TIME){
                        remove(node.key) ;
                    }
                } catch (Exception e) {
                    LOGGER.error("InterruptedException");
                }
            }
        }
    }

}

感兴趣的朋友可以直接从:

https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUAbstractMap.java

下载代码本地运行。

代码看着比较多,其实实现的思路还是比较简单:

采用了与 HashMap 一样的保存数据方式,只是自己手动实现了一个简易版。

内部采用了一个队列来保存每次写入的数据。

写入的时候判断缓存是否大于了阈值 N,如果满足则根据队列的 FIFO 特性将队列头的数据删除。因为队列头的数据肯定是最先放进去的。

再开启了一个守护线程用于判断最先放进去的数据是否超期(因为就算超期也是最先放进去的数据最有可能满足超期条件。)

设置为守护线程可以更好的表明其目的(最坏的情况下,如果是一个用户线程最终有可能导致程序不能正常退出,因为该线程一直在运行,守护线程则不会有这个情况。)

以上代码大体功能满足了,但是有一个致命问题。

就是最近最少使用没有满足,删除的数据都是最先放入的数据。

不过其中的 put get 流程算是一个简易的 HashMap 实现,可以对 HashMap 加深一些理解。
实现二

因此如何来实现一个完整的 LRU 缓存呢,这次不考虑过期时间的问题。

其实从上一个实现也能想到一些思路:

要记录最近最少使用,那至少需要一个有序的集合来保证写入的顺序。

在使用了数据之后能够更新它的顺序。

基于以上两点很容易想到一个常用的数据结构:链表

每次写入数据时将数据放入链表头结点。

使用数据时候将数据移动到头结点

缓存数量超过阈值时移除链表尾部数据。

因此有了以下实现:

public class LRUMap {
    private final Map cacheMap = new HashMap<>();

    /**
     * 最大缓存大小
     */
    private int cacheSize;

    /**
     * 节点大小
     */
    private int nodeCount;


    /**
     * 头结点
     */
    private Node header;

    /**
     * 尾结点
     */
    private Node tailer;

    public LRUMap(int cacheSize) {
        this.cacheSize = cacheSize;
        //头结点的下一个结点为空
        header = new Node<>();
        header.next = null;

        //尾结点的上一个结点为空
        tailer = new Node<>();
        tailer.tail = null;

        //双向链表 头结点的上结点指向尾结点
        header.tail = tailer;

        //尾结点的下结点指向头结点
        tailer.next = header;


    }

    public void put(K key, V value) {
        cacheMap.put(key, value);

        //双向链表中添加结点
        addNode(key, value);
    }

    public V get(K key){

        Node node = getNode(key);

        //移动到头结点
        moveToHead(node) ;

        return cacheMap.get(key);
    }

    private void moveToHead(Node node){

        //如果是最后的一个节点
        if (node.tail == null){
            node.next.tail = null ;
            tailer = node.next ;
            nodeCount -- ;
        }

        //如果是本来就是头节点 不作处理
        if (node.next == null){
            return ;
        }

        //如果处于中间节点
        if (node.tail != null && node.next != null){
            //它的上一节点指向它的下一节点 也就删除当前节点
            node.tail.next = node.next ;
            nodeCount -- ;
        }

        //最后在头部增加当前节点
        //注意这里需要重新 new 一个对象,不然原本的node 还有着下面的引用,会造成内存溢出。
        node = new Node<>(node.getKey(),node.getValue()) ;
        addHead(node) ;

    }

    /**
     * 链表查询 效率较低
     * @param key
     * @return
     */
    private Node getNode(K key){
        Node node = tailer ;
        while (node != null){

            if (node.getKey().equals(key)){
                return node ;
            }

            node = node.next ;
        }

        return null ;
    }


    /**
     * 写入头结点
     * @param key
     * @param value
     */
    private void addNode(K key, V value) {

        Node node = new Node<>(key, value);

        //容量满了删除最后一个
        if (cacheSize == nodeCount) {
            //删除尾结点
            delTail();
        }

        //写入头结点
        addHead(node);

    }



    /**
     * 添加头结点
     *
     * @param node
     */
    private void addHead(Node node) {

        //写入头结点
        header.next = node;
        node.tail = header;
        header = node;
        nodeCount++;

        //如果写入的数据大于2个 就将初始化的头尾结点删除
        if (nodeCount == 2) {
            tailer.next.next.tail = null;
            tailer = tailer.next.next;
        }

    }    

    private void delTail() {
        //把尾结点从缓存中删除
        cacheMap.remove(tailer.getKey());

        //删除尾结点
        tailer.next.tail = null;
        tailer = tailer.next;

        nodeCount--;

    }

    private class Node {
        private K key;
        private V value;
        Node tail;
        Node next;

        public Node(K key, V value) {
            this.key = key;
            this.value = value;
        }

        public Node() {
        }

        public K getKey() {
            return key;
        }

        public void setKey(K key) {
            this.key = key;
        }

        public V getValue() {
            return value;
        }

        public void setValue(V value) {
            this.value = value;
        }

    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder() ;
        Node node = tailer ;
        while (node != null){
            sb.append(node.getKey()).append(":")
                    .append(node.getValue())
                    .append("-->") ;

            node = node.next ;
        }


        return sb.toString();
    }
}

源码:
https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRUMap.java

实际效果,写入时:

    @Test
    public void put() throws Exception {
        LRUMap lruMap = new LRUMap(3) ;
        lruMap.put("1",1) ;
        lruMap.put("2",2) ;
        lruMap.put("3",3) ;

        System.out.println(lruMap.toString());

        lruMap.put("4",4) ;
        System.out.println(lruMap.toString());

        lruMap.put("5",5) ;
        System.out.println(lruMap.toString());
    }

//输出:
1:1-->2:2-->3:3-->
2:2-->3:3-->4:4-->
3:3-->4:4-->5:5-->

使用时:

    @Test
    public void get() throws Exception {
        LRUMap lruMap = new LRUMap(3) ;
        lruMap.put("1",1) ;
        lruMap.put("2",2) ;
        lruMap.put("3",3) ;

        System.out.println(lruMap.toString());
        System.out.println("==============");

        Integer integer = lruMap.get("1");
        System.out.println(integer);
        System.out.println("==============");
        System.out.println(lruMap.toString());
    }
    
//输出
1:1-->2:2-->3:3-->
==============
1
==============
2:2-->3:3-->1:1-->

实现思路和上文提到的一致,说下重点:

数据是直接利用 HashMap 来存放的。

内部使用了一个双向链表来存放数据,所以有一个头结点 header,以及尾结点 tailer。

每次写入头结点,删除尾结点时都是依赖于 header tailer,如果看着比较懵建议自己实现一个链表熟悉下,或结合下文的对象关系图一起理解。

使用数据移动到链表头时,第一步是需要在双向链表中找到该节点。这里就体现出链表的问题了。查找效率很低,最差需要 O(N)。之后依赖于当前节点进行移动。

在写入头结点时有判断链表大小等于 2 时需要删除初始化的头尾结点。这是因为初始化时候生成了两个双向节点,没有数据只是为了形成一个数据结构。当真实数据进来之后需要删除以方便后续的操作(这点可以继续优化)。

以上的所有操作都是线程不安全的,需要使用者自行控制。

下面是对象关系图:

初始化时

写入数据时
LRUMap lruMap = new LRUMap(3) ;
lruMap.put("1",1) ;

lruMap.put("2",2) ;

lruMap.put("3",3) ;

lruMap.put("4",4) ;

获取数据时

数据和上文一样:

Integer integer = lruMap.get("2");

通过以上几张图应该是很好理解数据是如何存放的了。

实现三

其实如果对 Java 的集合比较熟悉的话,会发现上文的结构和 LinkedHashMap 非常类似。

对此不太熟悉的朋友可以先了解下 LinkedHashMap 底层分析 。

所以我们完全可以借助于它来实现:

public class LRULinkedMap {


    /**
     * 最大缓存大小
     */
    private int cacheSize;

    private LinkedHashMap cacheMap ;


    public LRULinkedMap(int cacheSize) {
        this.cacheSize = cacheSize;

        cacheMap = new LinkedHashMap(16,0.75F,true){
            @Override
            protected boolean removeEldestEntry(Map.Entry eldest) {
                if (cacheSize + 1 == cacheMap.size()){
                    return true ;
                }else {
                    return false ;
                }
            }
        };
    }

    public void put(K key,V value){
        cacheMap.put(key,value) ;
    }

    public V get(K key){
        return cacheMap.get(key) ;
    }


    public Collection> getAll() {
        return new ArrayList>(cacheMap.entrySet());
    }
}

源码:
https://github.com/crossoverJie/Java-Interview/blob/master/src/main/java/com/crossoverjie/actual/LRULinkedMap.java

这次就比较简洁了,也就几行代码(具体的逻辑 LinkedHashMap 已经帮我们实现好了)

实际效果:

    @Test
    public void put() throws Exception {
        LRULinkedMap map = new LRULinkedMap(3) ;
        map.put("1",1);
        map.put("2",2);
        map.put("3",3);

        for (Map.Entry e : map.getAll()){
            System.out.print(e.getKey() + " : " + e.getValue() + "	");
        }

        System.out.println("");
        map.put("4",4);
        for (Map.Entry e : map.getAll()){
            System.out.print(e.getKey() + " : " + e.getValue() + "	");
        }
    }
    
//输出
1 : 1    2 : 2    3 : 3    
2 : 2    3 : 3    4 : 4        

使用时:

    @Test
    public void get() throws Exception {
        LRULinkedMap map = new LRULinkedMap(4) ;
        map.put("1",1);
        map.put("2",2);
        map.put("3",3);
        map.put("4",4);

        for (Map.Entry e : map.getAll()){
            System.out.print(e.getKey() + " : " + e.getValue() + "	");
        }

        System.out.println("");
        map.get("1") ;
        for (Map.Entry e : map.getAll()){
            System.out.print(e.getKey() + " : " + e.getValue() + "	");
        }
    }

}

//输出
1 : 1    2 : 2    3 : 3    4 : 4    
2 : 2    3 : 3    4 : 4    1 : 1

LinkedHashMap 内部也有维护一个双向队列,在初始化时也会给定一个缓存大小的阈值。初始化时自定义是否需要删除最近不常使用的数据,如果是则会按照实现二中的方式管理数据。

其实主要代码就是重写了 LinkedHashMap 的 removeEldestEntry 方法:

    protected boolean removeEldestEntry(Map.Entry eldest) {
        return false;
    }

它默认是返回 false,也就是不会管有没有超过阈值。

所以我们自定义大于了阈值时返回 true,这样 LinkedHashMap 就会帮我们删除最近最少使用的数据。

总结

以上就是对 LRU 缓存的实现,了解了这些至少在平时使用时可以知其所以然。

当然业界使用较多的还有 guava 的实现,并且它还支持多种过期策略。

号外

最近在总结一些 Java 相关的知识点,感兴趣的朋友可以一起维护。

地址: https://github.com/crossoverJie/Java-Interview

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/68959.html

相关文章

  • Guava 源码分析(Cache 原理)

    摘要:缓存本次主要讨论缓存。清除数据时的回调通知。具体不在本次的讨论范围。应该是以下原因新起线程需要资源消耗。维护过期数据还要获取额外的锁,增加了消耗。 showImg(https://segmentfault.com/img/remote/1460000015272232); 前言 Google 出的 Guava 是 Java 核心增强的库,应用非常广泛。 我平时用的也挺频繁,这次就借助日...

    wangxinarhat 评论0 收藏0
  • Guava 源码分析(Cache 原理)

    摘要:缓存本次主要讨论缓存。清除数据时的回调通知。具体不在本次的讨论范围。应该是以下原因新起线程需要资源消耗。维护过期数据还要获取额外的锁,增加了消耗。 showImg(https://segmentfault.com/img/remote/1460000015272232); 前言 Google 出的 Guava 是 Java 核心增强的库,应用非常广泛。 我平时用的也挺频繁,这次就借助日...

    Thanatos 评论0 收藏0
  • 一个线程安全的 lrucache 实现 --- 读 leveldb 源码

    摘要:在阅读的源代码的时候,发现其中的类正是一个线程安全的实现,代码非常优雅。至此一个线程安全的类就已经全部实现,在中使用的缓存是,其实就是聚合多个实例,真正的逻辑都在类中。 缓存是计算机的每一个层次中都是一个非常重要的概念,缓存的存在可以大大提高软件的运行速度。Least Recently Used(lru) cache 即最近最久未使用的缓存,多见与页面置换算法,lru 缓存算法在缓存的...

    widuu 评论0 收藏0
  • NPM酷库:lru-cache 基于内存的缓存管理

    摘要:酷库,每天两分钟,了解一个流行库。而直接将数据保存在程序变量中,最经济快捷。但是这样就会带来一些其他问题,比如缓存更新缓存过期等。用于在内存中管理缓存数据,并且支持算法。可以让程序不依赖任何外部数据库实现缓存管理。 NPM酷库,每天两分钟,了解一个流行NPM库。 为了优化程序性能,我们常常需要奖数据缓存起来,根据实际情况,我们可以将数据存储到磁盘、数据库、redis等。 但是有时候要缓...

    Yumenokanata 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<