diff --git a/zion/lib/shared_ptr.h b/zion/lib/shared_ptr.h index 7af3b10..8c26a31 100644 --- a/zion/lib/shared_ptr.h +++ b/zion/lib/shared_ptr.h @@ -15,12 +15,13 @@ class SharedPtr { } SharedPtr(const SharedPtr& other) - : ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) { + : init_(other.init_), ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) { (*ref_cnt_)++; } SharedPtr& operator=(const SharedPtr& other) { Cleanup(); + init_ = other.init_; ptr_ = other.ptr_; ref_cnt_ = other.ref_cnt_; (*ref_cnt_)++; @@ -29,14 +30,33 @@ class SharedPtr { ~SharedPtr() { Cleanup(); } - T& operator*() { return *ptr_; } - const T& operator*() const { return *ptr_; } - T* operator->() { return ptr_; } - const T* operator->() const { return ptr_; } + T& operator*() { + CheckValid(); + return *ptr_; + } + const T& operator*() const { + CheckValid(); + return *ptr_; + } + T* operator->() { + CheckValid(); + return ptr_; + } + const T* operator->() const { + CheckValid(); + return ptr_; + } - T* ptr() { return ptr_; } + T* ptr() { + CheckValid(); + return ptr_; + } - bool operator==(const SharedPtr& other) { return ptr_ == other.ptr_; } + bool operator==(const SharedPtr& other) { + CheckValid(); + other.CheckValid(); + return ptr_ == other.ptr_; + } bool empty() { return !init_; } @@ -55,6 +75,12 @@ class SharedPtr { delete ref_cnt_; } } + + void CheckValid() const { + if (!init_) { + panic("Accessing invalid shared ptr"); + } + } }; template