资讯专栏INFORMATION COLUMN

好球还是坏球(棒球中术语),用tensorflow.js预测一下?

chinafgj / 2397人阅读

摘要:感谢像这样的框架,使得这些数据集可以应用于机器学习领域。蓝点被标记为坏球,橙点被标记为好球标注来自大联盟裁判员使用构建模型将机器学习带入和领域。使用库将预测结果呈现为热图。好球区域位于本垒板上方至英尺之间。

在这篇文章中,我们将使用TensorFlow.js,D3.js和网络的力量来可视化训练模型的过程,以预测棒球数据中的坏球(蓝色区域)和好球(橙色区域)。在整个训练过程中,我们将一步一步的将模型预测出的好球区域动态的展示出来。您可以通过访问Observable notebook网站在浏览器中运行此模型。

体育方面的高级指标

如今的职业体育环境里充满了大量的数据。这些数据被团队、业余爱好者和粉丝应用于各种案例。感谢像TensorFlow这样的框架,使得这些数据集可以应用于机器学习领域。

美国职业棒球大联盟高级媒体(MLBAM)发布了一个可供公众研究的大型数据集。该数据集包含有关过去几年在MLB游戏中投掷的投球的传感器信息。从这个数据集中挑选了一个包含5000个样本(2,500个坏球和2,500个好球)的训练集用于此处实验。

以下是训练数据的具体数据格式示例:

以下是绘制好球区域时的训练数据分布。蓝点被标记为坏球,橙点被标记为好球(标注来自大联盟裁判员)

使用TensorFlow.js构建模型

TensorFlow.js将机器学习带入JavaScript和Web领域。我们将使用这个优秀的框架来构建一个深度神经网络模型。这个模型将能够以大联盟裁判的精确度来区分好球和坏球。

该模型从PITCHf/x中选出以下评测指标进行训练:

协调球越过本垒的位置("px"和"pz")

击球手站在球场的哪一侧

击球区(击球手的躯干)的高度,以英尺为单位。

击球区底部的高度(击球手的膝盖)以英尺为单位

该次击球是好球还是坏球(由裁判员判定的)

结构

我们将使用TensorFlow.js的Layers API定义此模型。Layers API基于Keras,对以前使用过Keras框架的人来说应该很熟悉:

const model = tf.sequential();

// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: "relu", inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: "relu"}));
model.add(tf.layers.dropout({rate: 0.01}));

// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: "softmax"}));

model.compile({
  optimizer: tf.train.adam(0.01),
  loss: "categoricalCrossentropy",
  metrics: ["accuracy"]
});
加载和准备数据

精选的训练集可以在GitHub gist获取。该数据集是CSV格式的,需要下载下来在本地转换成符合TensorFlow.js的格式。

const data = [];
csvData.forEach((values) => {
  // "logit" data uses the 5 fields:
  const x = [];
  x.push(parseFloat(values.px));
  x.push(parseFloat(values.pz));
  x.push(parseFloat(values.sz_top));
  x.push(parseFloat(values.sz_bot));
  x.push(parseFloat(values.left_handed_batter));
  // The label is simply "is strike" or "is ball":
  const y = parseInt(values.is_strike, 10);
  data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);

解析CSV数据后,需要将JS类型转换为Tensor batches才能进行训练和评估。有关此过程的详细信息,请参阅code lab。TensorFlow.js团队正在开发一种新的数据API接口,以便使数据获取在将来变得更容易。

训练模型

让我们把前期的准备都综合起来吧。定义好了模型,准备好了训练数据,现在我们将要开始训练了。以下的异步方法训练了一批训练样本并更新热图:

// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
  const history = await model.fit(batches[index].x, batches[index].y, {
    epochs: 1,
    shuffle: false,
    validationData: [batches[index].x, batches[index].y],
    batchSize: CONSTANTS.BATCH_SIZE
  });

  // Don"t block the UI frame by using tf.nextFrame()
  await tf.nextFrame();
  updateHeatmap();
  await tf.nextFrame();
}
可视化模型的准确度

使用来自均匀放置在本垒板上方的 4英尺x4英尺 栅格的预测矩阵来构建热图。在每个训练步骤之后将该矩阵传递到模型中以检查模型的准确度。使用D3库将预测结果呈现为热图。

建立预测矩阵

热图中所使用的预测矩阵从本垒板的中间开始,向左和向右各延伸2英尺宽,高度从本垒板的底部到4英尺高。好球区域位于本垒板上方1.5至3.5英尺之间。下图在二维平面上呈现出各个矩阵之间的关系:

将预测矩阵与模型一起使用

当每个批次的训练数据都在模型中训练之后,我们将预测矩阵传递到模型中,这样就可以去预测好球和坏球了。

function predictZone() {
  const predictions = model.predictOnBatch(predictionMatrix.data);
  const values = predictions.dataSync();

  // Sort each value so the higher prediction is the first element in the array:
  const results = [];
  let index = 0;
  for (let i = 0; i < values.length; i++) {
    let list = [];
    list.push({value: values[index++], strike: 0});
    list.push({value: values[index++], strike: 1});
    list = list.sort((a, b) => b.value - a.value);
    results.push(list);
  }
  return results;
}
使用D3生成热图

我们可以使用D3来显示预测结果。50x50尺寸的每个元素在SVG中呈现为10px x 10px的矩形。每个矩形的颜色取决于预测结果(好球或坏球)以及模型对该结果的确定程度(从50%-100%)。以下代码段显示了如何使用D3 svg 矩形组去更新数据:

function updateHeatmap() {
  rects.data(generateHeatmapData());
  rects
    .attr("x", (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
    .attr("y", (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
    .attr("width", CONSTANTS.HEATMAP_SIZE)
    .attr("height", CONSTANTS.HEATMAP_SIZE)
    .style("fill", (coord) => {
      if (coord.strike) {
        return strikeColorScale(coord.value);
      } else {
        return ballColorScale(coord.value);
      }
  });
}

有关使用D3绘制热图的完整详细信息,请参阅此部分。

总结

如今web前端有许多令人惊叹的库和工具来创建可视化视觉效果。把这些与机器学习的强大功能和TensorFlow.js相结合,可以使开发人员创建一些非常有趣的demo。

注:本文为译文,点击此处预览原文

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/19841.html

相关文章

  • 决胜圣诞,女神心情不猜!

    摘要:万万没想到,在圣诞节前夕,女神居然答应了在下的约会请求。想在下正如在座的一些看官一样,虽玉树临风风流倜傥,却总因猜不透女孩的心思,一不留神就落得个母胎单身。在内部将张量表示为基本数据类型的维数组。 showImg(https://segmentfault.com/img/remote/1460000017498745); 本文将结合移动设备摄像能力与 TensorFlow.js,在浏览...

    nanfeiyan 评论0 收藏0
  • Move Mirror:使 TensorFlow.js 在浏览器预测姿势之 AI 实验

    摘要:文和,创意实验室创意技术专家在机器学习和计算机视觉领域,姿势预测或根据图像数据探测人体及其姿势的能力,堪称最令人兴奋而又最棘手的一个话题。使用,用户可以直接在浏览器中运行机器学习模型,无需服务器。 文 /  Jane Friedhoff 和 Irene Alvarado,Google 创意实验室创意技术专家在机器学习和计算机视觉领域,姿势预测或根据图像数据探测人体及其姿势的能力,堪称最令人兴...

    MiracleWong 评论0 收藏0

发表评论

0条评论

最新活动
阅读需要支付1元查看
<