blob: 3584e085c2b1b2441f350771f39901a0b2c09fcb [file] [log] [blame]
Andy Hung0f7ad8c2020-01-03 13:24:34 -08001/*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <memory>
20#include <mutex>
21
22namespace android::mediametrics {
23
24/**
25 * Wraps a shared-ptr for which member access through operator->() behaves
26 * as if the shared-ptr is atomically copied and then (without a lock) -> called.
27 *
28 * See related C++ 20:
29 * https://en.cppreference.com/w/cpp/memory/shared_ptr/atomic2
30 *
31 * EXAMPLE:
32 *
33 * SharedPtrWrap<T> t{};
34 *
35 * thread1() {
36 * t->func(); // safely executes either the original t or the one created by thread2.
37 * }
38 *
39 * thread2() {
40 * t.set(std::make_shared<T>()); // overwrites the original t.
41 * }
42 */
43template <typename T>
44class SharedPtrWrap {
45 mutable std::mutex mLock;
46 std::shared_ptr<T> mPtr;
47
48public:
49 template <typename... Args>
50 explicit SharedPtrWrap(Args&&... args)
51 : mPtr(std::make_shared<T>(std::forward<Args>(args)...))
52 {}
53
54 /**
55 * Gets the current shared pointer. This must return a value, not a reference.
56 *
57 * For compatibility with existing shared_ptr, we do not pass back a
58 * shared_ptr<const T> for the const getter.
59 */
60 std::shared_ptr<T> get() const {
61 std::lock_guard lock(mLock);
62 return mPtr;
63 }
64
65 /**
66 * Sets the current shared pointer, returning the previous shared pointer.
67 */
68 std::shared_ptr<T> set(std::shared_ptr<T> ptr) { // pass by value as we use swap.
69 std::lock_guard lock(mLock);
70 std::swap(ptr, mPtr);
71 return ptr;
72 }
73
74 /**
75 * Returns a shared pointer value representing T at the instant of time when
76 * the call executes. The lifetime of the shared pointer will
77 * be extended as we are returning an instance of the shared_ptr
78 * not a reference to it. The destructor to the returned shared_ptr
79 * will be called sometime after the expression including the member function or
80 * the member variable is evaluated. Do not change to a reference!
81 */
82
83 // For compatibility with existing shared_ptr, we do not pass back a
84 // shared_ptr<const T> for the const operator pointer access.
85 std::shared_ptr<T> operator->() const {
86 return get();
87 }
88 /**
89 * We do not overload operator*() as the reference is not stable if the
90 * lock is not held.
91 */
92};
93
94/**
95 * Wraps member access to the class T by a lock.
96 *
97 * The object T is constructed within the LockWrap to guarantee
98 * locked access at all times. When T's methods are accessed through ->,
99 * a monitor style lock is obtained to prevent multiple threads from executing
100 * methods in the object T at the same time.
101 * Suggested by Kevin R.
102 *
103 * EXAMPLE:
104 *
105 * // Accumulator class which is very slow, requires locking for multiple threads.
106 *
107 * class Accumulator {
108 * int32_t value_ = 0;
109 * public:
110 * void add(int32_t incr) {
111 * const int32_t temp = value_;
112 * sleep(0); // yield
113 * value_ = temp + incr;
114 * }
115 * int32_t get() { return value_; }
116 * };
117 *
118 * // We use LockWrap on Accumulator to have safe multithread access.
119 * android::mediametrics::LockWrap<Accumulator> a{}; // locked accumulator succeeds
120 *
121 * // Conversely, the following line fails:
122 * // auto a = std::make_shared<Accumulator>(); // this fails, only 50% adds atomic.
123 *
124 * constexpr size_t THREADS = 100;
125 * constexpr size_t ITERATIONS = 10;
126 * constexpr int32_t INCREMENT = 1;
127 *
128 * // Test by generating multiple threads, all adding simultaneously.
129 * std::vector<std::future<void>> threads(THREADS);
130 * for (size_t i = 0; i < THREADS; ++i) {
131 * threads.push_back(std::async(std::launch::async, [&] {
132 * for (size_t j = 0; j < ITERATIONS; ++j) {
133 * a->add(INCREMENT); // add needs locked access here.
134 * }
135 * }));
136 * }
137 * threads.clear();
138 *
139 * // If the add operations are not atomic, value will be smaller than expected.
140 * ASSERT_EQ(INCREMENT * THREADS * ITERATIONS, (size_t)a->get());
141 *
142 */
143template <typename T>
144class LockWrap {
145 /**
146 * Holding class that keeps the pointer and the lock.
147 *
148 * We return this holding class from operator->() to keep the lock until the
149 * method function or method variable access is completed.
150 */
151 class LockedPointer {
152 friend LockWrap;
Andy Hung42b99302020-01-10 12:01:18 -0800153 LockedPointer(T *t, std::recursive_mutex *lock, std::atomic<size_t> *recursionDepth)
154 : mT(t), mLock(*lock), mRecursionDepth(recursionDepth) { ++*mRecursionDepth; }
155
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800156 T* const mT;
Andy Hung42b99302020-01-10 12:01:18 -0800157 std::lock_guard<std::recursive_mutex> mLock;
158 std::atomic<size_t>* mRecursionDepth;
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800159 public:
Andy Hung42b99302020-01-10 12:01:18 -0800160 ~LockedPointer() {
161 --*mRecursionDepth; // Used for testing, we do not check underflow.
162 }
163
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800164 const T* operator->() const {
165 return mT;
166 }
167 T* operator->() {
168 return mT;
169 }
170 };
171
Andy Hung42b99302020-01-10 12:01:18 -0800172 // We must use a recursive mutex because the end of the full expression may
173 // involve another reference to T->.
174 //
175 // A recursive mutex allows the same thread to recursively acquire,
176 // but different thread would block.
177 //
178 // Example which fails with a normal mutex:
179 //
180 // android::mediametrics::LockWrap<std::vector<int>> v{std::initializer_list<int>{1, 2}};
181 // const int sum = v->operator[](0) + v->operator[](1);
182 //
183 mutable std::recursive_mutex mLock;
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800184 mutable T mT;
Andy Hung42b99302020-01-10 12:01:18 -0800185 mutable std::atomic<size_t> mRecursionDepth{}; // Used for testing.
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800186
187public:
188 template <typename... Args>
189 explicit LockWrap(Args&&... args) : mT(std::forward<Args>(args)...) {}
190
191 const LockedPointer operator->() const {
Andy Hung42b99302020-01-10 12:01:18 -0800192 return LockedPointer(&mT, &mLock, &mRecursionDepth);
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800193 }
194 LockedPointer operator->() {
Andy Hung42b99302020-01-10 12:01:18 -0800195 return LockedPointer(&mT, &mLock, &mRecursionDepth);
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800196 }
197
Andy Hung42b99302020-01-10 12:01:18 -0800198 // Returns the lock depth of the recursive mutex.
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800199 // @TestApi
Andy Hung42b99302020-01-10 12:01:18 -0800200 size_t getRecursionDepth() const {
201 return mRecursionDepth;
Andy Hung0f7ad8c2020-01-03 13:24:34 -0800202 }
203};
204
205} // namespace android::mediametrics