sqlalchemy 是很强大的 orm 框架,一般的教程都是从新建一个表开始教,这种情况下我们能直接拿到表对象然后方便的操作,但是对于已经存在的表的操作,网上教程并不多,而实际上,对于测试同学来说,这种情况才是最普遍的。我在系统集成测试框架中数据库操作便使用了 sqlalchemy,这里分享一下,顺便记录踩过的一些不那么明显却很致命的坑。
首先,我们可以创建一个自定义对象,它可以去组合 sqlalchemy 提供的种种功能。

from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.ext.declarative import declared_attr, declarative_base
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy.orm import sessionmaker, class_mapper, Query

class DbRoot(object):

    def __init__(self, **kwargs):
        """
        orm基础db对象,通过实例化该对象得到db实例,然后创建类对象继承自db.Model,便可以对相应表进行操作
        :param kwargs: dialect 数据库类型
                        driver 数据库驱动
                        user 用户名
                        password 用户密码
                        host 数据库地址
                        port 端口
                        database 数据库名
        """
        url = '{dialect}+{driver}://{user}:{password}@{host}:{port}/{database}?charset=utf8'.format(**kwargs)
        engine = create_engine(url, echo=False)

        class Base(object):

            @declared_attr
            def __table__(cls):
                return Table(cls.__tablename__, MetaData(), autoload=True, autoload_with=engine)

        self._base = Base
        self.Model = self.make_declarative_base()
        self.session = sessionmaker(bind=engine)

    def make_declarative_base(self):
        base = declarative_base(cls=self._base)
        base.query = _QueryProperty(self)
        base.query_class = Query
        return base

如代码所示,DbRoot 对象接收数据库 url 的各项参数,然后创建引擎,组合 declarative_base 方法生成的 Model 对象 (Model 对象是 Base 对象的子类),以及通过 sessionmaker 方法生成的 session 数据库会话对象 (实例化它便可以产生一个数据库会话,相当于操作句柄)。
然后在 make_declarative_base 方法中,Model 对象还组合了_QueryProperty,也就是 query 属性,用于查询操作,_QueryProperty是一个描述符对象,见下面代码。这样,我们在做比如 query.filter_by 等操作时,都是转给了__get__方法去执行,该方法最终调用了 t 也就是 Model 对象 query_class 方法 (其实就是 sqlalchemy.orm 提供的 Query 方法) 做查询,这里尤其注意,session 参数传的是 DbRoot 对象的 session 属性,该属性一定要在这里实例化,这样每次操作才能重新生成一个 session 会话。

class _QueryProperty(object):

    def __init__(self, sa):
        self.sa = sa

    def __get__(self, obj, t):
        """
        这里一定要注意,session要每次重新生成,不然session会话会自动关闭,导致下一次操作句柄为空
        :param obj:
        :param t:
        :return:
        """
        try:
            mapper = class_mapper(t)
            if mapper:
                return t.query_class(mapper, session=self.sa.session())
        except UnmappedClassError:
            return None

现在,我们等于有了 db.Model 对象,然后我们要去映射表生成表对象。其实就是创建一个类继承自 DbRoot.Model,然后将__tablename__属性设置为想要生成的表对象表名,这样便大功告成了。我这里采用了动态生成类的方式。

def gen_orm_class( db=None, table_name=None):
    """
    动态生成数据库表映射Model类
    :param db: db对象
    :param table_name: 表名称
    :return:
    """
    return type(
        table_name.title(),
        (db.Model,),
        {
            '__tablename__': table_name
        }
    )

有了这个基础,如果结合 pytest 使用的话会更方便,pytest 有个很牛的功能叫 fixture,我们可以把初始化数据库的操作做出一个 scope 为 module 的 fixture

@pytest.fixture(scope='module')
def mysql(request, config_init):
    """
    mysql数据库操作实例
    :param request:
    :param config_init:
    :return:
    """
    db_roots = {}
    mysql_conf = config_init.get('mysql')
    databases = mysql_conf.pop('databases')
    dbs = request.module.config.get('mysql_dbs', {})

    for db, table in dbs.items():
        db_conf = databases.get(db)
        mysql_conf.update(db_conf)
        db_roots.update({table: gen_orm_class(db_name=db, db=DbRoot(**mysql_conf), table_name=table)})

    return db_roots

然后在测试用例中配置数据库连接信息。

# test_case.py
config = {
    'mysql_dbs': {
        'dev': 'user'
    }
}

这样在进入该模块 (test_case.py) 开始测试之前,pytest 会根据配置自动初始化数据库连接生成{‘表名’: ‘DbRoot表对象’}结构的字典,之后在测试代码中,就可以通过 mysql.get('user') 拿到 User 表对象,从而对 dev.User 表进行查询操作。

def test_user(self, mysql, user_id):
    user_tbl = mysql.get('user')
    # id是主键,所以可以通过get方法直接查询
    assert user_tbl.query.get(user_id),'该用户{}不存在'.format(user_id)


↙↙↙阅读原文可查看相关链接,并与作者交流