diff --git a/drivers/gpio/gpio-omap.c b/drivers/gpio/gpio-omap.c
index 6788c8a..a1a3b9d 100644
--- a/drivers/gpio/gpio-omap.c
+++ b/drivers/gpio/gpio-omap.c
@@ -72,9 +72,11 @@
 	bool loses_context;
 	int stride;
 	u32 width;
+	int context_loss_count;
 	u16 id;
 
 	void (*set_dataout)(struct gpio_bank *bank, int gpio, int enable);
+	int (*get_context_loss_count)(struct device *dev);
 
 	struct omap_gpio_reg_offs *regs;
 };
@@ -1179,6 +1181,7 @@
 	bank->stride = pdata->bank_stride;
 	bank->width = pdata->bank_width;
 	bank->loses_context = pdata->loses_context;
+	bank->get_context_loss_count = pdata->get_context_loss_count;
 	bank->regs = pdata->regs;
 
 	if (bank->regs->set_dataout && bank->regs->clr_dataout)
@@ -1323,11 +1326,11 @@
 
 #ifdef CONFIG_ARCH_OMAP2PLUS
 
-static int workaround_enabled;
+static void omap_gpio_save_context(struct gpio_bank *bank);
+static void omap_gpio_restore_context(struct gpio_bank *bank);
 
 void omap2_gpio_prepare_for_idle(int off_mode)
 {
-	int c = 0;
 	struct gpio_bank *bank;
 
 	list_for_each_entry(bank, &omap_gpio_list, node) {
@@ -1347,7 +1350,7 @@
 		 * non-wakeup GPIOs.  Otherwise spurious IRQs will be
 		 * generated.  See OMAP2420 Errata item 1.101. */
 		if (!(bank->enabled_non_wakeup_gpios))
-			continue;
+			goto save_gpio_context;
 
 		if (cpu_is_omap24xx() || cpu_is_omap34xx()) {
 			bank->saved_datain = __raw_readl(bank->base +
@@ -1384,13 +1387,13 @@
 			__raw_writel(l2, bank->base + OMAP4_GPIO_RISINGDETECT);
 		}
 
-		c++;
+save_gpio_context:
+		if (bank->get_context_loss_count)
+			bank->context_loss_count =
+				bank->get_context_loss_count(bank->dev);
+
+		omap_gpio_save_context(bank);
 	}
-	if (!c) {
-		workaround_enabled = 0;
-		return;
-	}
-	workaround_enabled = 1;
 }
 
 void omap2_gpio_resume_after_idle(void)
@@ -1398,6 +1401,7 @@
 	struct gpio_bank *bank;
 
 	list_for_each_entry(bank, &omap_gpio_list, node) {
+		int context_lost_cnt_after;
 		u32 l = 0, gen, gen0, gen1;
 		int j;
 
@@ -1407,8 +1411,13 @@
 		for (j = 0; j < hweight_long(bank->dbck_enable_mask); j++)
 			clk_enable(bank->dbck);
 
-		if (!workaround_enabled)
-			continue;
+		if (bank->get_context_loss_count) {
+			context_lost_cnt_after =
+				bank->get_context_loss_count(bank->dev);
+			if (context_lost_cnt_after != bank->context_loss_count
+				|| !context_lost_cnt_after)
+				omap_gpio_restore_context(bank);
+		}
 
 		if (!(bank->enabled_non_wakeup_gpios))
 			continue;
@@ -1486,74 +1495,50 @@
 			}
 		}
 	}
-
 }
 
-#endif
-
-#ifdef CONFIG_ARCH_OMAP3
-void omap_gpio_save_context(void)
+static void omap_gpio_save_context(struct gpio_bank *bank)
 {
-	struct gpio_bank *bank;
-
-	list_for_each_entry(bank, &omap_gpio_list, node) {
-
-		if (!bank->loses_context)
-			continue;
-
-		bank->context.irqenable1 =
-			__raw_readl(bank->base + OMAP24XX_GPIO_IRQENABLE1);
-		bank->context.irqenable2 =
-			__raw_readl(bank->base + OMAP24XX_GPIO_IRQENABLE2);
-		bank->context.wake_en =
-			__raw_readl(bank->base + OMAP24XX_GPIO_WAKE_EN);
-		bank->context.ctrl =
-			__raw_readl(bank->base + OMAP24XX_GPIO_CTRL);
-		bank->context.oe =
-			__raw_readl(bank->base + OMAP24XX_GPIO_OE);
-		bank->context.leveldetect0 =
-			__raw_readl(bank->base + OMAP24XX_GPIO_LEVELDETECT0);
-		bank->context.leveldetect1 =
-			__raw_readl(bank->base + OMAP24XX_GPIO_LEVELDETECT1);
-		bank->context.risingdetect =
-			__raw_readl(bank->base + OMAP24XX_GPIO_RISINGDETECT);
-		bank->context.fallingdetect =
-			__raw_readl(bank->base + OMAP24XX_GPIO_FALLINGDETECT);
-		bank->context.dataout =
-			__raw_readl(bank->base + OMAP24XX_GPIO_DATAOUT);
-	}
+	bank->context.irqenable1 =
+		__raw_readl(bank->base + OMAP24XX_GPIO_IRQENABLE1);
+	bank->context.irqenable2 =
+		__raw_readl(bank->base + OMAP24XX_GPIO_IRQENABLE2);
+	bank->context.wake_en =
+		__raw_readl(bank->base + OMAP24XX_GPIO_WAKE_EN);
+	bank->context.ctrl = __raw_readl(bank->base + OMAP24XX_GPIO_CTRL);
+	bank->context.oe = __raw_readl(bank->base + OMAP24XX_GPIO_OE);
+	bank->context.leveldetect0 =
+		__raw_readl(bank->base + OMAP24XX_GPIO_LEVELDETECT0);
+	bank->context.leveldetect1 =
+		__raw_readl(bank->base + OMAP24XX_GPIO_LEVELDETECT1);
+	bank->context.risingdetect =
+		__raw_readl(bank->base + OMAP24XX_GPIO_RISINGDETECT);
+	bank->context.fallingdetect =
+		__raw_readl(bank->base + OMAP24XX_GPIO_FALLINGDETECT);
+	bank->context.dataout =
+		__raw_readl(bank->base + OMAP24XX_GPIO_DATAOUT);
 }
 
-void omap_gpio_restore_context(void)
+static void omap_gpio_restore_context(struct gpio_bank *bank)
 {
-	struct gpio_bank *bank;
-
-	list_for_each_entry(bank, &omap_gpio_list, node) {
-
-		if (!bank->loses_context)
-			continue;
-
-		__raw_writel(bank->context.irqenable1,
-				bank->base + OMAP24XX_GPIO_IRQENABLE1);
-		__raw_writel(bank->context.irqenable2,
-				bank->base + OMAP24XX_GPIO_IRQENABLE2);
-		__raw_writel(bank->context.wake_en,
-				bank->base + OMAP24XX_GPIO_WAKE_EN);
-		__raw_writel(bank->context.ctrl,
-				bank->base + OMAP24XX_GPIO_CTRL);
-		__raw_writel(bank->context.oe,
-				bank->base + OMAP24XX_GPIO_OE);
-		__raw_writel(bank->context.leveldetect0,
-				bank->base + OMAP24XX_GPIO_LEVELDETECT0);
-		__raw_writel(bank->context.leveldetect1,
-				bank->base + OMAP24XX_GPIO_LEVELDETECT1);
-		__raw_writel(bank->context.risingdetect,
-				bank->base + OMAP24XX_GPIO_RISINGDETECT);
-		__raw_writel(bank->context.fallingdetect,
-				bank->base + OMAP24XX_GPIO_FALLINGDETECT);
-		__raw_writel(bank->context.dataout,
-				bank->base + OMAP24XX_GPIO_DATAOUT);
-	}
+	__raw_writel(bank->context.irqenable1,
+			bank->base + OMAP24XX_GPIO_IRQENABLE1);
+	__raw_writel(bank->context.irqenable2,
+			bank->base + OMAP24XX_GPIO_IRQENABLE2);
+	__raw_writel(bank->context.wake_en,
+			bank->base + OMAP24XX_GPIO_WAKE_EN);
+	__raw_writel(bank->context.ctrl, bank->base + OMAP24XX_GPIO_CTRL);
+	__raw_writel(bank->context.oe, bank->base + OMAP24XX_GPIO_OE);
+	__raw_writel(bank->context.leveldetect0,
+			bank->base + OMAP24XX_GPIO_LEVELDETECT0);
+	__raw_writel(bank->context.leveldetect1,
+			bank->base + OMAP24XX_GPIO_LEVELDETECT1);
+	__raw_writel(bank->context.risingdetect,
+			bank->base + OMAP24XX_GPIO_RISINGDETECT);
+	__raw_writel(bank->context.fallingdetect,
+			bank->base + OMAP24XX_GPIO_FALLINGDETECT);
+	__raw_writel(bank->context.dataout,
+			bank->base + OMAP24XX_GPIO_DATAOUT);
 }
 #endif
 
