if PARAM_DISABLE_ROLLBACK in params:
disable_rollback = params.get(PARAM_DISABLE_ROLLBACK)
- if disable_rollback in (True, False):
- kwargs[PARAM_DISABLE_ROLLBACK] = disable_rollback
+ if str(disable_rollback).lower() == 'true':
+ kwargs[PARAM_DISABLE_ROLLBACK] = True
+ elif str(disable_rollback).lower() == 'false':
+ kwargs[PARAM_DISABLE_ROLLBACK] = False
else:
raise ValueError("Unexpected value for parameter %s : %s" %
(PARAM_DISABLE_ROLLBACK, disable_rollback))
self.assertTrue('disable_rollback' in args)
self.assertTrue(args.get('disable_rollback'))
+ args = api.extract_args({'disable_rollback': 'True'})
+ self.assertTrue('disable_rollback' in args)
+ self.assertTrue(args.get('disable_rollback'))
+
+ args = api.extract_args({'disable_rollback': 'true'})
+ self.assertTrue('disable_rollback' in args)
+ self.assertTrue(args.get('disable_rollback'))
+
def test_disable_rollback_extract_false(self):
args = api.extract_args({'disable_rollback': False})
self.assertTrue('disable_rollback' in args)
self.assertFalse(args.get('disable_rollback'))
+ args = api.extract_args({'disable_rollback': 'False'})
+ self.assertTrue('disable_rollback' in args)
+ self.assertFalse(args.get('disable_rollback'))
+
+ args = api.extract_args({'disable_rollback': 'false'})
+ self.assertTrue('disable_rollback' in args)
+ self.assertFalse(args.get('disable_rollback'))
+
def test_disable_rollback_extract_bad(self):
self.assertRaises(ValueError, api.extract_args,
{'disable_rollback': 'bad'})