Clustering de Dirichlet démystifié

Le clustering de Dirichlet a connu à partir du début des années 2000 un véritable engouement de la communauté Machine Learning. Il a parfois un peu à tord eu la réputation d'un algorithme un peu magique reposant sur des maths compliquées. Le but de cette page est de tenter de démystifier cet algorithme.

Loi de Dirichlet

Il s'agit de la généralistion à plusieurs variables de la loi bêta. Pour rappel, une possibilité d'interprétation de la loi bêta de paramètres α,β est de la voir comme la vraissemblance du paramètre p d'une loi binômiale après avoir observé α − 1 succès (avec probabilité p de succès) et β − 1 échecs (avec probabilité 1 − p). On dit que beta est la congugée de la loi binômiale.

On va définir de la même façon la loi de Dirichlet comme la conjuguée d'une loi multinômiale.

Soit un entier K≥2 et un vecteur

Formula : $$\qquad\alpha=(\alpha_1,\dots,\alpha_K) \in {(\mathbb{R}^{+})}^K $$

La loi de Dirichlet d'ordre K et de paramètres α a pour densité :

Formula : $$f(x_1,\dots, x_{K}; \alpha) = \frac{1}{\mathrm{B}(\alpha)} \prod_{i=1}^K x_i^{\alpha_i - 1}$$

Sur l'hyper-triangle (simplexe) défini par x1, ..., xK > 0 et x1 +... + xK-1< 1. Avec la somme des xi valant 1. En dehors de cet espace la densité vaut 0.

La constante de normalisation est la fonction bêta multinomiale, qui s'exprime à l'aide de la fonction gamma (généralisation de la factorielle):

Formula : $$\mathrm{B}(\alpha) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma\bigl(\sum_{i=1}^K \alpha_i\bigr)}$$

Sous R la loi est disponibles dans le package MCMCpack.

 
#chagement d'une librairie offrant une fonction de calcul de la densité
#de Dirichlet (ddirichlet)
library("MCMCpack")

#preparation d'une grille
x1 <- seq(0.01, 0.99, length= 40)
x2 <- x1

#fabrication de la densité pour alpha=(3,6,3) par produit tensoriel
#on aurait pu aussi procéder avec 2 boucles
f <- function(x,y) {  ddirichlet( x=cbind(x,y, (1-x-y) ), alpha=c(3,6,3) );  }
z <- outer(x1, x2, f) 

#affichage
op <- par(bg = "white")
persp(x1, x2, z, theta = 30, phi = 30, expand = 0.5, col = "lightblue",
      ltheta = 120, shade = 0.75, ticktype = "detailed",
      xlab = "X1", ylab = "X2", zlab = "Dirichlet(3,6,3)"
)->res
round(res,3);

R plot

Ce graphique correspond à la vraisemblance d'une loi multinômiale à 3 possibilités (comme un dé, peut-être pipé, à 3 faces) qui aurait généré 2 fois le 1, 5 fois le 2 et 2 fois le 3. Le maximum se trouve en (0.222,0.555) ce qui signifie que le plus probable est que la face 1 ait une probabilité 2/(2+2+5) = 2/9 = 0.222222… de sortir, la 2e ait une proba de 5/9=0.555… et la 3e ce qui reste (2/9). La position du max n'est pas très interessante, c'est plutôt la répartition de la vraisemblance autour du pic qui est recherchée, en particulier pour construire des intervalles de confiance.

Évidemment si l'on augmente le nombre de tirages, l'on 'pique' la densité :

R plot

On notera

Formula : $$Dir(\alpha_1,\ldots, \alpha_K ) $$

la densité correspondant aux paramètres αi (en particulier il s'agit d'une fonction de dimension K)

Processus de Dirichlet

Une fois le principe de la loi de Dirichlet intégré, il faut maintenant parler des processus du même nom. En fait c'est une généralisation de la loi de Dirichlet au cas ou K est infini. L'on a coutume de dire que les processus de Dirichlet sont une distribution sur les distributions (si H tend vers l'infini le vecteur alpha correspondant se met plus ou moins à représenter une densité).

En particuler cela signifie que si l'on fait un tirage selon un processus de Dirichlet DP, l'on va obtenir une loi de probabilité (sur un espace infini).

Soit G0 une mesure de probabilité et α un réel strictement positif. Pour toute partition (A1…AK ) d'un espace mesurable, on définit G par

Formula : $$G(A_1), \ldots, G(A_K) \rightsquigarrow Dir(\alpha G_0(A_1),\ldots, \alpha G_0(A_K) ) $$
que l'on note :
Formula : $$G \rightsquigarrow DP(G_0,\alpha) $$

Cela signifie que la mesure d'un ensemble est elle même tirée selon une loi de Dirichlet dont les paramètres sont proportionnels à une mesure 'initiale' G0. De plus, presque sûrement, les tirages sont discrets (Sethuraman 94) :

Formula : $$P\left(G(\theta)\right) = \sum_{i=1}^{\infty} w_i \delta_{\theta_i}(\theta)  $$

Un autre truc qui tombe quand même bien c'est que la loi postérieure d'un processus de Dirichlet après n observations est aussi un processus de Dirichlet (et en fait c'est vraiment très très pratique car cela donne un accès à l'estimation de G0 facilement à partir des réalisations).

Formula : $$P\left(G(\theta)|\theta_1,\ldots,\theta_n\right) = DP\left(\frac{\alpha}{\alpha+n}G_0 + \frac1{\alpha+n} \sum_{i=1}^n\delta_{\theta_i}, \alpha+n \right)  $$

Du coup, on peut en faire une simulation numérique par la méthode du restaurant chinois qui met mieux en évidence la tendance au clustering des processus de Dirichlet. Par exemple avec G0 étant une N(0,1) et avec α=1 :

 
#tirage selon une multinomiale attend un vecteur x sommant <= 1 
multinomiale <- function (x) {
	stopifnot(sum(x)<=1)
	v <- runif(1)
	s <- 0
	for (i in 1:length(x) ) {
		s <- s + x[i]
		if (v < s) return(i)
		}
	return(length(x)+1)
	}
	
#Binding de rnorm pour pouvoir faire un appel sans paramètres
NormaleCentreeReduite <- function() {
	return(rnorm(1))
	}

#Le restaurant chinois (simulation de processus de Dirichlet)
DP <- function(nb,G0,alpha) {
	r<-list()
	#premier client à la première table valeur tirée selon G0
	x <- G0()
	n <- 1
	 
	#les clients suivants 
	for (i in 2:nb){
		table  <- multinomiale ( n / (i-1+alpha) )
		if (table > length(x) ){
				#nouvelle table
				x <- c( x, G0() )
				n <- c( n, 1)
			} else {
				#Nouveau client à une ancienne table
				n[table] <- n[table] + 1
			} 
		}
	
	#On reconstuit le vecteur complet et on mélange
	#return( sample(rep(x, n)) );
	
	#Mais ici on veut montrer la convergence vers G0
	#Donc on va retourner les valeurs générées et 
	#leur nombre d'occurences
	return(list(x=x,n=n))
	}
d <- DP(1000,NormaleCentreeReduite, 1)
plot (d$x, d$n, pch=16, col="blue", xlab="x", ylab="#occurences", type="h")

R plot

On voit bient le phénomène d'accumulation et le support discret. Si on fait grandir alpha on aura un comportement plus proche de G0 (normal puisque le nouveau client choisit plus souvent une nouvelle table, donc un tirage selon G0.

R plot

R plot

Appartée : Échantillonnage de Gibbs

Il s'agit d'un cas particulier de de Métropolis-Hasting pour simuler des tirages d'une loi jointe alors que l'on est juste capable de générer des tirages selon les probabilités conddtionnelles.

On veut donc pouvoir simuler des tirages selon une loi Formula : $\pi(\theta)=\pi(\theta_1,\ldots,\theta_k)$ alors que l'on juste capable de simuler Formula : $\pi_i(\theta_i|\theta_{j\neq i})$

On commence par choisir un point de départ (fixe ou aléatoire)

Formula : $\theta = (\theta_1,\ldots,\theta_k)$

Puis on effectue un grand nombre de fois la série de remplacement suivants (en pratique 1000, mais d'un point de vue théorique il faut être beaucoup plus propre que cela et vérifier qu'il n'y a pas eu de problèmes type convergence)

Formula : 
\begin{eqnarray*}
\theta_0 & \leftarrow  & \pi_0(\theta_0|\theta_{i\neq 0}) \\
\theta_1 & \leftarrow & \pi_1(\theta_1|\theta_{i\neq 1}) \\
&\ldots& \\
\theta_k &\leftarrow & \pi_k(\theta_k|\theta_{i\neq k}) \\
\end{eqnarray*}

Tous les remplacements sont effectués immédiatement et non en parallèle. Après "suffisamment" d'itérations, Formula : $\theta$ contient un échantillon tiré selon la loi Formula : $\pi$.

Clustering de Dirichlet

En fait la terminologie est un peu trompeuse. On devrait plutôt parler de Mixtures par Processus de Dirichlet. En effet, ces processus étant par nature discrets, ils sont très mal adaptés pour une utilisation en tant que prior pour faire du clustering de données continues.

Bon on devine la suite : on va choisir une densité de probabilité dont les paramètres vont suivre un DPM ce qui autorisera théoriquement un nombre infini de cluster (en pratique le nombre de cluster sera lié au choix d'alpha). Classiquement on choisira une loi normale, c'est exactement comme si chaque table du restaurant chinois possédait sa propre gaussienne dont les paramètres sont fournis par un DP notons ces paramètres Formula : $\eta_1,\ldots,\eta_n$. Ces paramètres étant issus d'un DP certains d'entre eux seront identiques ce qui nous fournira le clustering.

Évidemment le problème est d'estimer les paramètres de notre modèle en fonction des observations.Autant le dire tout de suite, c'est de ce coté que se situe la vraie difficulté pratique du clustering de Dirichlet. La méthode la plus simple est d'avoir recours à l'échantillonnage de Gibbs. En fait une des meilleures solutions connues est de faire un blocked Gibbs. Dans ce cas notre objectif est de pouvoir échantilloner selon la loi suivante (en notant x1,…,xn les données observées).

Formula : $$ P(\eta_1,\ldots,\eta_n | x_1,\ldots,x_n) $$

Notons qu'en général on se contente d'échantillonner selon ce loi et de prendre un tirage, car le calcul final de l'intégrale n'a pas vraiement de sens avec des clusters.

En plus manque de bol l'estimation de la densité d'un processus du Dirichlet est problématique car la loi étant presque sûrement discrète, sa densité est consitué de diracs. Une astuce possible est d'essayer d'estimer la densité du processus convolé à une autre loi (typiquement une gaussienne) afin "d'étaler" les diracs.

En pratique l'agorithme est vraiment très simple. On choisit la loi de nos clusters (la loi qui se fait mélanger) et il faut être capable de donner la vraisemblance pour une nouvelle donnée d'être tirée selon la loi de ce cluster ainsi que de pouvoir estimer les paramètres de cette loi à partir des données.

Ensuite un pas de la marche aléatoire est donné à partir d'une estimation courante des classes par la méthode suivante :

Parcourir tous les points de l'ensemble des données. Notons X le point courant.

  1. Supprimer le point X de "son" cluster. Si le cluster ne contient plus de points il est supprimé.
  2. Réestimer les paramètres de la loi de chaque cluster (sans tenir compte de X).
  3. Contruire le vecteur de probabilité P (de longueur le nombre de clusters "actifs") contenant la vraisemblance pour chaque cluster d'avoir généré X et alpha en dernière position.
  4. Tirer la nouvelle classe de X selon la multinômiale donnée par P. Si la dernière position est choisie alors on rend actif un nouveau cluster (qui devrait avoir ses paramètres tirés selon G0, mais en pratique on estime les paramètres de ce nouveau cluster à partir des points).

Il ne reste plus qu'à répéter la chose plusieurs fois et à prendre la réalisation possédant la plus grande vraisemblance.

Un petit test

Pour simplifier un peu on va se placer en dimension 2 et en considérant des mélanges de gausiennes dont les deux dimensions sont indépendantes. G0 est donc une distribution sur R4 (pour chaque gaussienne correspondant à un cluster l'on a besoin des deux moyennes et des deux écarts types). Mais en utilisant la formule sur la postérieure des DP, l'on a pas besoin d'expliciter G0.

 
##############################################################
#                Génération des données                      #
##############################################################
#initialisation du générateur aléatoire avec une constante
#pour la reproductibilité des essais
set.seed(100)  

#Generation données notons que cette génération peut se voir 
#comme une réalisation d'une mixture par DPM de gaussiennes.
nb_gaussian <- 3 
dim <- 2
type <- rep(1:nb_gaussian, 200*(1:nb_gaussian) )
N <- length(type)
m <- matrix ( c(0,2, 0,0, 3,1) ,ncol=dim, byrow=TRUE )
s <- matrix ( c(0.5,0.5, 0.25,0.1, 1,0.3), ncol=dim, byrow=TRUE )


data <- matrix(0,N,dim)
for (i in 1:N){
	for (j in 1:dim){
		data[i,j] <- rnorm(1,m=m[type[i],j],sd=s[type[i],j])
		} 	
	}

##############################################################
#                Clustering de Dirichlet                     #
##############################################################

	
#tirage selon une multinomiale proportionnelement à un vecteur x sommant à 1
multinomiale <- function (x) {
	v <- runif(1)
	s <- 0
	for (i in 1:length(x) ) {
		s <- s + x[i]
		if (v < s) break;
		}
	return(i) 
	}


# Donne le log de la vraisemblance qu'une donnée data appartienne 
# au cluster paramétré par params. Ici on suppose que l'on a 
# affaire à des lois normales
logVraisemblance <- function(params,data) {
	lv <- 0;
	for (dim in 1:length(data) ) {
		m <- params[2*(dim-1)+1]
		s <- params[2*(dim-1)+2]
		lv <- lv + dnorm(data[dim], m, s, log=TRUE)
	}
	return(lv)	
}

# Estime les paramètres à l'intérieur d'un cluster
estimeParametres <- function (datas) {
	if (is.vector(datas)) datas<-matrix(datas,nrow=1)
	p <- c()
	for (dim in 1:ncol(datas) ) {
		m <- mean( datas[,dim] )
		s <-   sd( datas[,dim] )
		if (is.na(s)) s<-1 #Si 1 seul point la variance est non définie 
		p <- c(p,m,s) 
	}
	return(p)
}

DP_gibbs <- function(datas, alpha, niter=100, classes = NULL) {
	
	# Pour pouvoir faire un démarrage "à chaud" sur les classes
	if (is.null(classes)) classes <- rep(1,nrow(datas))
	nn <- as.numeric(table(factor(classes))) #nombre de points par cluster 
		
	# Liste des paramètres des lois des clusters
	eta <- list()
	for (i in 1:max(classes)) eta[[i]] <- estimeParametres(datas[classes==i,])   
	
	maxLogVraisemblance <- -Inf

	# itérations de l'échantillonneur de Gibbs
	for (iter in 1:niter) {
		# Itération sur chaque point faisant partie des données
		for (i in 1:nrow(datas) ){
			
			# Supprime la classe de l'élément no j ainsi que le 
			# cluster correspondant si c'était le seul point
			nn[classes[i]] <- nn[classes[i]] - 1
			if (nn[classes[i]]==0) { #dernier élément de ce cluster
				#supprime les paramètres de la Gausienne
				eta[[ classes[i] ]] <- NULL 
				nn <- nn[-classes[i]]
				#renumérote les classes pour tenir compte de la suppression d'un cluster
				classes[ classes>classes[i] ] <-  classes[ classes>classes[i] ] - 1 
				
			} else {
				#Réestimer les paramètres eta pour le cluster dont 
				#on vient de supprimer un point
				eta[[classes[i]]] <- estimeParametres(datas[classes == classes[i],]) 
			}
			
			# Calcule dans l'espace log la probabilité d'appartenir
			# au cluster k pour le i-ème point. Le alpha est la 
			# pour laisser la possibilité d' affecter le point à un 
			# nouveau cluster
			
			proba <- log(c(nn,alpha)) 
			for (k in 1:length(eta)){
				proba[k] <- proba[k] + logVraisemblance(eta[[k]],datas[i,]);
				}
			proba <- exp(proba - max(proba)); #-max pour éviter des pbm numériques
    		proba <- proba / sum(proba);
				
			#Assigne le point courant à un cluster (éventuellent en créant le cluster)
			classes[i] <- multinomiale(proba)
			if (length(nn) < classes[i]) nn[classes[i]]<-0 
			nn[classes[i]] <- nn[classes[i]] + 1
			eta[[classes[i]]] <- estimeParametres(datas[classes == classes[i],])
		}
		
		if (niter>1){
			#On garde l'itération la plus vraisemblable
			v <- 0;
			for (i in nrow(data) ){
				v <- v +  logVraisemblance(eta[[classes[i]]],datas[i,]);
			}
			if (v> maxLogVraisemblance){
				maxLogVraisemblance <- v
				classesVraisemblables <- classes
			} 
		} else {
			classesVraisemblables <- classes
		}
	}
	return(classesVraisemblables);			
}

classes <- DP_gibbs(data, alpha=0.1, 100) 
plot(data,col=classes,pch=16)
	

R plot

Pour comparaison voici les vraies classes

R plot

Et une petite animation représentant l'évolution du clustering au cours de l'échantillonnage de Gibbs.

Évidemment si vous avez suivi vous avez noté l'arnaque. Pour le moment la classe retournée est simplement une réalisation d'un processus de Dirichlet. Il faudrait calculer l'intégrale, mais moyenner sur une classe pose problème. Les plus attentifs auront remarqué que la valeur de alpha est de 0.1 dans l'exemple. Quand le nombre de points augmente, le clustering de dirichlet tend aussi à augmenter le nombre de clusters ce qui n'est pas toujours un comportement souhaitable.

Et avec les méthodes classiques ?

On peut aussi essayer un kmeans (en fournissant k qui ne se retrouve pas très bien dans l'inertie sur ce jeu de données). Cela n'est pas très efficace car les groupes n'ont pas la même variance ce qui embête kmeans. Par contre avec un clustering hiérarchique avec la méthode de Ward :

R plot

Ce qui donne en découpant en 3 ou 5 classes (selon ce que l'on pense voir sur le dendrogramme)

R plot

R plot

Conclusion

Le clustering de Dirichlet est un beau formalisme et offre une manière élégante de choisir le nombre de clusters…

Est-ce que cela apporte vraiment quelque chose en pratique par rapport méthodes classiques (kmeans/EM/Clustering hiérarchique) ? Oui : cela marche beaucoup mieux pour publier à NIPS… Pour les applications industrielles… Joker… En tout cas pour le temps de calcul c'est quand même pas terrible (MCMC powered). Par contre comme souvent avec le MCMC cela offre une solution pouvant être acceptable dans les cas difficiles (quand tout le reste a échoué).

Design selector