美文网首页
浏览器用TensorFlow PoseNet进行姿态估计,支持本

浏览器用TensorFlow PoseNet进行姿态估计,支持本

作者: 雨田君的记事本 | 来源:发表于2020-02-19 17:47 被阅读0次
姿态识别结果

TensorFlow 在18年的一篇文章里面有报道:
英文版:Real-time Human Pose Estimation in the Browser with TensorFlow.js
翻译版:在浏览器中用TensorFlow.js进行实时人体姿态估计
简单介绍下,就是利用TensorFlow.js版本的PoseNet模型来识别,现在官网中也能看到了。

posenet模型

点击TensorFlow的姿态估计直接跳到了GitHub,里面有详细的介绍,教你如何配置,并且还贴心的给出了一个在线的demo,但是需要梯子才能运行。

使用介绍

大概看了下文档,使用起来还是比较简单的,demo里面其实给了两个例子,一个是静态图片的姿态识别,一个是摄像头的实时姿态识别。

下面就先用静态图片识别做个介绍,分解一下主要就是下面4个步骤:

  1. 首先需要加载 TensorFlow.js 和 Posenet
  2. 执行posenet.load方法拿到模型对象
  3. 调用模型对象的estimateSinglePose方法去识别图片
  4. 上一步会返回17个关键点的坐标,将坐标点绘制到网页即可

看起来还挺简单,但是理想很丰满,现实很骨感呀,下面说说可能遇到的坑:

TensorFlow.js无法加载

demo给出的TensorFlow.jsposenet 地址无法访问,得用梯子才能拿到,所以最好是保存到本地,然后每次从自己服务器加载。

模型load不了

执行posenet.load半天没反应,打开浏览器的控制台,发现一堆bin文件的网络请求。

bin文件请求

原来是模型文件都在google服务器存着的,每次load都是根据你的配置信息在线下载对应的模型文件,所以运行的时候还得用梯子,好在文档中说load方法有个modelUrl参数,可以指定模型文件的位置。modelUrl是设置.json文件的存放位置,框架会根据这个.json自动去下载对应的.bin文件,所以只需要把.json文件和.bin文件手动下载下来一起放到自己服务器上就行。然后调用如下:

posenet.load({ modelUrl: '/pose_models/model-stride16.json' })
.then(net => {
  console.log(net);
})
.catch(err => {
  console.log(err);
})

模型文件一般都是几十兆甚至上百兆,如果自己服务器带宽比较小,可以放到七牛或者阿里云的oss上。

辅助函数

识别之后拿到了坐标,还需要自己去绘制出来,所以得自己去写绘制方法。

下面是源码

<!DOCTYPE html>
<html>
  <head>
    <!-- 加载 TensorFlow.js,建议换成自己的服务器地址 -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <!-- 加载 Posenet,建议换成自己的服务器地址 -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/posenet"></script>
    <style type="text/css">
    .wrap {
      display: flex;
      justify-content: center;
      align-items: center;
      flex-direction: row;
    }
    #myImg {
      width: 300px;
    }
    </style>
  </head>

  <body>
    <div class="wrap">
      <img id="myImg" crossOrigin="anonymous" src="https://img.haomeiwen.com/i33776/129d3871bccfaa22.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240"/>
      <canvas id="output"></canvas>
    </div>
  </body>
  <script>
    const color = "#ff0000";
    const minConfidence = 0.2;
    const lineWidth = 1;
    // ######################################### 工具函数
    // 坐标转换
    function toTuple({ y, x }) {
      return [y, x];
    }
    // 将图片绘制到canvas
    function renderImageToCanvas(image, size, canvas) {
      canvas.width = size[0];
      canvas.height = size[1];
      const ctx = canvas.getContext('2d');

      ctx.drawImage(image, 0, 0, size[0], size[1]);
    }

    // 画关键点
    function drawKeypoints(keypoints, ctx, scale = 1) {
      for (let i = 0; i < keypoints.length; i++) {
        const keypoint = keypoints[i];
        if (keypoint.score < minConfidence) {
          continue;
        }
        const { y, x } = keypoint.position;
        drawPoint(ctx, y * scale, x * scale, 3, color);
      }
    }
    // canvas画点
    function drawPoint(ctx, y, x, r, color) {
      ctx.beginPath();
      ctx.arc(x, y, r, 0, 2 * Math.PI);
      ctx.fillStyle = color;
      ctx.fill();
    }

    // 关键点连线
    function drawSkeleton(keypoints, ctx, scale = 1) {
      const adjacentKeyPoints =
        posenet.getAdjacentKeyPoints(keypoints, minConfidence);

      adjacentKeyPoints.forEach((keypoints) => {
        drawSegment(
          toTuple(keypoints[0].position), toTuple(keypoints[1].position), color, scale, ctx);
      });
    }
    // canvas画线
    function drawSegment([ay, ax], [by, bx], color, scale, ctx) {
      ctx.beginPath();
      ctx.moveTo(ax * scale, ay * scale);
      ctx.lineTo(bx * scale, by * scale);
      ctx.lineWidth = lineWidth;
      ctx.strokeStyle = color;
      ctx.stroke();
    }

    // ######################################### 识别图片
    function detectImg() {
      let imageElement = document.getElementById('myImg');
      let canvas = document.getElementById('output');
      
      // 设置加载模型走自己服务器,不设置则走google的服务器
      posenet.load({ modelUrl: '/pose_models/model-stride16.json' }).then(net => { 
      // posenet.load({}).then(net => { 
        return net.estimateSinglePose(imageElement, {
          flipHorizontal: false
        });
      }).then(pose => {
        console.log(pose);
        renderImageToCanvas(imageElement, [imageElement.width, imageElement.height], canvas);
        let ctx = canvas.getContext('2d');
        drawKeypoints(pose.keypoints, ctx);
        drawSkeleton(pose.keypoints, ctx);
      });
    }
    
    // 识别
    detectImg();
  </script>
</html>

结果如下,这年头文章不放两张妹子图都没人看呀:

image.png
代码优化

上面的代码还有个问题就是,每次加载页面都需要去加载模型文件,虽然说是在自己的服务器上,但是毕竟模型文件还是有几十兆呀。好在浏览器能缓存模型文件,但是还是有缓存失效的问题,要是其他页面要用也要再加载一遍,太浪费资源了。
不久前,看到有些文章说 TensorFlow.js 微信小程序插件开始支持模型缓存了,然后去翻官方文档,发现确实可以将模型缓存到本地的,并且支持多种方式缓存,文档地址:https://www.tensorflow.org/js/guide/save_load
浏览器上建议缓存到indexeddb,一来没有文件大小限制,二来同域名下均可读取,完美。
但是posenet这个库有点坑,居然没有开放出这个方法,没办法,只能翻了源码之后自己写一个:

      const MOBILENET_V1_CONFIG = {
        architecture: 'MobileNetV1',
        outputStride: 16,
        multiplier: 0.75,
        inputResolution: 257,
      }
      async function loadMobileNet(config=MOBILENET_V1_CONFIG) {
        const outputStride = config.outputStride;
        const quantBytes = config.quantBytes;
        const multiplier = config.multiplier;

        var graphModel = null;
        try {
          graphModel = await tf.loadGraphModel('indexeddb://my-model');
          console.log('从缓存中加载模型');
        } catch(e) {
          console.log(e);
          graphModel = await tf.loadGraphModel('/pose_models/model-stride16.json');
          graphModel.save('indexeddb://my-model');
          console.log('从网络中加载模型');
        }
        const mobilenet = new posenet.MobileNet(graphModel, outputStride);
        
        const validInputResolution = [config.inputResolution, config.inputResolution];
        return new posenet.PoseNet(mobilenet, validInputResolution);
      }

ok,现在把以前的posenet.load方法替换成loadMobileNet就行了,这样第一次会从网络加载,后面就走本地缓存了,而且同域名下其他页面也能直接从缓存加载:

    loadMobileNet().then(net => {
      return net.estimateSinglePose(imageElement, {
        flipHorizontal: false
      });
    }).then(pose => {
      console.log(pose);
    });
实时视频识别

这个得先拉起用户摄像头,然后一帧一帧的去绘制,其实和识别图片原理是一样的。官方的demo里面有具体代码可以自行查看。


实时识别

相关文章

网友评论

      本文标题:浏览器用TensorFlow PoseNet进行姿态估计,支持本

      本文链接:https://www.haomeiwen.com/subject/iwaafhtx.html