samedi 5 septembre 2015

Scala : Récursivité terminale @tailrec

Parlons de la pile d'exécution : l'origine de StackOverFlowError

Lors de l’exécution d’une classe Java, la JVM commence l'exécution en chargeant (loading) la classe contenant la méthode main. Ce travail de chargement est assuré par la ClassLoader qui opère à la demande ou de façon anticipée en pré chargeant certaines classes. Un thread est créé pour exécuter la méthode main. Au démarrage d’un thread, la JVM lui associe une pile de mémoire. 

La pile est une zone de mémoire qui stocke les paramètres, les résultats intermédiaires, les retours de méthode/fonction. Pour chaque méthode/fonction exécutée, il existe un contexte d’exécution (stack frame) allouée sur la pile. Un frame est utilisé pour stocker des données locales, des résultats partiels, des valeurs de retour pour les méthodes, et les pointeurs vers 'autres objets liés à la méthode. Un frame n'existe que le temps d'existence d'une méthode et est empilée sur la pile.

Lors de l’appel d’une fonction f(), elle est empilée sur la pile avec ses paramètres et ses variables locales dans le stack frame. A la fin de son exécution, elle est dépilée ainsi que ses variables locales. Ci-dessus, un exemple pour clarifier mes propos :



Vous pouvez remarquer avec cet exemple simplifié qu'un grand nombre d’appels de fonctions imbriqués peut grossir la pile et engendrer un risque important de débordement (d’où l’exception StackOverFlowError, l’option Xss vous permet d’augmenter la taille de la pile). Cela arrive généralement pour les fonctions récursives. En effet, une fonction récursive effectue un (plusieurs) appel(s) à elle-même.

Les fonctions récursives sont très souvent élégantes mais peuvent parfois nécessiter un trop grand nombre d’appels récursifs, entraînant ainsi un débordement de la pile. Chaque appel de la fonction récursive, un stack frame est empilé sur la pile d’exécution. La pile d’exécution peut rapidement s’agrandir : on se retrouve avec autant de frames que d’appels récursifs. 

Pour remédier à cette problématique, le compilateur Scala transforme les fonctions récursives en remplaçant les appels récursifs par des boucles. Cette transformation n’est possible que pour les fonctions récursives terminales

Une fonction récursive f est terminale si la dernière instruction est un appel récursif à la fonction (l’appel récursif est du type : retrun f(…)). Un exemple vaut mieux qu’un long discours ! 


def factorielle(n: Int): Int = {
  if (n < 2) 1
  else n * factorielle(n - 1)
}

La fonction récursive factorielle n’est pas terminale. En effet l’appel récursif n’est pas la dernière instruction de la fonction, puisqu’il faut d’abord effectuer le calcul de factorielle(n - 1) puis multiplier le résultat par n.

Pour transformer la fonction factorielle à une fonction récursive terminale, il suffit d’ajouter un paramètre supplémentaire nous permettant de cumuler le résultat (le calcul de factorielle n se fait par factorielle(n, 1) ) :


def factorielle(n: Int, resultat: Int): Int = {
  if (n < 2) resultat
  else factorielle(n - 1, resultat * n)
}

Pour éviter le problème de débordement de la pile, le compilateur transforme cette fonction récursive à une fonction équivalente utilisant une boucle :


def fact_iter(n: Int, resultat: Int) = {
    var valeur = n
    var res = resultat
    while (valeur >= 2) {
      res = res * valeur
      valeur = valeur - 1
    }
    res
  }

L'annotation @tailrec Scala est là pour vous aider !

Le compilateur Scala vous permet d’utiliser l’annotation @annotation.tailrec   pour indiquer une fonction récursive terminale. Si la fonction n’est pas récursive terminale le compilateur Scala produit une erreur.


Nous remarquons que la fonction récursive fibonacci n’est pas terminale. En effet l’appel récursif n’est pas la dernière instruction de la fonction, puisqu’il faut d’abord effectuer le calcul de fibonacci (n - 1) puis additionner avec fibonacci(n-2). En plus des appels récursifs, il y a une opération addition à effectuer.

La version récursive terminale :


def fibonacciTailVersion (count: Int): Int = {
    @tailrec
    def fibonacciHelper (count: Int, value: Int = 1, accum: Int = 0): Int = count match {
      case 0 => accum
      case _ => fibonacciHelper(count - 1, accum, value + accum)
    }
    fibonacciHelper(count)
  }

Aucun commentaire:

Enregistrer un commentaire