Julius 4.2
libsent/src/phmm/gms.c
説明を見る。
00001 
00022 /*
00023  * Copyright (c) 1991-2011 Kawahara Lab., Kyoto University
00024  * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology
00025  * Copyright (c) 2005-2011 Julius project team, Nagoya Institute of Technology
00026  * All rights reserved
00027  */
00028 
00029 /*
00030   Implementation of Gaussian Mixture Selection (old doc...)
00031   
00032   It is called from gs_calc_selected_mixture_and_cache_{safe,heu,beam} in
00033   the first pass for each frame.  It calculates all GS HMM outprob for
00034   given input frame and get the N-best GS HMM states. Then,
00035        for the selected (N-best) states:
00036            calculate the corresponding codebook,
00037            and set fallback_score[t][book] to LOG_ZERO.
00038        else:
00039            set fallback_score[t][book] to the GS HMM outprob.
00040   Later, when calculating state outprobs, the fallback_score[t][book]
00041   is consulted and,
00042        if fallback_score[t][book] == LOG_ZERO:
00043            it means it has been selected, so calculate the outprob with
00044            the corresponding codebook and its weights.
00045        else:
00046            it means it was pruned, so use the fallback_score[t][book]
00047            as its outprob.
00048 
00049            
00050   For triphone, GS HMMs should be assigned to each state.
00051   So the fallback_score[][] is kept according to the GS state ID,
00052   and corresponding GS HMM state id for each triphone state id should be
00053   kept beforehand.
00054   GS HMM Calculation:
00055        for the selected (N-best) GS HMM states:
00056            set fallback_score[t][gs_stateid] to LOG_ZERO.
00057        else:
00058            set fallback_score[t][gs_stateid] to the GS HMM outprob.
00059   triphone HMM probabilities are assigned as:
00060        if fallback_score[t][state2gs[tri_stateid]] == LOG_ZERO:
00061            it has been selected, so calculate the original outprob.
00062        else:
00063            as it was pruned, re-use the fallback_score[t][stateid]
00064            as its outprob.
00065 */
00066 
00067 
00068 #include <sent/stddefs.h>
00069 #include <sent/htk_hmm.h>
00070 #include <sent/htk_param.h>
00071 #include <sent/hmm.h>
00072 #include <sent/hmm_calc.h>
00073 
00074 #undef NORMALIZE_GS_SCORE       /* normalize score (ad-hoc) */
00075 
00076   /* GS HMMs must be defined at STATE level using "~s NAME" macro,
00077      where NAMES are like "i:4m", "s2m", etc. */
00078 
00079 
00086 static void
00087 build_gsset(HMMWork *wrk)
00088 {
00089   HTK_HMM_State *st;
00090 
00091   /* allocate */
00092   wrk->gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * wrk->OP_gshmm->totalstatenum);
00093   wrk->gsset_num = wrk->OP_gshmm->totalstatenum;
00094   /* make ID */
00095   for(st = wrk->OP_gshmm->ststart; st; st=st->next) {
00096     wrk->gsset[st->id].state = st;
00097   }
00098 }
00099 
00106 static void
00107 free_gsset(HMMWork *wrk)
00108 {
00109   free(wrk->gsset);
00110 }
00111 
00119 static boolean
00120 build_state2gs(HMMWork *wrk)
00121 {
00122   HTK_HMM_Data *dt;
00123   HTK_HMM_State *st, *cr;
00124   int i;
00125   char gstr[MAX_HMMNAME_LEN], cbuf[MAX_HMMNAME_LEN];
00126   boolean ok_p = TRUE;
00127 
00128   /* initialize */
00129   wrk->state2gs = (int *)mymalloc(sizeof(int) * wrk->OP_hmminfo->totalstatenum);
00130   for(i=0;i<wrk->OP_hmminfo->totalstatenum;i++) wrk->state2gs[i] = -1;
00131 
00132   /* parse through all HMM macro to register their state */
00133   for(dt = wrk->OP_hmminfo->start; dt; dt=dt->next) {
00134     if (strlen(dt->name) >= MAX_HMMNAME_LEN - 2) {
00135       jlog("Error: gms: too long hmm name (>%d): \"%s\"\n",
00136            MAX_HMMNAME_LEN-3, dt->name);
00137       jlog("Error: gms: change value of MAX_HMMNAME_LEN\n");
00138       ok_p = FALSE;
00139       continue;
00140     }
00141     for(i=1;i<dt->state_num-1;i++) { /* for all state */
00142       st = dt->s[i];
00143       /* skip if already assigned */
00144       if (wrk->state2gs[st->id] != -1) continue;
00145       /* set corresponding gshmm name */
00146       sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00147       /* look up the state in OP_gshmm */
00148       if ((cr = state_lookup(wrk->OP_gshmm, gstr)) == NULL) {
00149         jlog("Error: gms: GS HMM \"%s\" not defined\n", gstr);
00150         ok_p = FALSE;
00151         continue;
00152       }
00153       /* store its ID */
00154       wrk->state2gs[st->id] = cr->id;
00155     }
00156   }
00157 #ifdef PARANOIA
00158   {
00159     HTK_HMM_State *st;
00160     for(st=wrk->OP_hmminfo->ststart; st; st=st->next) {
00161       printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00162              (wrk->gsset[wrk->state2gs[st->id]].state)->name);
00163     }
00164   }
00165 #endif
00166   return ok_p;
00167 }
00168 
00175 static void
00176 free_state2gs(HMMWork *wrk)
00177 {
00178   free(wrk->state2gs);
00179 }
00180 
00181 
00182 /* sort to find N-best states */
00183 #define SD(A) idx[A-1]  ///< Index macro for heap sort
00184 #define SCOPY(D,S) D = S        ///< Element copy macro for heap sort
00185 #define SVAL(A) (fs[idx[A-1]]) ///< Element evaluation macro for heap sort
00186 #define STVAL (fs[s]) ///< Element current value macro for heap sort
00187 
00194 static void
00195 sort_gsindex_upward(HMMWork *wrk)
00196 {
00197   int n,root,child,parent;
00198   int s;
00199   int *idx;
00200   LOGPROB *fs;
00201   int neednum, totalnum;
00202 
00203   idx = wrk->gsindex;
00204   fs = wrk->t_fs;
00205   neednum = wrk->my_nbest;
00206   totalnum = wrk->gsset_num;
00207 
00208   for (root = totalnum/2; root >= 1; root--) {
00209     SCOPY(s, SD(root));
00210     parent = root;
00211     while ((child = parent * 2) <= totalnum) {
00212       if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00213         child++;
00214       }
00215       if (STVAL >= SVAL(child)) {
00216         break;
00217       }
00218       SCOPY(SD(parent), SD(child));
00219       parent = child;
00220     }
00221     SCOPY(SD(parent), s);
00222   }
00223   n = totalnum;
00224   while ( n > totalnum - neednum) {
00225     SCOPY(s, SD(n));
00226     SCOPY(SD(n), SD(1));
00227     n--;
00228     parent = 1;
00229     while ((child = parent * 2) <= n) {
00230       if (child < n && SVAL(child) < SVAL(child+1)) {
00231         child++;
00232       }
00233       if (STVAL >= SVAL(child)) {
00234         break;
00235       }
00236       SCOPY(SD(parent), SD(child));
00237       parent = child;
00238     }
00239     SCOPY(SD(parent), s);
00240   }
00241 }
00242 
00249 static void
00250 do_gms(HMMWork *wrk)
00251 {
00252   int i;
00253   
00254   /* compute all gshmm scores (in gs_score.c) */
00255   compute_gs_scores(wrk);
00256   /* sort and select */
00257   sort_gsindex_upward(wrk);
00258   for(i=wrk->gsset_num - wrk->my_nbest;i<wrk->gsset_num;i++) {
00259     /* set scores of selected states to LOG_ZERO */
00260     wrk->t_fs[wrk->gsindex[i]] = LOG_ZERO;
00261   }
00262 
00263   /* power e -> 10 */
00264 #ifdef NORMALIZE_GS_SCORE
00265   /* normalize other fallback scores (rate of max) */
00266   for(i=0;i<wrk->gsset_num;i++) {
00267     if (wrk->t_fs[i] != LOG_ZERO) {
00268       wrk->t_fs[i] *= 0.975;
00269     }
00270   }
00271 #endif
00272 }  
00273 
00274 
00282 boolean
00283 gms_init(HMMWork *wrk)
00284 {
00285   int i;
00286   
00287   /* Check gshmm type */
00288   if (wrk->OP_gshmm->is_triphone) {
00289     jlog("Error: gms: GS HMM should be a monophone model\n");
00290     return FALSE;
00291   }
00292   if (wrk->OP_gshmm->is_tied_mixture) {
00293     jlog("Error: gms: GS HMM should not be a tied mixture model\n");
00294     return FALSE;
00295   }
00296 
00297   /* Register all GS HMM states in GS_SET */
00298   build_gsset(wrk);
00299   /* Make correspondence of all triphone states to GS HMM states */
00300   if (build_state2gs(wrk) == FALSE) {
00301     jlog("Error: gms: failed in assigning GS HMM state for each state\n");
00302     return FALSE;
00303   }
00304   jlog("Stat: gms: GS HMMs are mapped to HMM states\n");
00305 
00306   /* prepare index buffer for heap sort */
00307   wrk->gsindex = (int *)mymalloc(sizeof(int) * wrk->gsset_num);
00308   for(i=0;i<wrk->gsset_num;i++) wrk->gsindex[i] = i;
00309 
00310   /* init cache status */
00311   wrk->fallback_score = NULL;
00312   wrk->gms_is_selected = NULL;
00313   wrk->gms_allocframenum = -1;
00314 
00315   /* initialize gms_gprune functions */
00316   gms_gprune_init(wrk);
00317   
00318   return TRUE;
00319 }
00320 
00329 boolean
00330 gms_prepare(HMMWork *wrk, int framenum)
00331 {
00332   LOGPROB *tmp;
00333   int t;
00334 
00335   /* allocate cache */
00336   if (wrk->gms_allocframenum < framenum) {
00337     if (wrk->fallback_score != NULL) {
00338       free(wrk->fallback_score[0]);
00339       free(wrk->fallback_score);
00340       free(wrk->gms_is_selected);
00341     }
00342     wrk->fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00343     tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->gsset_num * framenum);
00344     for(t=0;t<framenum;t++) {
00345       wrk->fallback_score[t] = &(tmp[wrk->gsset_num * t]);
00346     }
00347     wrk->gms_is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00348     wrk->gms_allocframenum = framenum;
00349   }
00350   /* clear */
00351   for(t=0;t<framenum;t++) wrk->gms_is_selected[t] = FALSE;
00352 
00353   /* prepare gms_gprune functions */
00354   gms_gprune_prepare(wrk);
00355   
00356   return TRUE;
00357 }
00358 
00365 void
00366 gms_free(HMMWork *wrk)
00367 {
00368   free_gsset(wrk);
00369   free_state2gs(wrk);
00370   free(wrk->gsindex);
00371   if (wrk->fallback_score != NULL) {
00372     free(wrk->fallback_score[0]);
00373     free(wrk->fallback_score);
00374     free(wrk->gms_is_selected);
00375   }
00376   gms_gprune_free(wrk);
00377 }
00378 
00379 
00380 
00393 LOGPROB
00394 gms_state(HMMWork *wrk)
00395 {
00396   LOGPROB gsprob;
00397   if (wrk->OP_last_time != wrk->OP_time) { /* different frame */
00398     /* set current buffer */
00399     wrk->t_fs = wrk->fallback_score[wrk->OP_time];
00400     /* select state if not yet */
00401     if (!wrk->gms_is_selected[wrk->OP_time]) {
00402       do_gms(wrk);
00403       wrk->gms_is_selected[wrk->OP_time] = TRUE;
00404     }
00405   }
00406   if ((gsprob = wrk->t_fs[wrk->state2gs[wrk->OP_state_id]]) != LOG_ZERO) {
00407     /* un-selected: return the fallback value */
00408     return(gsprob);
00409   }
00410   /* selected: calculate the real outprob of the state */
00411   return((*(wrk->calc_outprob))(wrk));
00412 }