diff --git a/tests/__init__.py b/tests/__init__.py
index 47ea5e8..547e947 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,32 +1,38 @@
 import os
 import sys
-import unittest
 import warnings
 
 sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 '../src'))
 
-class CatchWarningsMixin(object):
-    class assertWarns(object):
-        def __init__(self, warningtype, msg=''):
-            self.warningtype = warningtype
-            warnings.filterwarnings('error')
-            self.failureException = unittest.TestCase.failureException
+class _AssertWarnsContext(object):
+    def __init__(self, warningtype, testcase, msg=''):
+        self.warningtype = warningtype
+        warnings.filterwarnings('error')
+        self.failureException = testcase.failureException
+
+    def __enter__(self):
+        return self
 
-        def __enter__(self):
-            return self
+    def __exit__(self, exc_type, exc_value, tb):
+        if exc_type is None:
+            try:
+                exc_name = self.warningtype.__name__
+            except AttributeError:
+                exc_name = str(self.warningtype)
+            raise self.failureException(
+                "{0} not raised".format(exc_name))
 
-        def __exit__(self, exc_type, exc_value, tb):
-            if exc_type is None:
-                try:
-                    exc_name = self.warningtype.__name__
-                except AttributeError:
-                    exc_name = str(self.warningtype)
-                raise self.failureException(
-                    "{0} not raised".format(exc_name))
+        if not issubclass(exc_type, self.warningtype):
+            raise self.failureException('"%s" does not match "%s"' %
+                    (self.warningtype.__name__, str(exc_type.__name__)))
 
-            if not issubclass(exc_type, self.warningtype):
-                raise self.failureException('"%s" does not match "%s"' %
-                     (self.warningtype.__name__, str(exc_type.__name__)))
+        return True
 
-            return True
+class CatchWarningsMixin(object):
+    def assertWarns(self, wrnClass, callableObj=None, *args, **kwargs):
+        context = _AssertWarnsContext(wrnClass, self)
+        if callableObj is None:
+            return context
+        with context:
+            callableObj(*args, **kwargs)