image.png
package main.scala.chap05
import com.cra.figaro.language._
import com.cra.figaro.library.atomic.discrete.{Binomial, Poisson}
import com.cra.figaro.library.compound.If
import com.cra.figaro.algorithm.sampling.Importance
object ProductDistribution {
// 社交网络
class Network(popularity: Double) {
// 目标客户社交网络中的平均人数
val numNodes = Poisson(popularity)
}
class Model(targetPopularity: Double, productQuality: Double, affordability: Double) {
/**
*
* @param numFriends
* @param productQuality
* @return
*/
def generateLikes(numFriends: Int, productQuality: Double): Element[Int] = {
/**
*
* @param friendsVisited 已经得到产品信息的人
* @param totalLikes 受访者中喜欢产品的人数
* @param unprocessedLikes 尚未模拟向其朋友推销的人中喜欢产品的人数
* @return
*/
def helper(friendsVisited: Int, totalLikes: Int, unprocessedLikes: Int): Element[Int] = {
if (unprocessedLikes == 0) Constant(totalLikes)
else {
val unvisitedFraction = //#C
1.0 - (friendsVisited.toDouble - 1)/ (numFriends - 1) //#C
val newlyVisited = Binomial(2, unvisitedFraction)
val newlyLikes = Binomial(newlyVisited, Constant(productQuality))
Chain(newlyVisited, newlyLikes,
(visited: Int, likes: Int) =>
helper(friendsVisited + visited, totalLikes + likes, unprocessedLikes + likes - 1))
}
}
helper(1, 1, 1)
}
// 随机一个社交网络
val targetSocialNetwork = new Network(targetPopularity)
// 目标是否喜好产品,基于产品质量
val targetLikes = Flip(productQuality)
// 如果目标喜欢产品,计算朋友喜欢产品的数量。
val numberFriendsLike =
Chain(targetLikes, targetSocialNetwork.numNodes,
(l: Boolean, n: Int) =>
if (l) generateLikes(n, productQuality); else Constant(0))
// 喜欢产品的人购买产品的概率定义, 二项分布
// 社交网络中的人数,产品质量0~1
val numberBuy = Binomial(numberFriendsLike, Constant(affordability))
}
def predict(targetPopularity: Double, productQuality: Double, affordability: Double): Double = {
val model = new Model(targetPopularity, productQuality, affordability)
val algorithm = Importance(1000, model.numberBuy)
algorithm.start()
// 推理期望
val result = algorithm.expectation(model.numberBuy, (i: Int) => i.toDouble)
algorithm.kill()
result
}
/**
* Figaro is used as a simulation language,
* predicts what will happen in the future.
*
* 也可以反向推理
* @param args
*/
def main(args: Array[String]) {
println("Popularity\tProduct quality\tAffordability\tPredicted number of buyers")
println("100 \t0.5 \t0.5 \t" + predict(100, 0.5, 0.5))
println("100 \t0.5 \t0.9 \t" + predict(100, 0.5, 0.9))
println("100 \t0.9 \t0.5 \t" + predict(100, 0.9, 0.5))
println("100 \t0.9 \t0.9 \t" + predict(100, 0.9, 0.9))
println("10 \t0.5 \t0.5 \t" + predict(10, 0.5, 0.5))
println("10 \t0.5 \t0.9 \t" + predict(10, 0.5, 0.9))
println("10 \t0.9 \t0.5 \t" + predict(10, 0.9, 0.5))
println("10 \t0.9 \t0.9 \t" + predict(10, 0.9, 0.9))
}
}
网友评论