Add validity check in shared ptr

This commit is contained in:
Drew Galbraith 2023-05-30 01:05:50 -07:00
parent de49dcc01a
commit c6921b5459
1 changed files with 33 additions and 7 deletions

View File

@ -15,12 +15,13 @@ class SharedPtr {
} }
SharedPtr(const SharedPtr<T>& other) SharedPtr(const SharedPtr<T>& other)
: ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) { : init_(other.init_), ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) {
(*ref_cnt_)++; (*ref_cnt_)++;
} }
SharedPtr& operator=(const SharedPtr<T>& other) { SharedPtr& operator=(const SharedPtr<T>& other) {
Cleanup(); Cleanup();
init_ = other.init_;
ptr_ = other.ptr_; ptr_ = other.ptr_;
ref_cnt_ = other.ref_cnt_; ref_cnt_ = other.ref_cnt_;
(*ref_cnt_)++; (*ref_cnt_)++;
@ -29,14 +30,33 @@ class SharedPtr {
~SharedPtr() { Cleanup(); } ~SharedPtr() { Cleanup(); }
T& operator*() { return *ptr_; } T& operator*() {
const T& operator*() const { return *ptr_; } CheckValid();
T* operator->() { return ptr_; } return *ptr_;
const T* operator->() const { return ptr_; } }
const T& operator*() const {
CheckValid();
return *ptr_;
}
T* operator->() {
CheckValid();
return ptr_;
}
const T* operator->() const {
CheckValid();
return ptr_;
}
T* ptr() { return ptr_; } T* ptr() {
CheckValid();
return ptr_;
}
bool operator==(const SharedPtr<T>& other) { return ptr_ == other.ptr_; } bool operator==(const SharedPtr<T>& other) {
CheckValid();
other.CheckValid();
return ptr_ == other.ptr_;
}
bool empty() { return !init_; } bool empty() { return !init_; }
@ -55,6 +75,12 @@ class SharedPtr {
delete ref_cnt_; delete ref_cnt_;
} }
} }
void CheckValid() const {
if (!init_) {
panic("Accessing invalid shared ptr");
}
}
}; };
template <typename T, class... A> template <typename T, class... A>