目的
用于对函数参数类型进行检验。
除指定参数类型为python原始数据类型外,包括以下几种情况
- 没有指定数据类型的参数;
- typing包中的类型;
- 自定义类作为参数类型。
此外,除了能对普通函数进行参数检验外,本装饰器还可以能对类的成员函数进行参数检验。
现状
python作为一门弱类型语言,带有自动类型推导,在一定程度上能减少开发工作量,带来便利。但也正是因为这个特性,使得我们在开发过程中需要验证函数参数类型时非常麻烦,除非你认为重复在函数头部写上一堆assert语句是一项轻松的工作。
在网络上也有一些为了解决这个问题而开发的代码,比如这篇文章https://www.jianshu.com/p/7a2c9133a002 但仍然存在一些问题:
- 每次使用还需要在装饰器函数中写上参数类型;
- 由于1的原因,必须要对每一个参数进行类型注解;
- 对于类的函数无法正确地进行参数类型检验;
- 检验出类型错误时的提示不够友好。
以上种种,都说明需要进一步开发来进行完善。
解决方案
下面是我改进后的代码,真正做到了简单易用,最大限度减轻了开发者的工作量
import inspect
from inspect import signature
from functools import wraps
def type_assert( func ):
"""作为装饰器,用于对函数参数类型进行检验。
除指定参数类型为python原始数据类型外,包括以下几种情况
1. 没有指定数据类型的参数;
2. typing包中的类型;
3. 自定义类作为参数类型。
此外,除了能对普通函数进行参数检验外,本装饰器还可以能对类的成员函数进行参数检验。
Args:
func (_type_): _description_
Raises:
TypeError: _description_
TypeError: _description_
Returns:
_type_: _description_
"""
sig = signature( func )
params = sig.parameters
@wraps( func )
def wrapper( *args, **kargs ):
paramNames = list( params.keys() )
paramValues = list( args )
# 对类函数进行处理
if paramNames[0] in [ 'self', 'cls' ]:
paramNames.pop(0)
paramValues.pop(0)
# 对传参时没有显示指定参数名的参数进行校验
for paramName, paramValue in zip( paramNames, paramValues ):
validate_param_value_type( paramName, paramValue )
# 对传参时显示指定参数名的参数进行校验
for paramName, paramValue in kargs.items():
validate_param_value_type( paramName, paramValue )
return func( *args, **kargs )
def validate_param_value_type( paramName, paramValue):
expectType = params[ paramName ].annotation
# 将 typing 中的类型(List、Dict、Tuple等)转换成对应的 python 原始数据类型
if hasattr( expectType, '__origin__' ):
expectType = expectType.__origin__
# 没有指定参数类型,任何数据类型都可以作为参数
if expectType == inspect._empty:
return
if not isinstance( paramValue, expectType ):
raise TypeError( f'Argument {paramName} must be {expectType}, but recived {type(paramValue)}.' )
return wrapper
使用示例
下面是两个例子
@type_assert
def s(a:int, b:List, c:Dict=None, d:List=[]):
print('done')
s(1, b=[])
class Foot:
pass
class H:
@type_assert
def __init__(self, a:int, b:int, c:Foot=None) -> None:
self.a = a
self.b = b
self.c = c
@classmethod
@type_assert
def data(cls, data:list):
pass
t = H(1,1,Foot())
网友评论