Expectation Maximization

EM dans un cas très simple

 nb_gaussian <- 3 #Nombre de gaussiennes dans le jeu de données
type <- rep(1:nb_gaussian,200)
N <- length(type)

#m <- runif(nb_gaussian) + (1:nb_gaussian)-1
#s <- runif(nb_gaussian) / 2
m <- c(0,2,4)
s <- c(0.5,0.25,0.3)

#generation aléatoire des données 
data <- rep(0,N)
for (i in 1:N){ 
	data[i] <- rnorm(1,m=m[type[i]],sd=s[type[i]])
}


#estime la moyenne/écart type en supposant les classes connues 
gaussian_estimate <- function (data, z, n) {
	g <- list()
	g$m <- rep(0,n)
	g$s <- rep(0,n)
	for (i in 1:n) {
		g$m[i] <- weighted.mean( data , z[i,] )
		g$s[i] <- sqrt(weighted.mean( (data- g$m[i])^2  , z[i,] )) 
	}
	return(g)
}

#estime la probabilité d'appartenance aux classes  
classes_estimate <- function (data, g) {
	N <- length(data)
	#le outer suivant est équilalent en nettement plus rapide à
	#
	#z <- matrix(0,length(g$m),N)
	#for (i in 1:N){
	#	for (j in 1:length(g$m)) {
	#		z[j,i] <- dnorm(data[i],g$m[j],g$s[j])
	#	} 
	#}
	z <- outer(1:length(g$m),1:N ,function(i,j) dnorm(data[j],g$m[i],g$s[i]) ) 
	
	return( prop.table(z,2) )
}


#preparation affichage
build_density <- function(x,g){
	n <- length(x)
	d <- rep(0,n)
	for (i in 1:n){
		for (j in 1:length(g$m)) {
			d[i] <- d[i] + dnorm(x[i],mean=g$m[j],sd=g$s[j])
		} 
	}
	return(d/length(g$m))
}


x <- seq(min(data),max(data),length=200)
h <- hist(data,30,col=rgb(1,0,0,0.5))
bar_width <- (max(data)-min(data))/length(h$counts)


#Le coeur de la chose inutile ici de calculer la log vraisemblance, 
#car un minimum est atteint quand z ne change plus.
#À noter qu il vaut mieux relancer plusieurs fois la procédure
#pour éviter un maximum local / point de selle
nbG <- 3
g <- list()
g$m = runif(nbG) #ici plutôt que aléatoire, c'est mieux d'utiliser les données pour initialiser
g$s = runif(nbG) #idem
z <- matrix(0,nbG,N)
i <- 0
repeat {
	i <- i + 1 
	zold <- z
	z <- classes_estimate(data , g)
	g <- gaussian_estimate(data, z, nbG)
	lines(x,N*bar_width*build_density(x,g),col=rgb(0,1-min(i*1/100,1),min(i*1/100,1),min(i*1/50,1)))
	if (sum(abs(z-zold),na.rm=T) < 10^-10) break
} 


R plot

Design selector