简介
ClassificationBox是MachineBox的一部分,提供了一个通用的分类器模型。 关于MachineBox的简介以及环境相关的设置可以参照我的介绍文章,这里就不赘述了。我们直接启动
docker run -p 8080:8080 -e "MB_KEY=<your-mb-key>" machinebox/classificationbox
接下来我们来实现一个猫狗识别的例子
例子
打开http://localhost:8080, 官方已经提供了详尽的API文档。
创建模型
首先第一步, 我们需要创建一个模型,来看API
POST http://localhost:8080/classificationbox/models
{
"id": "sentiment1",
"name": "sentimentModel",
"options": {
"ngrams": 1,
"skipgrams": 1
},
"classes": [
"class1",
"class2",
"class3"
]
}
-
id
- (optional string) 模型id,用于后续操作。 不指定则会自动生成一个 -
name
- (string) 模型的名字 -
options
- (optional object) 模型的一些可选参数, 目前支持的选项有-
ngrams
- (number) 文本特征化n-grams算法中的n值 (默认值为1) -
skipgrams
- (number) 文本特征化n-grams算法中的skip值 (默认值为0,不跳跃)
-
-
classes
- (array of string) 模型需要识别的类别集合 - 需要将
Accept
和Content-Type
headers 设置成"application/json; charset=utf-8"
在实际开发中,MachineBox也提供了官方的SDK, 不过这里我就直接发送请求了
curl -X POST \
http://localhost:8080/classificationbox/models \
-H 'accept: application/json; charset=utf-8' \
-H 'content-type: application/json; charset=utf-8' \
-d '{
"id": "catVsDog",
"name": "catVsDog",
"classes": [
"cat",
"dog"
]
}'
好了,到此我们已经创建好了一个ID为catVsDog
的模型
训练模型
接下来我们需要训练模型,先看单个样本的API
POST http://localhost:8080/classificationbox/models/{model_id}/teach
{
"class": "class1",
"inputs": [
{"key": "user_age", "type": "number", "value": "25"},
{"key": "user_interests", "type": "list", "value": "music,cooking,ml"},
{"key": "user_location", "type": "keyword", "value": "London"}
]
}
-
{model_id}
- (string parameter) 模型的ID -
class
- (string) 样本所属类别 -
inputs
- (array) 样本的feature数组 -
inputs[].key
- (string) feature的标识 -
inputs[].type
- (string) feature的类型,支持的类型有-
number
- 数值型,包括整数和浮点数 -
text
- 文本型,会被token化。 可以进一步指定语言,支持的语言有-
text_en
- 英语 -
text_sp
- 西班牙语 -
text_fr
- 法语 -
text_ru
- 俄语 -
text_sv
- 瑞典语 -
text_zh
- 中文 -
text_ge
- 德语 -
text_nl
- 荷兰语 -
text_pt
- 葡萄牙语
-
-
keyword
- 不会被token化的文本关键词 -
list
- 关键词列表 -
image_url
- 图片url,ClassificationBox会自动下载 -
image_base64
-base64编码的图片数据
-
-
inputs[].value
- (string) feature的值, 注意类型是string型
为了方便使用,也同样提供了多个样本的API
POST http://localhost:8080/classificationbox/models/{model_id}/teach-multi
{
"examples": [
{
"class": "class1",
"inputs": [{"key": "user_age", "type": "number", "value": "25"}]
},
{
"class": "class2",
"inputs": [{"key": "user_age", "type": "number", "value": "55"}]
},
{...}
]
}
接下来我们用几个简单的猫狗识别的图片来训练模型
类别 | 图片 | url |
---|---|---|
dog | https://img.haomeiwen.com/i13069854/ec88087820d9877f.png | |
cat | https://img.haomeiwen.com/i13069854/7668001c3341e1fc.png |
还是直接调用API
curl -X POST \
http://localhost:8080/classificationbox/models/catVsDog/teach-multi \
-H 'accept: application/json; charset=utf-8' \
-H 'content-type: application/json; charset=utf-8' \
-d '{
"examples": [
{
"class": "cat",
"inputs": [{
"key": "image",
"type": "image_url",
"value": "https://img.haomeiwen.com/i13069854/7668001c3341e1fc.png"
}]
},
{
"class": "dog",
"inputs": [{
"key": "image",
"type": "image_url",
"value": "https://img.haomeiwen.com/i13069854/ec88087820d9877f.png"
}]
}
]
}'
模型预测
接下来我们使用模型来预测, 还是先来看API
POST http://localhost:8080/classificationbox/models/{model_id}/predict
{
"limit": 10,
"inputs": [
{"key": "user_age", "type": "number", "value": "25"},
{"key": "user_interests", "type": "list", "value": "music,cooking,ml"},
{"key": "user_location", "type": "keyword", "value": "London"}
]
}
-
{model_id}
- (string parameter) 模型ID -
limit
- (optional number) 返回按置信度排序的前多少个分类预测信息 (默认值为10
) -
inputs
- (optional array) 同训练
我们来尝试对下面的图形来做预测, 它的url为https://img.haomeiwen.com/i13069854/494c4e162690d316.png
同样直接上PostMan
curl -X POST \
http://localhost:8080/classificationbox/models/catVsDog/predict \
-H 'accept: application/json; charset=utf-8' \
-H 'content-type: application/json; charset=utf-8' \
-d '{
"inputs":[{
"key": "image",
"type": "image_url",
"value": "https://img.haomeiwen.com/i13069854/494c4e162690d316.png"
}]
}'
我们得到结果
{
"success": true,
"classes": [
{
"id": "cat",
"score": 0.587596
},
{
"id": "dog",
"score": 0.412404
}
]
}
结果并不理想,这主要是由于我们训练的样本太少造成的。 大家可以自行使用用Kaggle catVsDog数据集去实验
模型导出
导出模型是机器学习中极其重要的一步,我们可以通过API
GET http://localhost:8080/classificationbox/state/{model_id}
-
{model_id}
- (string parameter) 模型ID
直接下载到一个.classificationbox
的模型文件
模型导入
同样也有导入的API
POST http://localhost:8080/classificationbox/state
- predict_only - (bool) 默认值为false. 如果设成true,那么训练将会被禁用,调用训练API会返回400。
该API支持file
,url
, base64
3种不同的输入
除了上述例子中涉及到的API, ClassificationBox还提供了模型查看, 删除等的API, 可以在http://localhost:8080中看到, 这里就不再一一罗列。 另外所有的MachineBox的镜像都提供了诸如/healthz
, /liveness
, /readyz
等实用的API, 具体可以查阅官方文档
网友评论