统一的校验模块
统一的校验模块
说明
- 支持类型有str, int, float, bool, date, list, dict
- 支持检查有:
- 是否可选
- 默认值
- 字符串长度
- 整形和浮点数的最大最小值
- 自定义检查
BasePermitView 改动
@classmethod
def _inner_view_(cls, **initkwargs):
def view(request, *args, **kwargs):
self = cls(request=request, **initkwargs)
self.args = args
self.kwargs = kwargs
if hasattr(self, 'get') and not hasattr(self, 'head'):
self.head = self.get
if self.request_method in self.http_method_names:
handler = getattr(
self, self.request_method, self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
# 添加这行
return self._param_check() or handler(request, *args, **kwargs)
# take name and docstring from class
update_wrapper(view, cls, updated=())
return view
def _param_check(self):
return
CommonBaseView 改动
from website.utils.param_check import GmError, ErrCode, ParamChecker as pc
class CommonBaseView(BasePermitView, FormAjaxResponseMixin):
param_check_dict = dict()
extra_param_check = None
def _param_check(self):
self.params = {}
try:
for k, v in self.param_check_dict.items():
arg_dict = {
'param_map': self.request_param,
'key': k,
'min': v.min,
'max': v.max,
'opt': v.optional,
'default': v.default
}
param_type = v.param_type
if hasattr(pc, 'check_' + param_type.__name__):
check_func = getattr(pc, 'check_' + param_type.__name__)
else: # 自定义函数做校验
if k in self.request_param:
self.params[k] = param_type(self.request_param[k])
elif v.optional:
self.params[k] = v.default
else:
raise GmError(ErrCode.param_err, '缺少{0}参数'.format(k))
continue
self.params[k] = check_func(**arg_dict)
except GmError as e:
return self.JsonErrorResponse(str(e))
@log
def __init__(self, request, *args, **kwargs):
super().__init__(request, *args, **kwargs)
# 添加下面四行
if self.request_method == 'post':
self.request_param = request.POST.dict()
elif self.request_method == 'get':
self.request_param = request.GET.dict()
....
param_check.py
from datetime import datetime
from enum import IntEnum
import json
from math import inf
def str2datetime(s):
# 由于业务需求比较简单,使用这个函数代替 datetime.strptime()
# 实际测试中,这个函数比 datetime.strptime() 快7倍多
year_s, mon_s, day_s = s.split('-')
return datetime(int(year_s), int(mon_s), int(day_s))
class Param(object):
def __init__(self, param_type, min=-inf, max=inf,
optional=False, default=None):
self.param_type = param_type
self.min = min
self.max = max
self.optional = optional
self.default = default
# self.scope = scope # 枚举使用
class ErrCode(IntEnum):
param_err = 1
mongo_err = 2
mysql_err = 3
business_err = 4
class GmError(Exception):
def __init__(self, err_code, err_str):
self.err_code = int(err_code)
self.err_str = str(err_str)
def __str__(self):
return self.err_str
class ParamChecker(object):
def __check_optional(func):
def _wrapper(param_map, key, opt, default, **kwarg):
if not opt and key not in param_map:
raise GmError(ErrCode.param_err, '缺少{0}参数'.format(key))
if default is None:
return func(param_map=param_map, key=key, opt=opt, **kwarg)
else:
return func(param_map=param_map, key=key, opt=opt,
default=default, **kwarg)
return _wrapper
@staticmethod
@__check_optional
def check_str(param_map, key, min, max, opt, default='', **kwarg):
try:
result = str(param_map.get(key, default))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值不是字符串'.format(key))
result_len = len(result)
if result_len < min:
raise GmError(ErrCode.param_err,
'{0}参数最短长度为{1}'.format(key, min))
if result_len > max:
raise GmError(ErrCode.param_err,
'{0}参数太长长度为{1}'.format(key, max))
return result
@staticmethod
@__check_optional
def check_int(param_map, key, min, max, opt, default=0, **kwarg):
try:
result = int(param_map.get(key, default))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值不是整数'.format(key))
if result < min:
raise GmError(ErrCode.param_err, '{0}参数最小为{1}'.format(key, min))
if result > max:
raise GmError(ErrCode.param_err, '{0}参数最大为{1}'.format(key, max))
return result
@staticmethod
@__check_optional
def check_float(param_map, key, min, max, opt, default=0, **kwarg):
try:
result = float(param_map.get(key, default))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值不是浮点数'.format(key))
if result < min:
raise GmError(ErrCode.param_err, '{0}参数最小为{1}'.format(key, min))
if result > max:
raise GmError(ErrCode.param_err, '{0}参数最大为{1}'.format(key, max))
return result
@staticmethod
@__check_optional
def check_bool(param_map, key, opt, default=0, **kwarg):
try:
return bool(int(param_map.get(key, default)))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值必须是0或1'.format(key))
@staticmethod
@__check_optional
def check_datetime(param_map, key, opt, default=datetime.now(), **kwarg):
try:
if key not in param_map:
return default
return str2datetime(param_map[key])
except ValueError:
raise GmError(ErrCode.param_err,
'{0}参数的值必须是2000-01-01形式的日期'.format(key))
@staticmethod
@__check_optional
def check_list(param_map, key, opt, check_list, default='[]', **kwarg):
try:
return json.loads(param_map.get(key, default))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值不是列表'.format(key))
@staticmethod
@__check_optional
def check_dict(param_map, key, opt, check_dict, default='{}', **kwarg):
try:
return json.loads(param_map.get(key, default))
except ValueError:
raise GmError(ErrCode.param_err, '{0}参数的值不是字典'.format(key))
使用示例
class StockOutView(CommonBaseView):
# 搜索出库单
param_check_dict = {
'start_date': Param(datetime),
'end_date': Param(datetime),
'date_type': Param(str),
'offset': Param(int, optional=True, default=0),
'limit': Param(int, optional=True, default=10),
'status': Param(lambda x: x in [1, 2, 3], optional=True),
'query_data': Param(str, optional=True)
}
def __init__(self, request, *args, **kwargs):
super().__init__(request, *args, **kwargs)
self.__receipts = []
self.__count = 0
def extra_param_check(self):
self.params['station_id'] = self.station_id
...