Merge "Transcoder: Refactor sample writer to not block clients."
diff --git a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
index bb0da88..afa5021 100644
--- a/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
+++ b/media/libmediatranscoding/transcoder/MediaSampleWriter.cpp
@@ -72,6 +72,11 @@
     AMediaMuxer* mMuxer;
 };
 
+// static
+std::shared_ptr<MediaSampleWriter> MediaSampleWriter::Create() {
+    return std::shared_ptr<MediaSampleWriter>(new MediaSampleWriter());
+}
+
 MediaSampleWriter::~MediaSampleWriter() {
     if (mState == STARTED) {
         stop();  // Join thread.
@@ -92,7 +97,7 @@
         return false;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != UNINITIALIZED) {
         LOG(ERROR) << "Sample writer is already initialized";
         return false;
@@ -104,39 +109,58 @@
     return true;
 }
 
-bool MediaSampleWriter::addTrack(const std::shared_ptr<MediaSampleQueue>& sampleQueue,
-                                 const std::shared_ptr<AMediaFormat>& trackFormat) {
-    if (sampleQueue == nullptr || trackFormat == nullptr) {
-        LOG(ERROR) << "Sample queue and track format must be non-null";
-        return false;
+MediaSampleWriter::MediaSampleConsumerFunction MediaSampleWriter::addTrack(
+        const std::shared_ptr<AMediaFormat>& trackFormat) {
+    if (trackFormat == nullptr) {
+        LOG(ERROR) << "Track format must be non-null";
+        return nullptr;
     }
 
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
     if (mState != INITIALIZED) {
         LOG(ERROR) << "Muxer needs to be initialized when adding tracks.";
-        return false;
+        return nullptr;
     }
-    ssize_t trackIndex = mMuxer->addTrack(trackFormat.get());
-    if (trackIndex < 0) {
-        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndex;
-        return false;
+    ssize_t trackIndexOrError = mMuxer->addTrack(trackFormat.get());
+    if (trackIndexOrError < 0) {
+        LOG(ERROR) << "Failed to add media track to muxer: " << trackIndexOrError;
+        return nullptr;
     }
+    const size_t trackIndex = static_cast<size_t>(trackIndexOrError);
 
     int64_t durationUs;
     if (!AMediaFormat_getInt64(trackFormat.get(), AMEDIAFORMAT_KEY_DURATION, &durationUs)) {
         durationUs = 0;
     }
 
-    mAllTracks.push_back(std::make_unique<TrackRecord>(sampleQueue, static_cast<size_t>(trackIndex),
-                                                       durationUs));
-    mSortedTracks.insert(mAllTracks.back().get());
-    return true;
+    mTracks.emplace(trackIndex, durationUs);
+    std::shared_ptr<MediaSampleWriter> thisWriter = shared_from_this();
+
+    return [self = shared_from_this(), trackIndex](const std::shared_ptr<MediaSample>& sample) {
+        self->addSampleToTrack(trackIndex, sample);
+    };
+}
+
+void MediaSampleWriter::addSampleToTrack(size_t trackIndex,
+                                         const std::shared_ptr<MediaSample>& sample) {
+    if (sample == nullptr) return;
+
+    bool wasEmpty;
+    {
+        std::scoped_lock lock(mMutex);
+        wasEmpty = mSampleQueue.empty();
+        mSampleQueue.push(std::make_pair(trackIndex, sample));
+    }
+
+    if (wasEmpty) {
+        mSampleSignal.notify_one();
+    }
 }
 
 bool MediaSampleWriter::start() {
-    std::scoped_lock lock(mStateMutex);
+    std::scoped_lock lock(mMutex);
 
-    if (mAllTracks.size() == 0) {
+    if (mTracks.size() == 0) {
         LOG(ERROR) << "No tracks to write.";
         return false;
     } else if (mState != INITIALIZED) {
@@ -144,30 +168,28 @@
         return false;
     }
 
+    mState = STARTED;
     mThread = std::thread([this] {
         media_status_t status = writeSamples();
         if (auto callbacks = mCallbacks.lock()) {
             callbacks->onFinished(this, status);
         }
     });
-    mState = STARTED;
     return true;
 }
 
 bool MediaSampleWriter::stop() {
-    std::scoped_lock lock(mStateMutex);
-
-    if (mState != STARTED) {
-        LOG(ERROR) << "Sample writer is not started.";
-        return false;
+    {
+        std::scoped_lock lock(mMutex);
+        if (mState != STARTED) {
+            LOG(ERROR) << "Sample writer is not started.";
+            return false;
+        }
+        mState = STOPPED;
     }
 
-    // Stop the sources, and wait for thread to join.
-    for (auto& track : mAllTracks) {
-        track->mSampleQueue->abort();
-    }
+    mSampleSignal.notify_all();
     mThread.join();
-    mState = STOPPED;
     return true;
 }
 
@@ -191,83 +213,69 @@
     return writeStatus != AMEDIA_OK ? writeStatus : muxerStatus;
 }
 
-std::multiset<MediaSampleWriter::TrackRecord*>::iterator MediaSampleWriter::getNextOutputTrack() {
-    // Find the first track that has samples ready in its queue AND is not more than
-    // mMaxTrackDivergenceUs ahead of the slowest track. If no such track exists then return the
-    // slowest track and let the writer wait for samples to become ready. Note that mSortedTracks is
-    // sorted by each track's previous sample timestamp in ascending order.
-    auto slowestTrack = mSortedTracks.begin();
-    if (slowestTrack == mSortedTracks.end() || !(*slowestTrack)->mSampleQueue->isEmpty()) {
-        return slowestTrack;
-    }
-
-    const int64_t slowestTimeUs = (*slowestTrack)->mPrevSampleTimeUs;
-    int64_t divergenceUs;
-
-    for (auto it = std::next(slowestTrack); it != mSortedTracks.end(); ++it) {
-        // If the current track has diverged then the rest will have too, so we can stop the search.
-        // If not and it has samples ready then return it, otherwise keep looking.
-        if (__builtin_sub_overflow((*it)->mPrevSampleTimeUs, slowestTimeUs, &divergenceUs) ||
-            divergenceUs >= mMaxTrackDivergenceUs) {
-            break;
-        } else if (!(*it)->mSampleQueue->isEmpty()) {
-            return it;
-        }
-    }
-
-    // No track with pending samples within acceptable time interval was found, so let the writer
-    // wait for the slowest track to produce a new sample.
-    return slowestTrack;
-}
-
-media_status_t MediaSampleWriter::runWriterLoop() {
+media_status_t MediaSampleWriter::runWriterLoop() NO_THREAD_SAFETY_ANALYSIS {
     AMediaCodecBufferInfo bufferInfo;
     int32_t lastProgressUpdate = 0;
+    int trackEosCount = 0;
 
     // Set the "primary" track that will be used to determine progress to the track with longest
     // duration.
     int primaryTrackIndex = -1;
     int64_t longestDurationUs = 0;
-    for (auto& track : mAllTracks) {
-        if (track->mDurationUs > longestDurationUs) {
-            primaryTrackIndex = track->mTrackIndex;
-            longestDurationUs = track->mDurationUs;
+    for (auto it = mTracks.begin(); it != mTracks.end(); ++it) {
+        if (it->second.mDurationUs > longestDurationUs) {
+            primaryTrackIndex = it->first;
+            longestDurationUs = it->second.mDurationUs;
         }
     }
 
     while (true) {
-        auto outputTrackIter = getNextOutputTrack();
-
-        // Exit if all tracks have reached end of stream.
-        if (outputTrackIter == mSortedTracks.end()) {
+        if (trackEosCount >= mTracks.size()) {
             break;
         }
 
-        // Remove the track from the set, update it, and then reinsert it to keep the set in order.
-        TrackRecord* track = *outputTrackIter;
-        mSortedTracks.erase(outputTrackIter);
-
+        size_t trackIndex;
         std::shared_ptr<MediaSample> sample;
-        if (track->mSampleQueue->dequeue(&sample)) {
-            // Track queue was aborted.
-            return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
-        } else if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+        {
+            std::unique_lock lock(mMutex);
+            while (mSampleQueue.empty() && mState == STARTED) {
+                mSampleSignal.wait(lock);
+            }
+
+            if (mState != STARTED) {
+                return AMEDIA_ERROR_UNKNOWN;  // TODO(lnilsson): Custom error code.
+            }
+
+            auto& topEntry = mSampleQueue.top();
+            trackIndex = topEntry.first;
+            sample = topEntry.second;
+            mSampleQueue.pop();
+        }
+
+        TrackRecord& track = mTracks[trackIndex];
+
+        if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+            if (track.mReachedEos) {
+                continue;
+            }
+
             // Track reached end of stream.
-            track->mReachedEos = true;
+            track.mReachedEos = true;
+            trackEosCount++;
 
             // Preserve source track duration by setting the appropriate timestamp on the
             // empty End-Of-Stream sample.
-            if (track->mDurationUs > 0 && track->mFirstSampleTimeSet) {
-                sample->info.presentationTimeUs = track->mDurationUs + track->mFirstSampleTimeUs;
+            if (track.mDurationUs > 0 && track.mFirstSampleTimeSet) {
+                sample->info.presentationTimeUs = track.mDurationUs + track.mFirstSampleTimeUs;
             }
         }
 
-        track->mPrevSampleTimeUs = sample->info.presentationTimeUs;
-        if (!track->mFirstSampleTimeSet) {
+        track.mPrevSampleTimeUs = sample->info.presentationTimeUs;
+        if (!track.mFirstSampleTimeSet) {
             // Record the first sample's timestamp in order to translate duration to EOS
             // time for tracks that does not start at 0.
-            track->mFirstSampleTimeUs = sample->info.presentationTimeUs;
-            track->mFirstSampleTimeSet = true;
+            track.mFirstSampleTimeUs = sample->info.presentationTimeUs;
+            track.mFirstSampleTimeSet = true;
         }
 
         bufferInfo.offset = sample->dataOffset;
@@ -275,8 +283,7 @@
         bufferInfo.flags = sample->info.flags;
         bufferInfo.presentationTimeUs = sample->info.presentationTimeUs;
 
-        media_status_t status =
-                mMuxer->writeSampleData(track->mTrackIndex, sample->buffer, &bufferInfo);
+        media_status_t status = mMuxer->writeSampleData(trackIndex, sample->buffer, &bufferInfo);
         if (status != AMEDIA_OK) {
             LOG(ERROR) << "writeSampleData returned " << status;
             return status;
@@ -284,9 +291,9 @@
         sample.reset();
 
         // TODO(lnilsson): Add option to toggle progress reporting on/off.
-        if (track->mTrackIndex == primaryTrackIndex) {
-            const int64_t elapsed = track->mPrevSampleTimeUs - track->mFirstSampleTimeUs;
-            int32_t progress = (elapsed * 100) / track->mDurationUs;
+        if (trackIndex == primaryTrackIndex) {
+            const int64_t elapsed = track.mPrevSampleTimeUs - track.mFirstSampleTimeUs;
+            int32_t progress = (elapsed * 100) / track.mDurationUs;
             progress = std::clamp(progress, 0, 100);
 
             if (progress > lastProgressUpdate) {
@@ -296,10 +303,6 @@
                 lastProgressUpdate = progress;
             }
         }
-
-        if (!track->mReachedEos) {
-            mSortedTracks.insert(track);
-        }
     }
 
     return AMEDIA_OK;
diff --git a/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
index 92ce60a..698594f 100644
--- a/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/MediaTrackTranscoder.cpp
@@ -94,7 +94,10 @@
         abortTranscodeLoop();
         mMediaSampleReader->setEnforceSequentialAccess(false);
         mTranscodingThread.join();
-        mOutputQueue->abort();  // Wake up any threads waiting for samples.
+        {
+            std::scoped_lock lock{mSampleMutex};
+            mSampleQueue.abort();  // Release any buffered samples.
+        }
         mState = STOPPED;
         return true;
     }
@@ -109,8 +112,24 @@
     }
 }
 
-std::shared_ptr<MediaSampleQueue> MediaTrackTranscoder::getOutputQueue() const {
-    return mOutputQueue;
+void MediaTrackTranscoder::onOutputSampleAvailable(const std::shared_ptr<MediaSample>& sample) {
+    std::scoped_lock lock{mSampleMutex};
+    if (mSampleConsumer == nullptr) {
+        mSampleQueue.enqueue(sample);
+    } else {
+        mSampleConsumer(sample);
+    }
+}
+
+void MediaTrackTranscoder::setSampleConsumer(
+        const MediaSampleWriter::MediaSampleConsumerFunction& sampleConsumer) {
+    std::scoped_lock lock{mSampleMutex};
+    mSampleConsumer = sampleConsumer;
+
+    std::shared_ptr<MediaSample> sample;
+    while (!mSampleQueue.isEmpty() && !mSampleQueue.dequeue(&sample)) {
+        mSampleConsumer(sample);
+    }
 }
 
 }  // namespace android
diff --git a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
index 4730be3..35bdc40 100644
--- a/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/MediaTranscoder.cpp
@@ -123,14 +123,16 @@
     }
 
     // Add track to the writer.
-    const bool ok =
-            mSampleWriter->addTrack(transcoder->getOutputQueue(), transcoder->getOutputFormat());
-    if (!ok) {
+    auto consumer = mSampleWriter->addTrack(transcoder->getOutputFormat());
+    if (consumer == nullptr) {
         LOG(ERROR) << "Unable to add track to sample writer.";
         sendCallback(AMEDIA_ERROR_UNKNOWN);
         return;
     }
 
+    MediaTrackTranscoder* mutableTranscoder = const_cast<MediaTrackTranscoder*>(transcoder);
+    mutableTranscoder->setSampleConsumer(consumer);
+
     mTracksAdded.insert(transcoder);
     if (mTracksAdded.size() == mTrackTranscoders.size()) {
         // Enable sequential access mode on the sample reader to achieve optimal read performance.
@@ -304,7 +306,7 @@
         return AMEDIA_ERROR_INVALID_OPERATION;
     }
 
-    mSampleWriter = std::make_unique<MediaSampleWriter>();
+    mSampleWriter = MediaSampleWriter::Create();
     const bool initOk = mSampleWriter->init(fd, shared_from_this());
 
     if (!initOk) {
diff --git a/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
index e7c0271..35b1d33 100644
--- a/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/PassthroughTrackTranscoder.cpp
@@ -138,10 +138,7 @@
         }
 
         sample->info = info;
-        if (mOutputQueue->enqueue(sample)) {
-            LOG(ERROR) << "Output queue aborted";
-            return AMEDIA_ERROR_IO;
-        }
+        onOutputSampleAvailable(sample);
     }
 
     if (mStopRequested && !mEosFromSource) {
diff --git a/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp b/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
index b0bf59f..c7d775c 100644
--- a/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
+++ b/media/libmediatranscoding/transcoder/VideoTrackTranscoder.cpp
@@ -375,12 +375,7 @@
         sample->info.flags = bufferInfo.flags;
         sample->info.presentationTimeUs = bufferInfo.presentationTimeUs;
 
-        const bool aborted = mOutputQueue->enqueue(sample);
-        if (aborted) {
-            LOG(ERROR) << "Output sample queue was aborted. Stopping transcode.";
-            mStatus = AMEDIA_ERROR_IO;  // TODO: Define custom error codes?
-            return;
-        }
+        onOutputSampleAvailable(sample);
     } else if (bufferIndex == AMEDIACODEC_INFO_OUTPUT_FORMAT_CHANGED) {
         AMediaFormat* newFormat = AMediaCodec_getOutputFormat(mEncoder->getCodec());
         LOG(DEBUG) << "Encoder output format changed: " << AMediaFormat_toString(newFormat);
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h b/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
index 5f9822d..2032def 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaSampleReaderNDK.h
@@ -58,7 +58,6 @@
     virtual ~MediaSampleReaderNDK() override;
 
 private:
-
     /**
      * SamplePosition describes the position of a single sample in the media file using its
      * timestamp and index in the file.
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
index d4b1fcf..f762556 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaSampleWriter.h
@@ -17,17 +17,19 @@
 #ifndef ANDROID_MEDIA_SAMPLE_WRITER_H
 #define ANDROID_MEDIA_SAMPLE_WRITER_H
 
-#include <media/MediaSampleQueue.h>
+#include <media/MediaSample.h>
 #include <media/NdkMediaCodec.h>
 #include <media/NdkMediaError.h>
 #include <media/NdkMediaFormat.h>
 #include <utils/Mutex.h>
 
+#include <condition_variable>
 #include <functional>
 #include <memory>
 #include <mutex>
-#include <set>
+#include <queue>
 #include <thread>
+#include <unordered_map>
 
 namespace android {
 
@@ -62,18 +64,16 @@
 };
 
 /**
- * MediaSampleWriter writes samples to a muxer while keeping its input sources synchronized. Each
- * source track have its own MediaSampleQueue from which samples are dequeued by the sample writer
- * and written to the muxer. The sample writer always prioritizes dequeueing samples from the source
- * track that is farthest behind by comparing sample timestamps. If the slowest track does not have
- * any samples pending the writer moves on to the next track but never allows tracks to diverge more
- * than a configurable duration of time. The default muxer interface implementation is based
+ * MediaSampleWriter is a wrapper around a muxer. The sample writer puts samples on a queue that
+ * is serviced by an internal thread to minimize blocking time for clients. MediaSampleWriter also
+ * provides progress reporting. The default muxer interface implementation is based
  * directly on AMediaMuxer.
  */
-class MediaSampleWriter {
+class MediaSampleWriter : public std::enable_shared_from_this<MediaSampleWriter> {
 public:
-    /** The default maximum track divergence in microseconds. */
-    static constexpr uint32_t kDefaultMaxTrackDivergenceUs = 1 * 1000 * 1000;  // 1 second.
+    /** Function prototype for delivering media samples to the writer. */
+    using MediaSampleConsumerFunction =
+            std::function<void(const std::shared_ptr<MediaSample>& sample)>;
 
     /** Callback interface. */
     class CallbackInterface {
@@ -90,18 +90,7 @@
         virtual ~CallbackInterface() = default;
     };
 
-    /**
-     * Constructor with custom maximum track divergence.
-     * @param maxTrackDivergenceUs The maximum track divergence in microseconds.
-     */
-    MediaSampleWriter(uint32_t maxTrackDivergenceUs)
-          : mMaxTrackDivergenceUs(maxTrackDivergenceUs), mMuxer(nullptr), mState(UNINITIALIZED){};
-
-    /** Constructor using the default maximum track divergence. */
-    MediaSampleWriter() : MediaSampleWriter(kDefaultMaxTrackDivergenceUs){};
-
-    /** Destructor. */
-    ~MediaSampleWriter();
+    static std::shared_ptr<MediaSampleWriter> Create();
 
     /**
      * Initializes the sample writer with its default muxer implementation. MediaSampleWriter needs
@@ -125,12 +114,12 @@
     /**
      * Adds a new track to the sample writer. Tracks must be added after the sample writer has been
      * initialized and before it is started.
-     * @param sampleQueue The MediaSampleQueue to pull samples from.
      * @param trackFormat The format of the track to add.
-     * @return True if the track was successfully added.
+     * @return A sample consumer to add samples to if the track was successfully added, or nullptr
+     * if the track could not be added.
      */
-    bool addTrack(const std::shared_ptr<MediaSampleQueue>& sampleQueue /* nonnull */,
-                  const std::shared_ptr<AMediaFormat>& trackFormat /* nonnull */);
+    MediaSampleConsumerFunction addTrack(
+            const std::shared_ptr<AMediaFormat>& trackFormat /* nonnull */);
 
     /**
      * Starts the sample writer. The sample writer will start processing samples and writing them to
@@ -150,51 +139,69 @@
      */
     bool stop();
 
+    /** Destructor. */
+    ~MediaSampleWriter();
+
 private:
     struct TrackRecord {
-        TrackRecord(const std::shared_ptr<MediaSampleQueue>& sampleQueue, size_t trackIndex,
-                    int64_t durationUs)
-              : mSampleQueue(sampleQueue),
-                mTrackIndex(trackIndex),
-                mDurationUs(durationUs),
+        TrackRecord(int64_t durationUs)
+              : mDurationUs(durationUs),
                 mFirstSampleTimeUs(0),
                 mPrevSampleTimeUs(INT64_MIN),
                 mFirstSampleTimeSet(false),
-                mReachedEos(false) {}
+                mReachedEos(false){};
 
-        std::shared_ptr<MediaSampleQueue> mSampleQueue;
-        const size_t mTrackIndex;
+        TrackRecord() : TrackRecord(0){};
+
         int64_t mDurationUs;
         int64_t mFirstSampleTimeUs;
         int64_t mPrevSampleTimeUs;
         bool mFirstSampleTimeSet;
         bool mReachedEos;
-
-        struct compare {
-            bool operator()(const TrackRecord* lhs, const TrackRecord* rhs) const {
-                return lhs->mPrevSampleTimeUs < rhs->mPrevSampleTimeUs;
-            }
-        };
     };
 
-    const uint32_t mMaxTrackDivergenceUs;
+    // Track index and sample.
+    using SampleEntry = std::pair<size_t, std::shared_ptr<MediaSample>>;
+
+    struct SampleComparator {
+        // Return true if lhs should come after rhs in the sample queue.
+        bool operator()(const SampleEntry& lhs, const SampleEntry& rhs) {
+            const bool lhsEos = lhs.second->info.flags & SAMPLE_FLAG_END_OF_STREAM;
+            const bool rhsEos = rhs.second->info.flags & SAMPLE_FLAG_END_OF_STREAM;
+
+            if (lhsEos && !rhsEos) {
+                return true;
+            } else if (!lhsEos && rhsEos) {
+                return false;
+            } else if (lhsEos && rhsEos) {
+                return lhs.first > rhs.first;
+            }
+
+            return lhs.second->info.presentationTimeUs > rhs.second->info.presentationTimeUs;
+        }
+    };
+
     std::weak_ptr<CallbackInterface> mCallbacks;
     std::shared_ptr<MediaSampleWriterMuxerInterface> mMuxer;
-    std::vector<std::unique_ptr<TrackRecord>> mAllTracks;
-    std::multiset<TrackRecord*, TrackRecord::compare> mSortedTracks;
-    std::thread mThread;
 
-    std::mutex mStateMutex;
+    std::mutex mMutex;  // Protects sample queue and state.
+    std::condition_variable mSampleSignal;
+    std::thread mThread;
+    std::unordered_map<size_t, TrackRecord> mTracks;
+    std::priority_queue<SampleEntry, std::vector<SampleEntry>, SampleComparator> mSampleQueue
+            GUARDED_BY(mMutex);
+
     enum : int {
         UNINITIALIZED,
         INITIALIZED,
         STARTED,
         STOPPED,
-    } mState GUARDED_BY(mStateMutex);
+    } mState GUARDED_BY(mMutex);
 
+    MediaSampleWriter() : mState(UNINITIALIZED){};
+    void addSampleToTrack(size_t trackIndex, const std::shared_ptr<MediaSample>& sample);
     media_status_t writeSamples();
     media_status_t runWriterLoop();
-    std::multiset<TrackRecord*>::iterator getNextOutputTrack();
 };
 
 }  // namespace android
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h b/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
index 60a9139..c5e161c 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaTrackTranscoder.h
@@ -19,6 +19,7 @@
 
 #include <media/MediaSampleQueue.h>
 #include <media/MediaSampleReader.h>
+#include <media/MediaSampleWriter.h>
 #include <media/NdkMediaError.h>
 #include <media/NdkMediaFormat.h>
 #include <utils/Mutex.h>
@@ -75,10 +76,13 @@
     bool stop();
 
     /**
-     * Retrieves the track transcoder's output sample queue.
-     * @return The output sample queue.
+     * Set the sample consumer function. The MediaTrackTranscoder will deliver transcoded samples to
+     * this function. If the MediaTrackTranscoder is started before a consumer is set the transcoder
+     * will buffer a limited number of samples internally before stalling. Once a consumer has been
+     * set the internally buffered samples will be delivered to the consumer.
+     * @param sampleConsumer The sample consumer function.
      */
-    std::shared_ptr<MediaSampleQueue> getOutputQueue() const;
+    void setSampleConsumer(const MediaSampleWriter::MediaSampleConsumerFunction& sampleConsumer);
 
     /**
       * Retrieves the track transcoder's final output format. The output is available after the
@@ -91,12 +95,14 @@
 
 protected:
     MediaTrackTranscoder(const std::weak_ptr<MediaTrackTranscoderCallback>& transcoderCallback)
-          : mOutputQueue(std::make_shared<MediaSampleQueue>()),
-            mTranscoderCallback(transcoderCallback){};
+          : mTranscoderCallback(transcoderCallback){};
 
     // Called by subclasses when the actual track format becomes available.
     void notifyTrackFormatAvailable();
 
+    // Called by subclasses when a transcoded sample is available.
+    void onOutputSampleAvailable(const std::shared_ptr<MediaSample>& sample);
+
     // configureDestinationFormat needs to be implemented by subclasses, and gets called on an
     // external thread before start.
     virtual media_status_t configureDestinationFormat(
@@ -110,12 +116,14 @@
     // be aborted as soon as possible. It should be safe to call abortTranscodeLoop multiple times.
     virtual void abortTranscodeLoop() = 0;
 
-    std::shared_ptr<MediaSampleQueue> mOutputQueue;
     std::shared_ptr<MediaSampleReader> mMediaSampleReader;
     int mTrackIndex;
     std::shared_ptr<AMediaFormat> mSourceFormat;
 
 private:
+    std::mutex mSampleMutex;
+    MediaSampleQueue mSampleQueue GUARDED_BY(mSampleMutex);
+    MediaSampleWriter::MediaSampleConsumerFunction mSampleConsumer GUARDED_BY(mSampleMutex);
     const std::weak_ptr<MediaTrackTranscoderCallback> mTranscoderCallback;
     std::mutex mStateMutex;
     std::thread mTranscodingThread GUARDED_BY(mStateMutex);
diff --git a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
index 8d96867..9a367ca 100644
--- a/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
+++ b/media/libmediatranscoding/transcoder/include/media/MediaTranscoder.h
@@ -138,7 +138,7 @@
 
     std::shared_ptr<CallbackInterface> mCallbacks;
     std::shared_ptr<MediaSampleReader> mSampleReader;
-    std::unique_ptr<MediaSampleWriter> mSampleWriter;
+    std::shared_ptr<MediaSampleWriter> mSampleWriter;
     std::vector<std::shared_ptr<AMediaFormat>> mSourceTrackFormats;
     std::vector<std::shared_ptr<MediaTrackTranscoder>> mTrackTranscoders;
     std::mutex mTracksAddedMutex;
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
index e8acd48..9c9c8b5 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleReaderNDKTests.cpp
@@ -26,6 +26,7 @@
 #include <gtest/gtest.h>
 #include <media/MediaSampleReaderNDK.h>
 #include <utils/Timers.h>
+
 #include <cmath>
 #include <mutex>
 #include <thread>
diff --git a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
index 64240d4..46f3e9b 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaSampleWriterTests.cpp
@@ -274,102 +274,95 @@
     void SetUp() override {
         LOG(DEBUG) << "MediaSampleWriterTests set up";
         mTestMuxer = std::make_shared<TestMuxer>();
-        mSampleQueue = std::make_shared<MediaSampleQueue>();
     }
 
     void TearDown() override {
         LOG(DEBUG) << "MediaSampleWriterTests tear down";
         mTestMuxer.reset();
-        mSampleQueue.reset();
     }
 
 protected:
     std::shared_ptr<TestMuxer> mTestMuxer;
-    std::shared_ptr<MediaSampleQueue> mSampleQueue;
     std::shared_ptr<TestCallbacks> mTestCallbacks = std::make_shared<TestCallbacks>();
 };
 
 TEST_F(MediaSampleWriterTests, TestAddTrackWithoutInit) {
     const TestMediaSource& mediaSource = getMediaSource();
 
-    MediaSampleWriter writer{};
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_EQ(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutInit) {
-    MediaSampleWriter writer{};
-    EXPECT_FALSE(writer.start());
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_FALSE(writer->start());
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutTracks) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
-    EXPECT_FALSE(writer.start());
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
+    EXPECT_FALSE(writer->start());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestAddInvalidTrack) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, nullptr));
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
-
-    const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_FALSE(writer.addTrack(nullptr, mediaSource.mTrackFormats[0]));
+    EXPECT_EQ(writer->addTrack(nullptr), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestDoubleStartStop) {
-    MediaSampleWriter writer{};
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
 
     std::shared_ptr<TestCallbacks> callbacks =
             std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
+    EXPECT_TRUE(writer->init(mTestMuxer, callbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    EXPECT_NE(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));
 
-    ASSERT_TRUE(writer.start());
-    EXPECT_FALSE(writer.start());
+    ASSERT_TRUE(writer->start());
+    EXPECT_FALSE(writer->start());
 
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
     EXPECT_TRUE(callbacks->hasFinished());
-    EXPECT_FALSE(writer.stop());
+    EXPECT_FALSE(writer->stop());
 }
 
 TEST_F(MediaSampleWriterTests, TestStopWithoutStart) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
+    EXPECT_NE(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(mediaSource.mTrackFormats[0].get()));
 
-    EXPECT_FALSE(writer.stop());
+    EXPECT_FALSE(writer->stop());
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::NoEvent);
 }
 
 TEST_F(MediaSampleWriterTests, TestStartWithoutCallback) {
-    MediaSampleWriter writer{};
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
 
     std::weak_ptr<MediaSampleWriter::CallbackInterface> unassignedWp;
-    EXPECT_FALSE(writer.init(mTestMuxer, unassignedWp));
+    EXPECT_FALSE(writer->init(mTestMuxer, unassignedWp));
 
     std::shared_ptr<MediaSampleWriter::CallbackInterface> unassignedSp;
-    EXPECT_FALSE(writer.init(mTestMuxer, unassignedSp));
+    EXPECT_FALSE(writer->init(mTestMuxer, unassignedSp));
 
     const TestMediaSource& mediaSource = getMediaSource();
-    EXPECT_FALSE(writer.addTrack(mSampleQueue, mediaSource.mTrackFormats[0]));
-    ASSERT_FALSE(writer.start());
+    EXPECT_EQ(writer->addTrack(mediaSource.mTrackFormats[0]), nullptr);
+    ASSERT_FALSE(writer->start());
 }
 
 TEST_F(MediaSampleWriterTests, TestProgressUpdate) {
     const TestMediaSource& mediaSource = getMediaSource();
 
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     std::shared_ptr<AMediaFormat> videoFormat =
             std::shared_ptr<AMediaFormat>(AMediaFormat_new(), &AMediaFormat_delete);
@@ -377,42 +370,41 @@
                       mediaSource.mTrackFormats[mediaSource.mVideoTrackIndex].get());
 
     AMediaFormat_setInt64(videoFormat.get(), AMEDIAFORMAT_KEY_DURATION, 100);
-    EXPECT_TRUE(writer.addTrack(mSampleQueue, videoFormat));
-    ASSERT_TRUE(writer.start());
+    auto sampleConsumer = writer->addTrack(videoFormat);
+    EXPECT_NE(sampleConsumer, nullptr);
+    ASSERT_TRUE(writer->start());
 
     for (int64_t pts = 0; pts < 100; ++pts) {
-        mSampleQueue->enqueue(newSampleWithPts(pts));
+        sampleConsumer(newSampleWithPts(pts));
     }
-    mSampleQueue->enqueue(newSampleEos());
+    sampleConsumer(newSampleEos());
     mTestCallbacks->waitForWritingFinished();
 
     EXPECT_EQ(mTestCallbacks->getProgressUpdateCount(), 100);
 }
 
 TEST_F(MediaSampleWriterTests, TestInterleaving) {
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(mTestMuxer, mTestCallbacks));
 
     // Use two tracks for this test.
     static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    std::vector<std::pair<std::shared_ptr<MediaSample>, size_t>> interleavedSamples;
+    MediaSampleWriter::MediaSampleConsumerFunction sampleConsumers[kNumTracks];
+    std::vector<std::pair<std::shared_ptr<MediaSample>, size_t>> addedSamples;
     const TestMediaSource& mediaSource = getMediaSource();
 
     for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
         auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
+        sampleConsumers[trackIdx] = writer->addTrack(trackFormat);
+        EXPECT_NE(sampleConsumers[trackIdx], nullptr);
         EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
     }
 
     // Create samples in the expected interleaved order for easy verification.
-    auto addSampleToTrackWithPts = [&interleavedSamples, &sampleQueues](int trackIndex,
-                                                                        int64_t pts) {
+    auto addSampleToTrackWithPts = [&addedSamples, &sampleConsumers](int trackIndex, int64_t pts) {
         auto sample = newSampleWithPts(pts);
-        sampleQueues[trackIndex]->enqueue(sample);
-        interleavedSamples.emplace_back(sample, trackIndex);
+        sampleConsumers[trackIndex](sample);
+        addedSamples.emplace_back(sample, trackIndex);
     };
 
     addSampleToTrackWithPts(0, 0);
@@ -431,18 +423,24 @@
     addSampleToTrackWithPts(1, 13);
 
     for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        sampleQueues[trackIndex]->enqueue(newSampleEos());
+        sampleConsumers[trackIndex](newSampleEos());
     }
 
     // Start the writer.
-    ASSERT_TRUE(writer.start());
+    ASSERT_TRUE(writer->start());
 
     // Wait for writer to complete.
     mTestCallbacks->waitForWritingFinished();
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
 
+    std::sort(addedSamples.begin(), addedSamples.end(),
+              [](const std::pair<std::shared_ptr<MediaSample>, size_t>& left,
+                 const std::pair<std::shared_ptr<MediaSample>, size_t>& right) {
+                  return left.first->info.presentationTimeUs < right.first->info.presentationTimeUs;
+              });
+
     // Verify sample order.
-    for (auto entry : interleavedSamples) {
+    for (auto entry : addedSamples) {
         auto sample = entry.first;
         auto trackIndex = entry.second;
 
@@ -470,162 +468,10 @@
     }
 
     EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
     EXPECT_TRUE(mTestCallbacks->hasFinished());
 }
 
-TEST_F(MediaSampleWriterTests, TestMaxDivergence) {
-    static constexpr uint32_t kMaxDivergenceUs = 10;
-
-    MediaSampleWriter writer{kMaxDivergenceUs};
-    EXPECT_TRUE(writer.init(mTestMuxer, mTestCallbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    ASSERT_TRUE(writer.start());
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::Start());
-
-    // The first samples of each track can be written in any order since the writer does not have
-    // any previous timestamps to compare.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(0));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1));
-    mTestMuxer->popEvent(true);
-    mTestMuxer->popEvent(true);
-
-    // The writer will now be waiting on track 0 since it has the lowest previous timestamp.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 1));
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 2));
-
-    // The writer should dequeue the first sample above but not the second since track 0 now is too
-    // far ahead. Instead it should wait for track 1.
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, kMaxDivergenceUs + 1));
-
-    // Enqueue a sample from track 1 that puts it within acceptable divergence range again. The
-    // writer should dequeue that sample and then go back to track 0 since track 1 is empty.
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, kMaxDivergenceUs));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, kMaxDivergenceUs + 2));
-
-    // Both tracks are now empty so the writer should wait for track 1 which is farthest behind.
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(kMaxDivergenceUs + 3));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, kMaxDivergenceUs + 3));
-
-    for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        sampleQueues[trackIndex]->enqueue(newSampleEos());
-    }
-
-    // Wait for writer to complete.
-    mTestCallbacks->waitForWritingFinished();
-
-    // Verify EOS samples.
-    for (int trackIndex = 0; trackIndex < kNumTracks; ++trackIndex) {
-        auto trackFormat = mediaSource.mTrackFormats[trackIndex % mediaSource.mTrackCount];
-        int64_t duration = 0;
-        AMediaFormat_getInt64(trackFormat.get(), AMEDIAFORMAT_KEY_DURATION, &duration);
-
-        // EOS timestamp = first sample timestamp + duration.
-        const int64_t endTime = duration + (trackIndex == 1 ? 1 : 0);
-        const AMediaCodecBufferInfo info = {0, 0, endTime, AMEDIACODEC_BUFFER_FLAG_END_OF_STREAM};
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::WriteSample(trackIndex, nullptr, &info));
-    }
-
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
-    EXPECT_TRUE(mTestCallbacks->hasFinished());
-}
-
-TEST_F(MediaSampleWriterTests, TestTimestampDivergenceOverflow) {
-    auto testCallbacks = std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(mTestMuxer, testCallbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    // Prime track 0 with lower end of INT64 range, and track 1 with positive timestamps making the
-    // difference larger than INT64_MAX.
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(INT64_MIN + 1));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1000));
-    sampleQueues[1]->enqueue(newSampleWithPtsOnly(1001));
-
-    ASSERT_TRUE(writer.start());
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::Start());
-
-    // The first sample of each track can be pulled in any order.
-    mTestMuxer->popEvent(true);
-    mTestMuxer->popEvent(true);
-
-    // Wait to make sure the writer compares track 0 empty against track 1 non-empty. The writer
-    // should handle the large timestamp differences and chose to wait for track 0 even though
-    // track 1 has a sample ready.
-    std::this_thread::sleep_for(std::chrono::milliseconds(20));
-
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(INT64_MIN + 2));
-    sampleQueues[0]->enqueue(newSampleWithPtsOnly(1000));  // <-- Close the gap between the tracks.
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, INT64_MIN + 2));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(0, 1000));
-    EXPECT_EQ(mTestMuxer->popEvent(true), TestMuxer::WriteSampleWithPts(1, 1001));
-
-    EXPECT_TRUE(writer.stop());
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(testCallbacks->hasFinished());
-}
-
-TEST_F(MediaSampleWriterTests, TestAbortInputQueue) {
-    MediaSampleWriter writer{};
-    std::shared_ptr<TestCallbacks> callbacks =
-            std::make_shared<TestCallbacks>(false /* expectSuccess */);
-    EXPECT_TRUE(writer.init(mTestMuxer, callbacks));
-
-    // Use two tracks for this test.
-    static constexpr int kNumTracks = 2;
-    std::shared_ptr<MediaSampleQueue> sampleQueues[kNumTracks];
-    const TestMediaSource& mediaSource = getMediaSource();
-
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx] = std::make_shared<MediaSampleQueue>();
-
-        auto trackFormat = mediaSource.mTrackFormats[trackIdx % mediaSource.mTrackCount];
-        EXPECT_TRUE(writer.addTrack(sampleQueues[trackIdx], trackFormat));
-        EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::AddTrack(trackFormat.get()));
-    }
-
-    // Start the writer.
-    ASSERT_TRUE(writer.start());
-
-    // Abort the input queues and wait for the writer to complete.
-    for (int trackIdx = 0; trackIdx < kNumTracks; ++trackIdx) {
-        sampleQueues[trackIdx]->abort();
-    }
-
-    callbacks->waitForWritingFinished();
-
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Start());
-    EXPECT_EQ(mTestMuxer->popEvent(), TestMuxer::Stop());
-    EXPECT_TRUE(writer.stop());
-}
-
 // Convenience function for reading a sample from an AMediaExtractor represented as a MediaSample.
 static std::shared_ptr<MediaSample> readSampleAndAdvance(AMediaExtractor* extractor,
                                                          size_t* trackIndexOut) {
@@ -667,36 +513,35 @@
     ASSERT_GT(destinationFd, 0);
 
     // Initialize writer.
-    MediaSampleWriter writer{};
-    EXPECT_TRUE(writer.init(destinationFd, mTestCallbacks));
+    std::shared_ptr<MediaSampleWriter> writer = MediaSampleWriter::Create();
+    EXPECT_TRUE(writer->init(destinationFd, mTestCallbacks));
     close(destinationFd);
 
     // Add tracks.
     const TestMediaSource& mediaSource = getMediaSource();
-    std::vector<std::shared_ptr<MediaSampleQueue>> inputQueues;
+    std::vector<MediaSampleWriter::MediaSampleConsumerFunction> sampleConsumers;
 
     for (size_t trackIndex = 0; trackIndex < mediaSource.mTrackCount; trackIndex++) {
-        inputQueues.push_back(std::make_shared<MediaSampleQueue>());
-        EXPECT_TRUE(
-                writer.addTrack(inputQueues[trackIndex], mediaSource.mTrackFormats[trackIndex]));
+        auto consumer = writer->addTrack(mediaSource.mTrackFormats[trackIndex]);
+        sampleConsumers.push_back(consumer);
     }
 
     // Start the writer.
-    ASSERT_TRUE(writer.start());
+    ASSERT_TRUE(writer->start());
 
     // Enqueue samples and finally End Of Stream.
     std::shared_ptr<MediaSample> sample;
     size_t trackIndex;
     while ((sample = readSampleAndAdvance(mediaSource.mExtractor, &trackIndex)) != nullptr) {
-        inputQueues[trackIndex]->enqueue(sample);
+        sampleConsumers[trackIndex](sample);
     }
     for (trackIndex = 0; trackIndex < mediaSource.mTrackCount; trackIndex++) {
-        inputQueues[trackIndex]->enqueue(newSampleEos());
+        sampleConsumers[trackIndex](newSampleEos());
     }
 
     // Wait for writer.
     mTestCallbacks->waitForWritingFinished();
-    EXPECT_TRUE(writer.stop());
+    EXPECT_TRUE(writer->stop());
 
     // Compare output file with source.
     mediaSource.reset();
diff --git a/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
index a46c2bd..83f0a4a 100644
--- a/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/MediaTrackTranscoderTests.cpp
@@ -60,7 +60,6 @@
             break;
         }
         ASSERT_NE(mTranscoder, nullptr);
-        mTranscoderOutputQueue = mTranscoder->getOutputQueue();
 
         initSampleReader();
     }
@@ -115,34 +114,29 @@
     }
 
     // Drains the transcoder's output queue in a loop.
-    void drainOutputSampleQueue() {
-        mSampleQueueDrainThread = std::thread{[this] {
-            std::shared_ptr<MediaSample> sample;
-            bool aborted = false;
-            do {
-                aborted = mTranscoderOutputQueue->dequeue(&sample);
-            } while (!aborted && !(sample->info.flags & SAMPLE_FLAG_END_OF_STREAM));
-            mQueueWasAborted = aborted;
-            mGotEndOfStream =
-                    sample != nullptr && (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0;
-        }};
+    void drainOutputSamples(int numSamplesToSave = 0) {
+        mTranscoder->setSampleConsumer(
+                [this, numSamplesToSave](const std::shared_ptr<MediaSample>& sample) {
+                    ASSERT_NE(sample, nullptr);
+
+                    mGotEndOfStream = (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0;
+
+                    if (mSavedSamples.size() < numSamplesToSave) {
+                        mSavedSamples.push_back(sample);
+                    }
+
+                    if (mSavedSamples.size() == numSamplesToSave || mGotEndOfStream) {
+                        mSamplesSavedSemaphore.signal();
+                    }
+                });
     }
 
-    void joinDrainThread() {
-        if (mSampleQueueDrainThread.joinable()) {
-            mSampleQueueDrainThread.join();
-        }
-    }
-    void TearDown() override {
-        LOG(DEBUG) << "MediaTrackTranscoderTests tear down";
-        joinDrainThread();
-    }
+    void TearDown() override { LOG(DEBUG) << "MediaTrackTranscoderTests tear down"; }
 
     ~MediaTrackTranscoderTests() { LOG(DEBUG) << "MediaTrackTranscoderTests destroyed"; }
 
 protected:
     std::shared_ptr<MediaTrackTranscoder> mTranscoder;
-    std::shared_ptr<MediaSampleQueue> mTranscoderOutputQueue;
     std::shared_ptr<TestCallback> mCallback;
 
     std::shared_ptr<MediaSampleReader> mMediaSampleReader;
@@ -151,8 +145,8 @@
     std::shared_ptr<AMediaFormat> mSourceFormat;
     std::shared_ptr<AMediaFormat> mDestinationFormat;
 
-    std::thread mSampleQueueDrainThread;
-    bool mQueueWasAborted = false;
+    std::vector<std::shared_ptr<MediaSample>> mSavedSamples;
+    OneShotSemaphore mSamplesSavedSemaphore;
     bool mGotEndOfStream = false;
 };
 
@@ -161,11 +155,9 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-    drainOutputSampleQueue();
+    drainOutputSamples();
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 }
 
@@ -229,49 +221,27 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-    drainOutputSampleQueue();
+    drainOutputSamples();
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
     EXPECT_FALSE(mTranscoder->start());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 }
 
-TEST_P(MediaTrackTranscoderTests, AbortOutputQueue) {
-    LOG(DEBUG) << "Testing AbortOutputQueue";
-    EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
-              AMEDIA_OK);
-    ASSERT_TRUE(mTranscoder->start());
-    mTranscoderOutputQueue->abort();
-    drainOutputSampleQueue();
-    EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_ERROR_IO);
-    joinDrainThread();
-    EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_TRUE(mQueueWasAborted);
-    EXPECT_FALSE(mGotEndOfStream);
-}
-
 TEST_P(MediaTrackTranscoderTests, HoldSampleAfterTranscoderRelease) {
     LOG(DEBUG) << "Testing HoldSampleAfterTranscoderRelease";
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-
-    std::shared_ptr<MediaSample> sample;
-    EXPECT_FALSE(mTranscoderOutputQueue->dequeue(&sample));
-
-    drainOutputSampleQueue();
+    drainOutputSamples(1 /* numSamplesToSave */);
     EXPECT_EQ(mCallback->waitUntilFinished(), AMEDIA_OK);
-    joinDrainThread();
     EXPECT_TRUE(mTranscoder->stop());
-    EXPECT_FALSE(mQueueWasAborted);
     EXPECT_TRUE(mGotEndOfStream);
 
     mTranscoder.reset();
-    mTranscoderOutputQueue.reset();
+
     std::this_thread::sleep_for(std::chrono::milliseconds(20));
-    sample.reset();
+    mSavedSamples.clear();
 }
 
 TEST_P(MediaTrackTranscoderTests, HoldSampleAfterTranscoderStop) {
@@ -279,13 +249,12 @@
     EXPECT_EQ(mTranscoder->configure(mMediaSampleReader, mTrackIndex, mDestinationFormat),
               AMEDIA_OK);
     ASSERT_TRUE(mTranscoder->start());
-
-    std::shared_ptr<MediaSample> sample;
-    EXPECT_FALSE(mTranscoderOutputQueue->dequeue(&sample));
+    drainOutputSamples(1 /* numSamplesToSave */);
+    mSamplesSavedSemaphore.wait();
     EXPECT_TRUE(mTranscoder->stop());
 
     std::this_thread::sleep_for(std::chrono::milliseconds(20));
-    sample.reset();
+    mSavedSamples.clear();
 }
 
 TEST_P(MediaTrackTranscoderTests, NullSampleReader) {
diff --git a/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
index a2ffbe4..9713e17 100644
--- a/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/PassthroughTrackTranscoderTests.cpp
@@ -165,21 +165,23 @@
     ASSERT_TRUE(transcoder.start());
 
     // Pull transcoder's output samples and compare against input checksums.
+    bool eos = false;
     uint64_t sampleCount = 0;
-    std::shared_ptr<MediaSample> sample;
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder.getOutputQueue();
-    while (!outputQueue->dequeue(&sample)) {
-        ASSERT_NE(sample, nullptr);
+    transcoder.setSampleConsumer(
+            [&sampleCount, &sampleChecksums, &eos](const std::shared_ptr<MediaSample>& sample) {
+                ASSERT_NE(sample, nullptr);
+                EXPECT_FALSE(eos);
 
-        if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-            break;
-        }
+                if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+                    eos = true;
+                } else {
+                    SampleID sampleId{sample->buffer, static_cast<ssize_t>(sample->info.size)};
+                    EXPECT_TRUE(sampleId == sampleChecksums[sampleCount]);
+                    ++sampleCount;
+                }
+            });
 
-        SampleID sampleId{sample->buffer, static_cast<ssize_t>(sample->info.size)};
-        EXPECT_TRUE(sampleId == sampleChecksums[sampleCount]);
-        ++sampleCount;
-    }
-
+    callback->waitUntilFinished();
     EXPECT_EQ(sampleCount, sampleChecksums.size());
     EXPECT_TRUE(transcoder.stop());
 }
diff --git a/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h b/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
index a3ddd71..8d05353 100644
--- a/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
+++ b/media/libmediatranscoding/transcoder/tests/TrackTranscoderTestUtils.h
@@ -102,4 +102,25 @@
     bool mTrackFormatAvailable = false;
 };
 
+class OneShotSemaphore {
+public:
+    void wait() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        while (!mSignaled) {
+            mCondition.wait(lock);
+        }
+    }
+
+    void signal() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        mSignaled = true;
+        mCondition.notify_all();
+    }
+
+private:
+    std::mutex mMutex;
+    std::condition_variable mCondition;
+    bool mSignaled = false;
+};
+
 };  // namespace android
diff --git a/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp b/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
index e809cbd..1b5bd13 100644
--- a/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
+++ b/media/libmediatranscoding/transcoder/tests/VideoTrackTranscoderTests.cpp
@@ -102,46 +102,40 @@
               AMEDIA_OK);
     ASSERT_TRUE(transcoder->start());
 
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder->getOutputQueue();
-    std::thread sampleConsumerThread{[&outputQueue] {
-        uint64_t sampleCount = 0;
-        std::shared_ptr<MediaSample> sample;
-        while (!outputQueue->dequeue(&sample)) {
-            ASSERT_NE(sample, nullptr);
-            const uint32_t flags = sample->info.flags;
+    bool eos = false;
+    uint64_t sampleCount = 0;
+    transcoder->setSampleConsumer([&sampleCount, &eos](const std::shared_ptr<MediaSample>& sample) {
+        ASSERT_NE(sample, nullptr);
+        const uint32_t flags = sample->info.flags;
 
-            if (sampleCount == 0) {
-                // Expect first sample to be a codec config.
-                EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) != 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_PARTIAL_FRAME) == 0);
-            } else if (sampleCount == 1) {
-                // Expect second sample to be a sync sample.
-                EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) == 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) != 0);
-                EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
-            }
-
-            if (!(flags & SAMPLE_FLAG_END_OF_STREAM)) {
-                // Expect a valid buffer unless it is EOS.
-                EXPECT_NE(sample->buffer, nullptr);
-                EXPECT_NE(sample->bufferId, 0xBAADF00D);
-                EXPECT_GT(sample->info.size, 0);
-            }
-
-            ++sampleCount;
-            if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-                break;
-            }
-            sample.reset();
+        if (sampleCount == 0) {
+            // Expect first sample to be a codec config.
+            EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) != 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_PARTIAL_FRAME) == 0);
+        } else if (sampleCount == 1) {
+            // Expect second sample to be a sync sample.
+            EXPECT_TRUE((flags & SAMPLE_FLAG_CODEC_CONFIG) == 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_SYNC_SAMPLE) != 0);
+            EXPECT_TRUE((flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
         }
-    }};
+
+        if (!(flags & SAMPLE_FLAG_END_OF_STREAM)) {
+            // Expect a valid buffer unless it is EOS.
+            EXPECT_NE(sample->buffer, nullptr);
+            EXPECT_NE(sample->bufferId, 0xBAADF00D);
+            EXPECT_GT(sample->info.size, 0);
+        } else {
+            EXPECT_FALSE(eos);
+            eos = true;
+        }
+
+        ++sampleCount;
+    });
 
     EXPECT_EQ(callback->waitUntilFinished(), AMEDIA_OK);
     EXPECT_TRUE(transcoder->stop());
-
-    sampleConsumerThread.join();
 }
 
 TEST_F(VideoTrackTranscoderTests, PreserveBitrate) {
@@ -167,7 +161,6 @@
     ASSERT_NE(outputFormat, nullptr);
 
     ASSERT_TRUE(transcoder->stop());
-    transcoder->getOutputQueue()->abort();
 
     int32_t outBitrate;
     EXPECT_TRUE(AMediaFormat_getInt32(outputFormat.get(), AMEDIAFORMAT_KEY_BIT_RATE, &outBitrate));
@@ -187,25 +180,7 @@
 }
 
 TEST_F(VideoTrackTranscoderTests, LingeringEncoder) {
-    struct {
-        void wait() {
-            std::unique_lock<std::mutex> lock(mMutex);
-            while (!mSignaled) {
-                mCondition.wait(lock);
-            }
-        }
-
-        void signal() {
-            std::unique_lock<std::mutex> lock(mMutex);
-            mSignaled = true;
-            mCondition.notify_all();
-        }
-
-        std::mutex mMutex;
-        std::condition_variable mCondition;
-        bool mSignaled = false;
-    } semaphore;
-
+    OneShotSemaphore semaphore;
     auto callback = std::make_shared<TestCallback>();
     auto transcoder = VideoTrackTranscoder::create(callback);
 
@@ -214,29 +189,24 @@
               AMEDIA_OK);
     ASSERT_TRUE(transcoder->start());
 
-    std::shared_ptr<MediaSampleQueue> outputQueue = transcoder->getOutputQueue();
     std::vector<std::shared_ptr<MediaSample>> samples;
-    std::thread sampleConsumerThread([&outputQueue, &samples, &semaphore] {
-        std::shared_ptr<MediaSample> sample;
-        while (samples.size() < 4 && !outputQueue->dequeue(&sample)) {
-            ASSERT_NE(sample, nullptr);
-            samples.push_back(sample);
+    transcoder->setSampleConsumer(
+            [&samples, &semaphore](const std::shared_ptr<MediaSample>& sample) {
+                if (samples.size() >= 4) return;
 
-            if (sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
-                break;
-            }
-            sample.reset();
-        }
+                ASSERT_NE(sample, nullptr);
+                samples.push_back(sample);
 
-        semaphore.signal();
-    });
+                if (samples.size() == 4 || sample->info.flags & SAMPLE_FLAG_END_OF_STREAM) {
+                    semaphore.signal();
+                }
+            });
 
     // Wait for the encoder to output samples before stopping and releasing the transcoder.
     semaphore.wait();
 
     EXPECT_TRUE(transcoder->stop());
     transcoder.reset();
-    sampleConsumerThread.join();
 
     // Return buffers to the codec so that it can resume processing, but keep one buffer to avoid
     // the codec being released.