首页 > 财经 >

flask自定义参数校验、序列化和反序列化

发表于: 2023-06-15 15:15:40 来源:博客园


(资料图片)

项目总体结构

我的工厂函数factory.py

from settings import settingfrom flask import Flaskfrom models.models import dbfrom flask_migrate import Migratefrom urls.router import bp_te, bp_lo# from flask_script import Managerfrom utils.log import set_log# from flask_limiter import Limiter# from flask_limiter.util import get_remote_address# https://www.cnblogs.com/Du704/p/13281032.htmlmysql_host = setting.MYSQL_HOSTmysql_port = setting.MYSQL_PORTmysql_user = setting.MYSQL_USERmysql_pwd = setting.MYSQL_PASSWORDmysql_database = setting.MYSQL_DATABASEenv_cnf = setting.ENV_CNFdef create_app():    set_log()    application = Flask(__name__)    DB_URI = f"mysql+pymysql://{mysql_user}:{mysql_pwd}@{mysql_host}:{mysql_port}/{mysql_database}"    application.config["SQLALCHEMY_DATABASE_URI"] = DB_URI    # 是否追踪数据库修改,一般不开启, 会影响性能    application.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False    # 是否显示底层执行的SQL语句    application.config["SQLALCHEMY_ECHO"] = False    # 初始化db,关联项目    db.app = application    db.init_app(application)    migrate = Migrate()    migrate.init_app(application, db)    # 注册蓝图    application.register_blueprint(bp_te)    application.register_blueprint(bp_lo)    # manager = Manager(application)    return applicationapplication = create_app()

配置文件setting.py,读取数据库等配置信息

from configparser import ConfigParserfrom pathlib import Pathfrom utils.encryption import getDataAesBASE_DIR = Path(__file__).resolve().parent.parentconf = ConfigParser()conf.read("config.ini", encoding="utf-8")try:    mysqlhost = conf.get("mysql", "host")    mysqlport = conf.get("mysql", "port")    mysqluser = conf.get("mysql", "user")    mysqlpassword = conf.get("mysql", "password")    mysqlname = conf.get("mysql", "name")    secret = conf.get("serve", "secret")    env_cnf = conf.get("serve", "env")    redishost = conf.get("redis", "host")    redisport = conf.get("redis", "port")    redispwd = conf.get("redis", "password")    redislibrary = conf.get("redis", "library")    MINIOHOST = conf.get("minio", "clienthost")    MINIOPORT = conf.get("minio", "clientport")    miniopwd = conf.get("minio", "password")    miniouser = conf.get("minio", "user")    MINIOWEBHOST = conf.get("minio", "webhost")    MINIOWEBPORT = conf.get("minio", "webport")    MINIOHTTP = conf.get("minio", "http")except Exception as e:    print(e)    mysqlhost = "127.0.0.1"    mysqlport = 3306    mysqluser = "root"    mysqlpassword = "000000"    mysqlname = "0"    secret = "Wchime"    env_cnf = "develop"    redishost = "127.0.0.1"    redisport = "6379"    redispwd = "000000"    redislibrary = "1"    MINIOHOST = "127.0.0.1"    MINIOPORT = "9000"    MINIOWEBPORT = "9000"    miniopwd = "000000"    miniouser = "000000"    MINIOHTTP = "http://"    MINIOWEBHOST = "127.0.0.1"MYSQL_HOST = mysqlhostMYSQL_PORT = mysqlportMYSQL_USER = getDataAes(secret, mysqluser)MYSQL_PASSWORD = getDataAes(secret, mysqlpassword)MYSQL_DATABASE = mysqlnameREDIS_HOST = redishostREDIS_PORT = redisportREDIS_PASSWORD = getDataAes(secret, redispwd)REDIS_LIBRARY = redislibraryMINIOPWD = getDataAes(secret, miniopwd)MINIOUSER = getDataAes(secret, miniouser)ENV_CNF = env_cnfif __name__ == "__main__":    print(mysqlpassword)

models.py数据库模型文件

import datetimefrom utils.core import dbfrom sqlalchemy_serializer import SerializerMixinclass Uu(db.Model, SerializerMixin):    __tablename__ = "uu"    id = db.Column(db.Integer, autoincrement=True, primary_key=True)    name = db.Column(db.String(20), nullable=False)    age = db.Column(db.Integer, nullable=False)    ux_id = db.Column(db.Integer, db.ForeignKey("ux.id", ondelete="SET NULL"), nullable=True)    ux = db.relationship("Ux", backref="uu")        # , lazy="dynamic"    des = db.Column(db.String(20), nullable=True)    img = db.Column(db.String(128), nullable=True)class Ux(db.Model, SerializerMixin):    serialize_rules = ("-uu",)    __tablename__ = "ux"    id = db.Column(db.Integer, autoincrement=True, primary_key=True)    name = db.Column(db.String(20), nullable=False)

序列化文件serializes.py

from models import modelsfrom utils.base import Serialize, DeSerializeclass TestSerialize(Serialize):    model = models.Uu    fields = ["id", "name"]    build_fiels = [        {"name": "ux_name", "source": "ux.name"},        {"name": "img", "method": True}    ]    def get_img(self, instance):        return []class TestDeSerialize(DeSerialize):    model = models.Ux    required_fields = ["name"]    ser_fields = ["id", "name"]

base.py自定义序列化和反序列化和参数解析文件

from models.models import dbfrom flask_restful import abortfrom sqlalchemy import inspectclass DeSerialize(object):    """    反序列化,增删改    """    model = None    req_fields = None    other_fields = []    req_data = {}    insatance = None    ser_fields = []    def __init__(self, insatance=None, data={}):        self.insatance = insatance        self.req_data = data    @property    def required_fields(self):        return self._get_fileds(self.req_fields)    @property    def serializer_fields(self):        return self.ser_fields if self.ser_fields else self.model().serializable_keys    def _get_fileds(self, fileds):        if fileds is None:            values_valid = self.model().serializable_keys        else:            values_valid = fileds        values_valid = list(values_valid)        # 删除主键        try:            primary_key = list(map(lambda x: x.name, inspect(self.model).primary_key))[0]            if primary_key in values_valid:                values_valid.remove(primary_key)        except:            if "id" in values_valid:                values_valid.remove("id")        return values_valid    def _get_vaild_values(self):        vaild_dict = {}        err = "request data is empty"        for key in self.required_fields:            value = self.req_data.get(key)            if value is None:                err = f"{key} is not required"                return False, err            vaild_dict[key] = value        for key in self.other_fields:            value = self.req_data.get(key)            if value is None:                continue            vaild_dict[key] = value        return vaild_dict, err    def _create(self):        vaild_data, err = self._get_vaild_values()        if vaild_data:            try:                instance = self.model(**vaild_data)                db.session.add(instance)                db.session.commit()                self.insatance = instance                return True, "success"            except:                err = "please correct fileds"                return False, err        else:            return False, err    def _update(self):        vaild_data, err = self._get_vaild_values()        if vaild_data:            try:                instance = self.insatance                if instance is None:                    return False, "not find data"                for key, value in vaild_data.items():                    setattr(instance, key, value)                db.session.commit()                return True, "success"            except:                err = "please correct fileds"                return False, err        else:            return False, err    def save(self):        if self.insatance is None:            ret, msg = self._create()            if ret is False:                abort(400, msg=msg)        else:            ret, msg = self._update()            if ret is False:                abort(400, msg=msg)    @property    def data(self):        if self.insatance is None:            msg = "data is not save"            abort(500, msg=msg)        return self.insatance.to_dict(only=tuple(self.serializer_fields))    def delete(self):        if self.insatance is None:            msg = "not find data"            abort(400, msg=msg)        else:            try:                db.session.delete(self.insatance)                db.session.commit()            except Exception:                msg = "delete exception"                abort(400, msg=msg)class Serialize(object):    """    序列化    """    model = None    fields = "__all__"    modelsDatas = []    many = True    date_format = "%Y-%m-%d"    datetime_format = "%Y-%m-%d %H:%M:%S"    time_format = "%H:%M:%S"    build_fiels = []    def __init__(self, serializers, many=True):        self.modelsDatas = serializers        self.many = many    @property    def data(self):        li = []        try:            if self.many:                for data in self.modelsDatas:                    da = data.to_dict(only=tuple(self.serializer_fields), date_format=self.date_format,                                      datetime_format=self.datetime_format, time_format=self.time_format)                    da.update(self._get_build_files_values(data))                    li.append(da)            else:                da = self.modelsDatas.to_dict(only=tuple(self.serializer_fields), date_format=self.date_format,                                              datetime_format=self.datetime_format, time_format=self.time_format)                da.update(self._get_build_files_values(self.modelsDatas))                li.append(da)        except Exception as e:            print(e)            msg = "serialize error"            li.append(msg)            abort(500, msg=msg)        return li    @property    def serializer_fields(self):        return self.fields if self.fields != "__all__" else self.model().serializable_keys    def _get_build_files_values(self, data):        dit = {}        for build in self.build_fiels:            obj = data            if build.get("method"):                func = f"get_{build["name"]}"                f = getattr(self, func)                value = f(obj)                dit[build["name"]] = value            else:                source_list = build["source"].split(".")                value = None                for source in source_list:                    value = getattr(obj, source, None)                    if value:                        obj = value                    else:                        break                dit[build["name"]] = value        return ditclass ParseQuery(object):    """    查询,排序    """    filer_query = frozenset(["gt", "ge", "lt", "le", "ne", "eq", "ic", "ni", "in"])    def __init__(self, model, req_data, filter_list=[], order_by=None):        self.model = model        self.req_data = req_data        self.filter_list = filter_list        self.order_by = order_by        self._operator_funcs = {            "gt": self.__gt_model,            "ge": self.__ge_model,            "lt": self.__lt_model,            "le": self.__le_model,            "ne": self.__ne_model,            "eq": self.__eq_model,            "ic": self.__ic_model,            "ni": self.__ni_model,            # "by": self.__by_model,            "in": self.__in_model,        }    @property    def _filter_data(self):        search_dict = {}        for fit in self.filter_list:            val = self.req_data.get(fit)            key = fit.split("__")[0]            if val and hasattr(self.model, key):                search_dict[fit] = val        return search_dict    def _parse_fields(self):        li = []        for search_key, value in self._filter_data.items():            key, ope = search_key.split("__")            if ope in self.filer_query:                data = self._operator_funcs[ope](key=key, value=value)                li.append(data)        return li    def _filter(self):        data = tuple(self._parse_fields())        quety_data = self.model.query.filter(*data)        if self.order_by:            data = self._parse_order_by()            quety_data = quety_data.order_by(*data)        return quety_data    @property    def query(self):        return self._filter()    def pagination_class(self, page_num=1, page_size=10, max_page_size=50, error_out=False):        pagin = self.query.paginate(            page=page_num,            per_page=page_size,            error_out=error_out,            max_per_page=max_page_size        )        return pagin.items, pagin.total    def _parse_order_by(self):        """        解析排序        :return:        """        li = []        for ord in list(self.order_by):            if ord.find("-") == -1:                data = self.__by_model(ord)                if data:                    li.append(data.asc())            else:                ord = ord[1:]                data = self.__by_model(ord)                if data:                    li.append(data.desc())        return tuple(li)    def __by_model(self, key):        """        排序时获取字段        :return:        """        return getattr(self.model, key)    def __gt_model(self, key, value):        """        大于        :param key:        :param value:        :return:        """        return getattr(self.model, key) > value    def __ge_model(self, key, value):        """        大于等于        :param key:        :param value:        :return:        """        return getattr(self.model, key) >= value    def __lt_model(self, key, value):        """        小于        :param key:        :param value:        :return:        """        return getattr(self.model, key) < value    def __le_model(self, key, value):        """        小于等于        :param key:        :param value:        :return:        """        return getattr(self.model, key) <= value    def __eq_model(self, key, value):        """        等于        :param key:        :param value:        :return:        """        return getattr(self.model, key) == value    def __ne_model(self, key, value):        """        不等于        :param key:        :param value:        :return:        """        return getattr(self.model, key) != value    def __ic_model(self, key, value):        """        包含        :param key:        :param value:        :return:        """        return getattr(self.model, key).like("%{}%".format(value))    def __ni_model(self, key, value):        """        不包含        :param key:        :param value:        :return:        """        return getattr(self.model, key).notlike("%{}%".format(value))    def __in_model(self, key, value):        """        查询多个相同字段的值        :param key:        :param value:        :return:        """        return getattr(self.model, key).in_(value)

自定义序列化和反序列化后,接口将变得简单

上面的类视图将只有短短的几行代码

标签:
x 广告

Copyright ©  2015-2022 东方公司网版权所有  备案号:沪ICP备2020036824号-8   联系邮箱:562 66 29@qq.com