MTP: add strict bounds checking for all incoming packets

Previously we did not sanity check incoming MTP packets,
which could result in crashes due to reading off the edge of a packet.
Now all MTP packet getter functions return a boolean result
(true for OK, false for reading off the edge of the packet)
and we now return errors for malformed packets.

Bug: 18113092
Change-Id: Ic7623ee96f00652bdfb4f66acb16a93db5a1c105
diff --git a/media/mtp/MtpDataPacket.cpp b/media/mtp/MtpDataPacket.cpp
index e6e19e3..052b700 100644
--- a/media/mtp/MtpDataPacket.cpp
+++ b/media/mtp/MtpDataPacket.cpp
@@ -51,104 +51,178 @@
     MtpPacket::putUInt32(MTP_CONTAINER_TRANSACTION_ID_OFFSET, id);
 }
 
-uint16_t MtpDataPacket::getUInt16() {
-    int offset = mOffset;
-    uint16_t result = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8);
-    mOffset += 2;
-    return result;
+bool MtpDataPacket::getUInt8(uint8_t& value) {
+    if (mPacketSize - mOffset < sizeof(value))
+        return false;
+    value = mBuffer[mOffset++];
+    return true;
 }
 
-uint32_t MtpDataPacket::getUInt32() {
+bool MtpDataPacket::getUInt16(uint16_t& value) {
+    if (mPacketSize - mOffset < sizeof(value))
+        return false;
     int offset = mOffset;
-    uint32_t result = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) |
+    value = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8);
+    mOffset += sizeof(value);
+    return true;
+}
+
+bool MtpDataPacket::getUInt32(uint32_t& value) {
+    if (mPacketSize - mOffset < sizeof(value))
+        return false;
+    int offset = mOffset;
+    value = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) |
            ((uint32_t)mBuffer[offset + 2] << 16)  | ((uint32_t)mBuffer[offset + 3] << 24);
-    mOffset += 4;
-    return result;
+    mOffset += sizeof(value);
+    return true;
 }
 
-uint64_t MtpDataPacket::getUInt64() {
+bool MtpDataPacket::getUInt64(uint64_t& value) {
+    if (mPacketSize - mOffset < sizeof(value))
+        return false;
     int offset = mOffset;
-    uint64_t result = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) |
+    value = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) |
            ((uint64_t)mBuffer[offset + 2] << 16) | ((uint64_t)mBuffer[offset + 3] << 24) |
            ((uint64_t)mBuffer[offset + 4] << 32) | ((uint64_t)mBuffer[offset + 5] << 40) |
            ((uint64_t)mBuffer[offset + 6] << 48)  | ((uint64_t)mBuffer[offset + 7] << 56);
-    mOffset += 8;
-    return result;
+    mOffset += sizeof(value);
+    return true;
 }
 
-void MtpDataPacket::getUInt128(uint128_t& value) {
-    value[0] = getUInt32();
-    value[1] = getUInt32();
-    value[2] = getUInt32();
-    value[3] = getUInt32();
+bool MtpDataPacket::getUInt128(uint128_t& value) {
+    return getUInt32(value[0]) && getUInt32(value[1]) && getUInt32(value[2]) && getUInt32(value[3]);
 }
 
-void MtpDataPacket::getString(MtpStringBuffer& string)
+bool MtpDataPacket::getString(MtpStringBuffer& string)
 {
-    string.readFromPacket(this);
+    return string.readFromPacket(this);
 }
 
 Int8List* MtpDataPacket::getAInt8() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     Int8List* result = new Int8List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getInt8());
+    for (uint32_t i = 0; i < count; i++) {
+        int8_t value;
+        if (!getInt8(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 UInt8List* MtpDataPacket::getAUInt8() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     UInt8List* result = new UInt8List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getUInt8());
+    for (uint32_t i = 0; i < count; i++) {
+        uint8_t value;
+        if (!getUInt8(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 Int16List* MtpDataPacket::getAInt16() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     Int16List* result = new Int16List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getInt16());
+    for (uint32_t i = 0; i < count; i++) {
+        int16_t value;
+        if (!getInt16(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 UInt16List* MtpDataPacket::getAUInt16() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     UInt16List* result = new UInt16List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getUInt16());
+    for (uint32_t i = 0; i < count; i++) {
+        uint16_t value;
+        if (!getUInt16(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 Int32List* MtpDataPacket::getAInt32() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     Int32List* result = new Int32List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getInt32());
+    for (uint32_t i = 0; i < count; i++) {
+        int32_t value;
+        if (!getInt32(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 UInt32List* MtpDataPacket::getAUInt32() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     UInt32List* result = new UInt32List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getUInt32());
+    for (uint32_t i = 0; i < count; i++) {
+        uint32_t value;
+        if (!getUInt32(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 Int64List* MtpDataPacket::getAInt64() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     Int64List* result = new Int64List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getInt64());
+    for (uint32_t i = 0; i < count; i++) {
+        int64_t value;
+        if (!getInt64(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }
 
 UInt64List* MtpDataPacket::getAUInt64() {
+    uint32_t count;
+    if (!getUInt32(count))
+        return NULL;
     UInt64List* result = new UInt64List;
-    int count = getUInt32();
-    for (int i = 0; i < count; i++)
-        result->push(getUInt64());
+    for (uint32_t i = 0; i < count; i++) {
+        uint64_t value;
+        if (!getUInt64(value)) {
+            delete result;
+            return NULL;
+        }
+        result->push(value);
+    }
     return result;
 }