Skip to content

统一的校验模块

统一的校验模块

说明

  • 支持类型有str, int, float, bool, date, list, dict
  • 支持检查有:
    1. 是否可选
    2. 默认值
    3. 字符串长度
    4. 整形和浮点数的最大最小值
    5. 自定义检查

BasePermitView 改动

@classmethod
def _inner_view_(cls, **initkwargs):
   def view(request, *args, **kwargs):
       self = cls(request=request, **initkwargs)
       self.args = args
       self.kwargs = kwargs
       if hasattr(self, 'get') and not hasattr(self, 'head'):
           self.head = self.get

       if self.request_method in self.http_method_names:
           handler = getattr(
               self, self.request_method, self.http_method_not_allowed)
       else:
           handler = self.http_method_not_allowed
       # 添加这行
       return self._param_check() or handler(request, *args, **kwargs)

   # take name and docstring from class
   update_wrapper(view, cls, updated=())

   return view


def _param_check(self):
    return

CommonBaseView 改动

from website.utils.param_check import GmError, ErrCode, ParamChecker as pc



class CommonBaseView(BasePermitView, FormAjaxResponseMixin):
    param_check_dict = dict()
    extra_param_check = None

    def _param_check(self):
        self.params = {}
        try:
            for k, v in self.param_check_dict.items():
                arg_dict = {
                    'param_map': self.request_param,
                    'key': k,
                    'min': v.min,
                    'max': v.max,
                    'opt': v.optional,
                    'default': v.default
                }
                param_type = v.param_type
                if hasattr(pc, 'check_' + param_type.__name__):
                    check_func = getattr(pc, 'check_' + param_type.__name__)
                else:  # 自定义函数做校验
                    if k in self.request_param:
                        self.params[k] = param_type(self.request_param[k])
                    elif v.optional:
                        self.params[k] = v.default
                    else:
                        raise GmError(ErrCode.param_err, '缺少{0}参数'.format(k))
                    continue
                self.params[k] = check_func(**arg_dict)
        except GmError as e:
            return self.JsonErrorResponse(str(e))

    @log
    def __init__(self, request, *args, **kwargs):
       super().__init__(request, *args, **kwargs)
       # 添加下面四行
       if self.request_method == 'post':
           self.request_param = request.POST.dict()
       elif self.request_method == 'get':
           self.request_param = request.GET.dict()
           ....

param_check.py

from datetime import datetime
from enum import IntEnum
import json
from math import inf


def str2datetime(s):
    # 由于业务需求比较简单,使用这个函数代替 datetime.strptime()
    # 实际测试中,这个函数比 datetime.strptime() 快7倍多
    year_s, mon_s, day_s = s.split('-')
    return datetime(int(year_s), int(mon_s), int(day_s))


class Param(object):
    def __init__(self, param_type, min=-inf, max=inf,
                 optional=False, default=None):
        self.param_type = param_type
        self.min = min
        self.max = max
        self.optional = optional
        self.default = default
        # self.scope = scope  # 枚举使用


class ErrCode(IntEnum):
    param_err = 1
    mongo_err = 2
    mysql_err = 3
    business_err = 4


class GmError(Exception):
    def __init__(self, err_code, err_str):
        self.err_code = int(err_code)
        self.err_str = str(err_str)

    def __str__(self):
        return self.err_str


class ParamChecker(object):

    def __check_optional(func):
        def _wrapper(param_map, key, opt, default, **kwarg):
            if not opt and key not in param_map:
                raise GmError(ErrCode.param_err, '缺少{0}参数'.format(key))
            if default is None:
                return func(param_map=param_map, key=key, opt=opt, **kwarg)
            else:
                return func(param_map=param_map, key=key, opt=opt,
                            default=default, **kwarg)
        return _wrapper

    @staticmethod
    @__check_optional
    def check_str(param_map, key, min, max, opt, default='', **kwarg):
        try:
            result = str(param_map.get(key, default))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值不是字符串'.format(key))
        result_len = len(result)
        if result_len < min:
            raise GmError(ErrCode.param_err,
                          '{0}参数最短长度为{1}'.format(key, min))
        if result_len > max:
            raise GmError(ErrCode.param_err,
                          '{0}参数太长长度为{1}'.format(key, max))
        return result

    @staticmethod
    @__check_optional
    def check_int(param_map, key, min, max, opt, default=0, **kwarg):
        try:
            result = int(param_map.get(key, default))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值不是整数'.format(key))
        if result < min:
            raise GmError(ErrCode.param_err, '{0}参数最小为{1}'.format(key, min))
        if result > max:
            raise GmError(ErrCode.param_err, '{0}参数最大为{1}'.format(key, max))
        return result

    @staticmethod
    @__check_optional
    def check_float(param_map, key, min, max, opt, default=0, **kwarg):
        try:
            result = float(param_map.get(key, default))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值不是浮点数'.format(key))
        if result < min:
            raise GmError(ErrCode.param_err, '{0}参数最小为{1}'.format(key, min))
        if result > max:
            raise GmError(ErrCode.param_err, '{0}参数最大为{1}'.format(key, max))
        return result

    @staticmethod
    @__check_optional
    def check_bool(param_map, key, opt, default=0, **kwarg):
        try:
            return bool(int(param_map.get(key, default)))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值必须是0或1'.format(key))

    @staticmethod
    @__check_optional
    def check_datetime(param_map, key, opt, default=datetime.now(), **kwarg):
        try:
            if key not in param_map:
                return default
            return str2datetime(param_map[key])
        except ValueError:
            raise GmError(ErrCode.param_err,
                          '{0}参数的值必须是2000-01-01形式的日期'.format(key))

    @staticmethod
    @__check_optional
    def check_list(param_map, key, opt, check_list, default='[]', **kwarg):
        try:
            return json.loads(param_map.get(key, default))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值不是列表'.format(key))

    @staticmethod
    @__check_optional
    def check_dict(param_map, key, opt, check_dict, default='{}', **kwarg):
        try:
            return json.loads(param_map.get(key, default))
        except ValueError:
            raise GmError(ErrCode.param_err, '{0}参数的值不是字典'.format(key))


使用示例


class StockOutView(CommonBaseView):
    # 搜索出库单

    param_check_dict = {
        'start_date': Param(datetime),
        'end_date': Param(datetime),
        'date_type': Param(str),
        'offset': Param(int, optional=True, default=0),
        'limit': Param(int, optional=True, default=10),
        'status': Param(lambda x: x in [1, 2, 3], optional=True),
        'query_data': Param(str, optional=True)
    }

    def __init__(self, request, *args, **kwargs):
        super().__init__(request, *args, **kwargs)
        self.__receipts = []
        self.__count = 0

    def extra_param_check(self):
        self.params['station_id'] = self.station_id
...