今天遇到了一个场景:查看一组二维的中间数据,print之后比较乱,如下图所示
原始打印效果.png
这是严格的10×25的数组,其中有一些英文空格,占位不足,所以打算用一个函数去优化他,这里考虑三种情况:
(1)如果数据类型不是List而是Tensor或者Numpy
(2)如果第一维度的list长度不一致
(3)如果输入数据的0,1维度倒置,也就是下面这种情况
数据结构.png
这里用prettytable包来实现了,不得不说确实是比较美观,和一开始的效果差距还是比较大,终端输出如下:
prettytable效果.png
具体的代码如下,调用row_table、column_table就可以看到打印效果,column_table会将0,1维度调换:
import prettytable as pt
import numpy as np
import torch
def row_table(d):
data = pre_type(d)
tb = pt.PrettyTable()
row_len = len(data[0])
row_num = len(data)
print(row_len,row_num)
tb.field_names = list(map(lambda x:str(x), range(0, row_len+1)))
for i in range(0,row_num):
row = data[i]
row.insert(0,str(i))
tb.add_row(row)
print(tb)
def column_table(d):
# 如果1维度是按列(备选组)的,那就转化成按行(一个句子)的,然后再打印
# 毕竟只是展示,还是打印成一行看着清楚
# 所有预处理之前都用pre_type先转成二维list的形式,不接受Ternsor,Numpy类型
data = pre_type(d)
data = pre_size(data)
row_table(data)
def pre_size(d):
data = np.array(d)
data.transpose()
return data.tolist()
def pre_type(d):
data = None
if isinstance(d,list):
data = d
elif type(d) is np.ndarray:
data = data.tolist()
elif torch.is_tensor(d):
data = data.numpy().tolist()
max_len = max(list(map(lambda x:len(x), data)))
data=list(map(lambda l:l+[' ']*(max_len-len(l)), data))
print(data)
return data
网友评论