x86/amd-iommu: Support higher level PTEs in iommu_page_unmap

This patch changes fetch_pte and iommu_page_unmap to support
different page sizes too.

Signed-off-by: Joerg Roedel <joerg.roedel@amd.com>
diff --git a/arch/x86/kernel/amd_iommu.c b/arch/x86/kernel/amd_iommu.c
index addf658..002cf9c 100644
--- a/arch/x86/kernel/amd_iommu.c
+++ b/arch/x86/kernel/amd_iommu.c
@@ -62,7 +62,7 @@
 				      unsigned long start_page,
 				      unsigned int pages);
 static u64 *fetch_pte(struct protection_domain *domain,
-		      unsigned long address);
+		      unsigned long address, int map_size);
 static void update_domain(struct protection_domain *domain);
 
 #ifndef BUS_NOTIFY_UNBOUND_DRIVER
@@ -552,9 +552,9 @@
 }
 
 static void iommu_unmap_page(struct protection_domain *dom,
-			     unsigned long bus_addr)
+			     unsigned long bus_addr, int map_size)
 {
-	u64 *pte = fetch_pte(dom, bus_addr);
+	u64 *pte = fetch_pte(dom, bus_addr, map_size);
 
 	if (pte)
 		*pte = 0;
@@ -668,7 +668,7 @@
  * there is one, it returns the pointer to it.
  */
 static u64 *fetch_pte(struct protection_domain *domain,
-		      unsigned long address)
+		      unsigned long address, int map_size)
 {
 	int level;
 	u64 *pte;
@@ -676,7 +676,7 @@
 	level =  domain->mode - 1;
 	pte   = &domain->pt_root[PM_LEVEL_INDEX(level, address)];
 
-	while (level > 0) {
+	while (level > map_size) {
 		if (!IOMMU_PTE_PRESENT(*pte))
 			return NULL;
 
@@ -684,6 +684,11 @@
 
 		pte = IOMMU_PTE_PAGE(*pte);
 		pte = &pte[PM_LEVEL_INDEX(level, address)];
+
+		if ((PM_PTE_LEVEL(*pte) == 0) && level != map_size) {
+			pte = NULL;
+			break;
+		}
 	}
 
 	return pte;
@@ -757,7 +762,7 @@
 	for (i = dma_dom->aperture[index]->offset;
 	     i < dma_dom->aperture_size;
 	     i += PAGE_SIZE) {
-		u64 *pte = fetch_pte(&dma_dom->domain, i);
+		u64 *pte = fetch_pte(&dma_dom->domain, i, PM_MAP_4k);
 		if (!pte || !IOMMU_PTE_PRESENT(*pte))
 			continue;
 
@@ -2192,7 +2197,7 @@
 	iova  &= PAGE_MASK;
 
 	for (i = 0; i < npages; ++i) {
-		iommu_unmap_page(domain, iova);
+		iommu_unmap_page(domain, iova, PM_MAP_4k);
 		iova  += PAGE_SIZE;
 	}
 
@@ -2207,7 +2212,7 @@
 	phys_addr_t paddr;
 	u64 *pte;
 
-	pte = fetch_pte(domain, iova);
+	pte = fetch_pte(domain, iova, PM_MAP_4k);
 
 	if (!pte || !IOMMU_PTE_PRESENT(*pte))
 		return 0;