@ -15,9 +15,10 @@
import configparser
import os
import pkg_resources
from unittest import mock
import importlib . metadata as importlib_metadata
import trove
from trove . common import extensions
from trove . extensions . routes . mgmt import Mgmt
@ -29,11 +30,12 @@ DEFAULT_EXTENSION_MAP = {
' MYSQL ' : [ Mysql , extensions . ExtensionDescriptor ]
}
EP_TEXT = '''
mgmt = trove . extensions . routes . mgmt : Mgmt
mysql = trove . extensions . routes . mysql : Mysql
invalid = trove . tests . unittests . api . common . test_extensions : InvalidExtension
'''
INVALID_EXTENSION_MAP = {
' mgmt ' : ' trove.extensions.routes.mgmt:Mgmt ' ,
' mysql ' : ' trove.extensions.routes.mysql:Mysql ' ,
' invalid ' : ' trove.tests.unittests.api.common. '
' test_extensions:InvalidExtension '
}
class InvalidExtension ( object ) :
@ -68,8 +70,8 @@ class TestExtensionLoading(trove_testtools.TestCase):
for clazz in DEFAULT_EXTENSION_MAP [ alias ] :
self . assertIsInstance ( ext , clazz , " Improper extension class " )
@mock.patch ( " pkg_resources.iter _entry_points" )
def test_default_extensions ( self , mock_ iter_ep s) :
@mock.patch ( " stevedore.enabled.EnabledExtensionManager.list _entry_points" )
def test_default_extensions ( self , mock_ extension s) :
trove_base = os . path . abspath ( os . path . join (
os . path . dirname ( trove . __file__ ) , " .. " ) )
setup_path = " %s /setup.cfg " % trove_base
@ -79,20 +81,31 @@ class TestExtensionLoading(trove_testtools.TestCase):
parser . read ( setup_path )
entry_points = parser . get (
' entry_points ' , extensions . ExtensionManager . EXT_NAMESPACE )
eps = pkg_resources . EntryPoint . parse_group ( ' plugins ' , entry_points )
mock_iter_eps . return_value = eps . values ( )
test_extensions = list ( )
for entry in entry_points . split ( ' \n ' ) [ 1 : ] :
name = entry . split ( " = " ) [ 0 ] . strip ( )
value = entry . split ( " = " ) [ 1 ] . strip ( )
test_extensions . append ( importlib_metadata . EntryPoint (
name = name ,
value = value ,
group = extensions . ExtensionManager . EXT_NAMESPACE ) )
mock_extensions . return_value = test_extensions
extension_mgr = extensions . ExtensionManager ( )
self . assertEqual ( sorted ( DEFAULT_EXTENSION_MAP . keys ( ) ) ,
sorted ( extension_mgr . extensions . keys ( ) ) ,
" Invalid extension names " )
self . _assert_default_extensions ( extension_mgr . extensions )
@mock.patch ( " pkg_resources.iter_entry_points " )
def test_invalid_extension ( self , mock_iter_eps ) :
eps = pkg_resources . EntryPoint . parse_group ( ' mock ' , EP_TEXT )
mock_iter_eps . return_value = eps . values ( )
@mock.patch ( " stevedore.enabled.EnabledExtensionManager.list_entry_points " )
def test_invalid_extension ( self , mock_extensions ) :
test_extensions = list ( )
for k , v in INVALID_EXTENSION_MAP . items ( ) :
test_extensions . append ( importlib_metadata . EntryPoint (
name = k ,
value = v ,
group = extensions . ExtensionManager . EXT_NAMESPACE ) )
mock_extensions . return_value = test_extensions
extension_mgr = extensions . ExtensionManager ( )
self . assertEqual ( len ( DEFAULT_EXTENSION_MAP . keys ( ) ) ,
len ( extension_mgr . extensions ) ,
self . assertEqual ( 2 , len ( extension_mgr . extensions ) ,
" Loaded invalid extensions " )
self . _assert_default_extensions ( extension_mgr . extensions )