rust实现的双向list
list的功能
先定义list的接口,就是list有的具体功能。可以看出这个list不仅仅在头尾可以增加和删除,也可以在list中间增加删除。
trait定义:
pub trait ListInterface<T> {
// NodePtr是list中node的指针,通过NodePtr操作node,具体定义参见后面的代码。
type NodePtr;
// IntoIter是一个iterator,list可以转换为IntoIter,此后list就不存在了。
type IntoIter: Iterator;
// Iter也是一个iteraotor,可以用来遍历list。
type Iter: Iterator;
fn is_empty(&self) -> bool;
// node的数量
fn len(&self) -> usize;
// at是具体的位置,从0开始直到len - 1。如果at > (len - 1)返回值为None
fn get(&self, at: usize) -> Option<Self::NodePtr>;
fn push_head(&mut self, value: T) -> Self::NodePtr;
fn push_tail(&mut self, value: T) -> Self::NodePtr;
fn pop_head(&mut self) -> Option<Self::NodePtr>;
fn pop_tail(&mut self) -> Option<Self::NodePtr>;
// 如果at > (len - 1),实际就是push_tail
fn insert_at(&mut self, at: usize, value: T) -> Self::NodePtr;
// 如果before不是指向本list的node,就不插入,返回值是None
fn insert(&mut self, before: &Self::NodePtr, value: T) -> Option<Self::NodePtr>;
// 如果at > (len - 1),就不remove,返回值为None
fn remove_at(&mut self, at: usize) -> Option<Self::NodePtr>;
// 如果ptr不是指向本list的node,就不remove,返回值是None
fn remove(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;
// 将node移动到head
fn top(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;
// 将node移动到tail
fn bottom(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;
fn into_iter(self) -> Self::IntoIter;
fn iter(&self) -> Self::Iter;
}
数据结构
首先定义list中的node,如下:
#[cfg_attr(test, derive(Eq, PartialEq))]
pub struct Node<T> {
pub value: T,
id: Option<usize>,
pre: Option<Pointer<T>>,
next: Option<Pointer<T>>,
}
impl<T> Node<T> {
fn new(id: usize, value: T) -> Self {
Self {
value,
id: Some(id),
pre: None,
next: None,
}
}
}
impl<T: std::fmt::Debug> fmt::Debug for Node<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"({}{:?}{})",
if self.pre.is_some() { "<" } else { "|" },
self.value,
if self.next.is_some() { ">" } else { "|" },
)
}
}
注意,Node中只有value是pub的,其他都是外界不可操作的,Node的new方法也不是pub。
其中的id解释一下,每个List都有一个唯一的id,List中Node的id与List的id一致,如果Node的id与List的id不一致,那就说明Node不是这个List产生的。
Pointer<T>定义如下,指向Node<T>。
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub struct Pointer<T>(Rc<RefCell<Node<T>>>);
impl<T> Pointer<T> {
fn new(id: usize, value: T) -> Self {
Self(Rc::new(RefCell::new(Node::new(id, value))))
}
pub fn node(&self) -> Ref<Node<T>> {
self.0.borrow()
}
pub fn node_mut(&self) -> RefMut<Node<T>> {
self.0.borrow_mut()
}
}
impl<T> Clone for Pointer<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
这是是List的定义:
pub struct List<T> {
id: usize,
count: usize,
head: Option<Pointer<T>>,
tail: Option<Pointer<T>>,
}
impl<T> List<T> {
pub fn new() -> Self {
static mut ID: AtomicUsize = AtomicUsize::new(1);
Self {
id: unsafe { ID.fetch_add(1, Ordering::SeqCst) },
count: 0,
head: None,
tail: None,
}
}
fn new_pointer(&mut self, value: T) -> Pointer<T> {
self.count += 1;
Pointer::new(self.id, value)
}
fn contains(&self, ptr: &Pointer<T>) -> bool {
ptr.node()
.id
.map(|v| v == self.id)
.unwrap_or(false)
}
}
impl<T> Default for List<T> {
fn default() -> Self { Self::new() }
}
impl<T> Drop for List<T> {
fn drop(&mut self) {
self.iter().for_each(|_| {
self.pop_tail();
});
}
}
impl<T: std::fmt::Debug> fmt::Debug for List<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut r = write!(f, "({})-[{}]", self.id, self.count);
let mut it = self.iter();
while let Some(p) = it.next() {
r = write!(f, "-{:?}", p.node());
}
r
}
}
再最后,是IntoIter和Iter的定义:
pub struct IntoIter<T>(List<T>);
impl<T> Iterator for IntoIter<T> {
type Item = Pointer<T>;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_head()
}
}
pub struct Iter<T> {
next: Option<Pointer<T>>,
}
impl<T> Iterator for Iter<T> {
type Item = Pointer<T>;
fn next(&mut self) -> Option<Self::Item> {
let p: Option<Pointer<T>> = self.next.take();
self.next = p
.as_ref()
.and_then(|v| v.node().next.as_ref().map(|v2| v2.clone()));
p
}
}
实现ListInterface trait
基本上都比较简单,代码如下:
impl<T> ListInterface<T> for List<T> {
type NodePtr = Pointer<T>;
type IntoIter = IntoIter<T>;
type Iter = Iter<T>;
fn is_empty(&self) -> bool {
self.count == 0
}
fn len(&self) -> usize {
self.count
}
fn get(&self, at: usize) -> Option<Pointer<T>> {
self.iter().nth(at)
}
fn push_head(&mut self, value: T) -> Pointer<T> {
if let Some(head) = self.head.clone() {
self.insert(&head, value).unwrap()
} else {
let ptr = self.new_pointer(value);
self.tail.replace(ptr.clone());
self.head.replace(ptr.clone());
ptr
}
}
fn push_tail(&mut self, value: T) -> Pointer<T> {
let ptr = self.new_pointer(value);
if let Some(tail) = self.tail.clone() {
tail.node_mut().next.replace(ptr.clone());
ptr.node_mut().pre.replace(tail);
self.tail.replace(ptr.clone());
} else {
self.tail.replace(ptr.clone());
self.head.replace(ptr.clone());
}
ptr
}
fn pop_head(&mut self) -> Option<Pointer<T>> {
self.head.clone().and_then(|v| self.remove(&v))
}
fn pop_tail(&mut self) -> Option<Pointer<T>> {
self.tail.clone().and_then(|v| self.remove(&v))
}
fn insert_at(&mut self, at: usize, value: T) -> Pointer<T> {
if let Some(ref p) = self.get(at) {
self.insert(p, value).unwrap()
} else {
self.push_tail(value)
}
}
fn insert(&mut self, at: &Pointer<T>, value: T) -> Option<Pointer<T>> {
if !self.contains(at) {
return None;
}
let ptr = self.new_pointer(value);
if let Some(ref pre) = at.node().pre {
pre.node_mut().next.replace(ptr.clone());
ptr.node_mut().pre.replace(pre.clone());
} else {
// the at was the head, so the node will be the head now.
self.head.replace(ptr.clone());
}
at.node_mut().pre.replace(ptr.clone());
ptr.node_mut().next.replace(at.clone());
Some(ptr)
}
fn remove_at(&mut self, at: usize) -> Option<Pointer<T>> {
let p = self.get(at);
p.as_ref().map(|v| self.remove(v));
p
}
fn remove(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
if !self.contains(ptr) {
return None;
}
if let Some(ref pre) = ptr.node().pre {
// ptr不是head
if let Some(ref next) = ptr.node().next {
// ptr不是tail
pre.node_mut().next.replace(next.clone());
next.node_mut().pre.replace(pre.clone());
} else {
// ptr是tail,tail改为前一个
pre.node_mut().next.take();
self.tail.replace(pre.clone());
}
} else {
// ptr is head
if let Some(ref next) = ptr.node().next {
// ptr is not tail
next.node_mut().pre.take();
self.head.replace(next.clone());
} else {
// node is tail
self.head.take();
self.tail.take();
}
}
// split the node from the list
ptr.node_mut().pre.take();
ptr.node_mut().next.take();
self.count -= 1;
ptr.node_mut().id.take();
Some(ptr.clone())
}
fn top(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
if !self.contains(ptr) {
return None;
}
if let Some(ref pre) = ptr.node().pre {
// ptr不是head
if let Some(ref next) = ptr.node().next {
// ptr不是tail
pre.node_mut().next.replace(next.clone());
next.node_mut().pre.replace(pre.clone());
} else {
// ptr是tail,tail改为前一个
pre.node_mut().next.take();
self.tail.replace(pre.clone());
}
}
if ptr.node().pre.is_some() {
// ptr不是head,将node放在head之前
if let Some(ref head) = self.head {
head.node_mut().pre.replace(ptr.clone());
}
ptr.node_mut().pre = None;
ptr.node_mut().next = self.head.clone();
// ptr改为node
self.head.replace(ptr.clone());
}
Some(ptr.clone())
}
fn bottom(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
if !self.contains(ptr) {
return None;
}
if let Some(ref next) = ptr.node().next {
// ptr不是tail
if let Some(ref pre) = ptr.node().pre {
// ptr不是head
next.node_mut().pre.replace(pre.clone());
pre.node_mut().next.replace(next.clone());
} else {
// ptr是head,head改为后一个
next.node_mut().pre.take();
self.head.replace(next.clone());
}
}
if ptr.node().next.is_some() {
// ptr不是tail,将ptr放在tail之后
if let Some(ref tail) = self.tail {
tail.node_mut().next.replace(ptr.clone());
}
ptr.node_mut().next = None;
ptr.node_mut().pre = self.tail.clone();
// ptr改为node
self.tail.replace(ptr.clone());
}
Some(ptr.clone())
}
fn into_iter(self) -> IntoIter<T> {
IntoIter(self)
}
fn iter(&self) -> Iter<T> {
Iter {
next: self.head.clone(),
}
}
}
关于Send和Sync
List<T>是!Send + !Sync,原因是Pointer中使用了Rc,Rc是!Send + !Sync。
那么可以使List支持Send + Sync吗?
将Rc替换为Arc是否支持Send + Sync呢?答案是否。
Arc<T>是Send + Sync必须要求T: Send + Sync,参见如下std中的定义:
impl<T: ?Sized + Sync + Send> Send for Arc<T>
impl<T: ?Sized + Sync + Send> Sync for Arc<T>
在Pointer<T>的定义中,就是要求RefCell<T>必须是Send + Sync。但是RefCell明确不支持Sync,参见如下std中的定义:
impl<T> Send for RefCell<T>
where
T: Send + ?Sized,
[src]
impl<T> !Sync for RefCell<T>
where
T: ?Sized,
所以Pointer<T>就不支持Send + Sync,那么List就更不能支持Send + Sync。
如何让List支持Send + Sync?
可以将Pointer修改如下,Pointer就是Send + Sync。
pub struct Pointer<T: Send>(Arc<Mutex<RefCell<Node<T>>>>);
同时Node和List也就是Send + Sync了,前提条件就是T必须是Send + Sync。如下所示:
impl<T> Send for Node<T>
where
T: Send,
impl<T> Sync for Node<T>
where
T: Send + Sync,
impl<T> Send for List<T>
where
T: Send,
impl<T> Sync for List<T>
where
T: Send,
宏
定义宏list,用来生成List:
macro_rules! list {
[] => { List::new() };
[$($x: expr),+] => {{
let mut list = List::new();
$(list.push_tail($x);)+
list
}};
[$($x: expr,)+] => { list![$($x),+] }
}
测试
len()方法
测试代码如下:
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_list_len() {
let l = list![1_u32, 2, 3, 4];
dbg!(&l);
assert_eq!(4, l.len())
}
}
如下的测试命令:
cargo test test_list_len -- --nocapture --test-threads=1
输出如下:
running 1 test
[leetcode/src/utilites/list3.rs:372] &l = (1)-[4]-(|1>)-(<2>)-(<3>)-(<4|)
test utilites::list3::tests::test_list_len ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 24 filtered out
get()方法
测试代码如下:
#[test]
fn test_list_get() {
let l = list![1_u32, 2, 3, 4];
dbg!(&l);
assert_eq!(1, l.get(0).unwrap().node().value);
assert_eq!(3, l.get(2).unwrap().node().value);
assert_eq!(4, l.get(3).unwrap().node().value);
assert_eq!(None, l.get(4));
assert_eq!(None, l.get(14));
l.get(1).unwrap().node_mut().value = 22;
assert_eq!(22, l.get(1).unwrap().node().value);
dbg!(&l);
}
输出如下:
running 1 test
[leetcode/src/utilites/list3.rs:381] &l = (1)-[4]-(|1>)-(<2>)-(<3>)-(<4|)
[leetcode/src/utilites/list3.rs:392] &l = (1)-[4]-(|1>)-(<22>)-(<3>)-(<4|)
test utilites::list3::tests::test_list_get ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 25 filtered out
push_head()方法
测试代码如下:
fn test_list_push_head() {
let mut l = List::<u32>::new();
l.push_head(3);
l.push_head(2);
l.push_head(1);
dbg!(&l);
assert_eq!(1, l.get(0).unwrap().node().value);
assert_eq!(2, l.get(1).unwrap().node().value);
assert_eq!(3, l.get(2).unwrap().node().value);
// --
let mut l = list![11_u32, 12, 13, 14];
l.push_head(3);
l.push_head(2);
l.push_head(1);
dbg!(&l);
assert_eq!(7, l.len());
assert_eq!(1, l.get(0).unwrap().node().value);
assert_eq!(3, l.get(2).unwrap().node().value);
assert_eq!(11, l.get(3).unwrap().node().value);
assert_eq!(13, l.get(5).unwrap().node().value);
assert_eq!(14, l.get(6).unwrap().node().value);
assert_eq!(None, l.get(7));
assert_eq!(None, l.get(17));
}
输出如下:
running 1 test
[leetcode/src/utilites/list3.rs:402] &l = (1)-[3]-(|1>)-(<2>)-(<3|)
[leetcode/src/utilites/list3.rs:414] &l = (2)-[7]-(|1>)-(<2>)-(<3>)-(<11>)-(<12>)-(<13>)-(<14|)
test utilites::list3::tests::test_list_push_head ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 26 filtered out
更多的测试代码,就不在这里啰嗦了,请参考playgournd的链接,有兴趣可以玩玩。
https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=7534c79e06a299af42e170dad583b5c5