用两个哈希表实现O(1)复杂度的LFU算法

LFU缓存是非常常见的缓存策略算法,如果考虑在O(1)的复杂度下实现,还是有一点意思的。

LFU定义

LFUCache 类:

  • LFUCache(int capacity) - 用数据结构的容量 capacity 初始化对象
  • int get(int key) - 如果键 key 存在于缓存中,则获取键的值,否则返回 -1 。
  • void put(int key, int value) - 如果键 key 已存在,则变更其值;如果键不存在,请插入键值对。当缓存达到其容量 capacity 时,则应该在插入新项之前,移除最不经常使用的项。在此问题中,当存在平局(即两个或更多个键具有相同使用频率)时,应该去除 最久未使用 的键。

为了确定最不常使用的键,可以为缓存中的每个键维护一个 使用计数器 。使用计数最小的键是最久未使用的键。

当一个键首次插入到缓存中时,它的使用计数器被设置为 1 (由于 put 操作)。对缓存中的键执行 get 或 put 操作,使用计数器的值将会递增。

思路分析

既然时间复杂度需要时O(1),那我们一定要想到的就是哈希表,并且存储key-value对的方式就是哈希表。

那么问题就在于,如果达到容量了,我们怎么找出需要移除的那一项,即如何快速找出最不经常使用的那一项的key,然后通过哈希表找到value

也就是说我们需要有一种数据结构,可以让我们快速找到(count,time)与key之间的映射关系,并且(count,time)权重最小的

这里有几种思路:

  1. 数组并配合任意一种排序算法,保证数组按照(count,time)升序或者降序排列

  2. BST,AVL,红黑树,优先级队列,等二叉搜索树,通过(count,time)来进行排序

  3. 我们定义两个哈希表,第一个 freq_table 以频率 freq 为索引,每个索引存放一个双向链表,这个链表里存放所有使用频率为 freq 的缓存,缓存里存放三个信息,分别为键 key,值 value,以及使用频率 freq。第二个 key_table 以键值 key 为索引,每个索引存放对应缓存在 freq_table 中链表里的内存地址,这样我们就能利用两个哈希表来使得两个操作的时间复杂度均为 O(1)O(1)O(1)。同时需要记录一个当前缓存最少使用的频率 minFreq,这是为了删除操作服务的。

只有第三种思路可以达到O(1),其实仔细思考下,因为我们只需要直到最少使用频率是哪个就好,没必要每次都完整排好序。

具体实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class LFUCache {
int minfreq, capacity;
Map<Integer, Node> keyTable;
Map<Integer, DoublyLinkedList> freqTable;

public LFUCache(int capacity) {
this.minfreq = 0;
this.capacity = capacity;
keyTable = new HashMap<Integer, Node>();
freqTable = new HashMap<Integer, DoublyLinkedList>();
}

public int get(int key) {
if (capacity == 0) {
return -1;
}
if (!keyTable.containsKey(key)) {
return -1;
}
Node node = keyTable.get(key);
int val = node.val, freq = node.freq;
freqTable.get(freq).remove(node);
// 如果当前链表为空,我们需要在哈希表中删除,且更新minFreq
if (freqTable.get(freq).size == 0) {
freqTable.remove(freq);
if (minfreq == freq) {
minfreq += 1;
}
}
// 插入到 freq + 1 中
DoublyLinkedList list = freqTable.getOrDefault(freq + 1, new DoublyLinkedList());
list.addFirst(new Node(key, val, freq + 1));
freqTable.put(freq + 1, list);
keyTable.put(key, freqTable.get(freq + 1).getHead());
return val;
}

public void put(int key, int value) {
if (capacity == 0) {
return;
}
if (!keyTable.containsKey(key)) {
// 缓存已满,需要进行删除操作
if (keyTable.size() == capacity) {
// 通过 minFreq 拿到 freqTable[minFreq] 链表的末尾节点
Node node = freqTable.get(minfreq).getTail();
keyTable.remove(node.key);
freqTable.get(minfreq).remove(node);
if (freqTable.get(minfreq).size == 0) {
freqTable.remove(minfreq);
}
}
DoublyLinkedList list = freqTable.getOrDefault(1, new DoublyLinkedList());
list.addFirst(new Node(key, value, 1));
freqTable.put(1, list);
keyTable.put(key, freqTable.get(1).getHead());
minfreq = 1;
} else {
// 与 get 操作基本一致,除了需要更新缓存的值
Node node = keyTable.get(key);
int freq = node.freq;
freqTable.get(freq).remove(node);
if (freqTable.get(freq).size == 0) {
freqTable.remove(freq);
if (minfreq == freq) {
minfreq += 1;
}
}
DoublyLinkedList list = freqTable.getOrDefault(freq + 1, new DoublyLinkedList());
list.addFirst(new Node(key, value, freq + 1));
freqTable.put(freq + 1, list);
keyTable.put(key, freqTable.get(freq + 1).getHead());
}
}
}

class Node {
int key, val, freq;
Node prev, next;

Node() {
this(-1, -1, 0);
}

Node(int key, int val, int freq) {
this.key = key;
this.val = val;
this.freq = freq;
}
}

class DoublyLinkedList {
Node dummyHead, dummyTail;
int size;

DoublyLinkedList() {
dummyHead = new Node();
dummyTail = new Node();
dummyHead.next = dummyTail;
dummyTail.prev = dummyHead;
size = 0;
}

public void addFirst(Node node) {
Node prevHead = dummyHead.next;
node.prev = dummyHead;
dummyHead.next = node;
node.next = prevHead;
prevHead.prev = node;
size++;
}

public void remove(Node node) {
Node prev = node.prev, next = node.next;
prev.next = next;
next.prev = prev;
size--;
}

public Node getHead() {
return dummyHead.next;
}

public Node getTail() {
return dummyTail.prev;
}
}