美文网首页python进阶Flask实践
博客系列:对SQLAlchemy进行改写。可直接拿到别的项目中使

博客系列:对SQLAlchemy进行改写。可直接拿到别的项目中使

作者: 我的昵称很霸气 | 来源:发表于2018-08-21 16:27 被阅读8次
    • model这一块用的是SQLAlchemy
    • 继承重写Query查询

    第一个:Query的改写

    # 所谓的改写并不是在源码中改写,而是继承之后重写这个方法
    # 在flask_app>orm>base.py 
    from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
    from sqlalchemy import inspect, Column, Integer, SmallInteger, orm
    from contextlib import contextmanager
    
    from common.error import NotFoundError
    from flask_app.orm import transfer
    
    class SQLAlchemy(_SQLAlchemy):
        @contextmanager
        def auto_commit(self):
            try:
                yield
                self.session.commit()
            except Exception as e:
                db.session.rollback()
                raise e
    
    class Query(BaseQuery):
        def filter_by(self, **kwargs):
            # 这个是每个model中都加入了status。只有等于0才会被找到
            if 'status' not in kwargs.keys():
                kwargs['status'] = 0
            return super(Query, self).filter_by(**kwargs)
    
        def get_or_404(self, ident):
            rv = self.get(ident)
            if not rv:
                raise NotFoundError(msg="数据不存在")
            return rv
    
        def first_or_404(self):
            rv = self.first()
            if not rv:
                raise NotFoundError(msg="数据不存在")
            return rv
    
    db = SQLAlchemy(session_options={'autocommit': True},query_class=Query)
    
    # 这样用的话  找不到直接报出异常,很好使用
    

    第二个:Model改写

    class BaseModel(db.Model):
        __abstract__ = True
    
        def insert(self):
            self._before_insert()
            self.try_to_add_ip()
            self.try_to_add_device_info()
            db.session.add(self)
            db.session.flush()
            self._after_insert()
            return self
    
        def update(self):
            self._before_update()
            db.session.merge(self)
            db.session.flush()
            self._after_update()
            return self
    
        def delete(self):
            self._before_delete()
            db.session.delete(self)
            db.session.flush()
            self._after_delete()
    
        def _before_insert(self):
            pass
    
        def _after_insert(self):
            pass
    
        def _before_update(self):
            pass
    
        def _after_update(self):
            pass
    
        def _before_delete(self):
            pass
    
        def _after_delete(self):
            pass
    
        @classmethod
        def load_all_data_field(cls):
            """
            获取类自身所有数据表映射字段名
            :return:
            """
            if hasattr(cls, '__table__'):
                return [c.name for c in cls.__table__.columns]
    
        def try_to_add_ip(self):
            ip_column = 'ip'
            # 检查是否有ip这个field,如果有并且没有值,则从flask request对象里面取
            if ip_column in self.load_all_data_field():
                ip_val = getattr(self,ip_column)
                if not ip_val:
                    from flask import request
                    setattr(self,ip_column,request.headers.get('X-Forwarded-For', None) or request.remote_addr)
    
        def try_to_add_device_info(self):
            device_infos = ['imei', 'mac']
            # 检查是否有ip这个field,如果有并且没有值,则从flask request对象里面取
            for info in device_infos:
                if info in self.load_all_data_field():
                    info_val = getattr(self, info)
                    if not info_val:
                        from flask import request
    
                        info_val = request.args.get(info, "")
                        if info_val:
                            setattr(self, info, info_val)
    
        def to_dict(self, without=(), include=()):
            """
            主要是将model转换为字典返回
            """
            return transfer.orm_obj2dict(self, without, include)
    
        def update_from_json(self,json_str):
            """
            接受json_str更新原本信息
            """
            return transfer.json_up_orm_obj(json_str,self)
    
        @classmethod
        def from_dict(cls, dic):
            return transfer.dict2obj(dic, cls)
    
        def __repr__(self):
            return self.to_json()
    
        def to_json(self, without=(), include=()):
            return transfer.orm_obj2json(self, without, include)
    
        def save(self):
            if self.id:
                self.update()
            else:
                self.insert()
    
    
    class MixinJSONSerializer:
        @orm.reconstructor
        def init_on_load(self):
            self._fields = []
            # self._include = []
            self._exclude = []
    
            self._set_fields()
            self.__prune_fields()
    
        def _set_fields(self):
            pass
    
        def __prune_fields(self):
            columns = inspect(self.__class__).columns
            if not self._fields:
                all_columns = set(columns.keys())
                self._fields = list(all_columns - set(self._exclude))
    
        def hide(self, *args):
            for key in args:
                self._fields.remove(key)
            return self
    
        def keys(self):
            return self._fields
    
        def __getitem__(self, key):
            return getattr(self, key)
    
    • BaseModel 和MixinJSONSerializer 都是可用于model继承的

    相关文章

      网友评论

      本文标题:博客系列:对SQLAlchemy进行改写。可直接拿到别的项目中使

      本文链接:https://www.haomeiwen.com/subject/kzjjiftx.html