Julius 4.2
libsent/src/phmm/vsegment.c
説明を見る。
00001 
00018 /*
00019  * Copyright (c) 1991-2011 Kawahara Lab., Kyoto University
00020  * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology
00021  * Copyright (c) 2005-2011 Julius project team, Nagoya Institute of Technology
00022  * All rights reserved
00023  */
00024 
00025 #include <sent/stddefs.h>
00026 #include <sent/htk_param.h>
00027 #include <sent/hmm.h>
00028 
00052 LOGPROB
00053 viterbi_segment(HMM *hmm, HTK_Param *param, HMMWork *wrk, boolean multipath, int *endstates, int ulen, int **id_ret, int **seg_ret, LOGPROB **uscore_ret, int *slen_ret)
00054 {
00055   /* for viterbi */
00056   LOGPROB *nodescore[2];        /* node buffer */
00057   SEGTOKEN **tokenp[2];         /* propagating token which holds segment info */
00058   int startt, endt;
00059   int *from_node;
00060   int *u_end, *u_start; /* the node is an end of the word, or -1 for non-multipath mode*/
00061   int i, n;
00062   unsigned int t;
00063   int tl,tn;
00064   LOGPROB tmpsum;
00065   A_CELL *ac;
00066   SEGTOKEN *newtoken, *token, *tmptoken, *root;
00067   LOGPROB result_score;
00068   LOGPROB maxscore, minscore;   /* for debug */
00069   int maxnode;                  /* for debug */
00070   int *id, *seg, slen;
00071   LOGPROB *uscore;
00072 
00073   /* assume more than 1 units */
00074   if (ulen < 1) {
00075     jlog("Error: vsegment: no unit?\n");
00076     return LOG_ZERO;
00077   }
00078 
00079   if (!multipath) {
00080     /* initialize unit start/end marker */
00081     u_start = (int *)mymalloc(hmm->len * sizeof(int));
00082     u_end   = (int *)mymalloc(hmm->len * sizeof(int));
00083     for (n = 0; n < hmm->len; n++) {
00084       u_start[n] = -1;
00085       u_end[n] = -1;
00086     }
00087     u_start[0] = 0;
00088     u_end[endstates[0]] = 0;
00089     for (i=1;i<ulen;i++) {
00090       u_start[endstates[i-1]+1] = i;
00091       u_end[endstates[i]] = i;
00092     }
00093 #if 0
00094     for (i=0;i<hmm->len;i++) {
00095       printf("unit %d: start=%d, end=%d\n", i, u_start[i], u_end[i]);
00096     }
00097 #endif
00098   }
00099 
00100   /* initialize node buffers */
00101   tn = 0;
00102   tl = 1;
00103   root = NULL;
00104   for (i=0;i<2;i++){
00105     nodescore[i] = (LOGPROB *)mymalloc(hmm->len * sizeof(LOGPROB));
00106     tokenp[i] = (SEGTOKEN **)mymalloc(hmm->len * sizeof(SEGTOKEN *));
00107     for (n = 0; n < hmm->len; n++) {
00108       tokenp[i][n] = NULL;
00109     }
00110   }
00111   for (n = 0; n < hmm->len; n++) {
00112     nodescore[tn][n] = LOG_ZERO;
00113     newtoken = (SEGTOKEN *)mymalloc(sizeof(SEGTOKEN));
00114     newtoken->last_id = -1;
00115     newtoken->last_end_frame = -1;
00116     newtoken->last_end_score = 0.0;
00117     newtoken->list = root;
00118     root = newtoken;
00119     newtoken->next = NULL;
00120     tokenp[tn][n] = newtoken;
00121   }
00122   from_node = (int *)mymalloc(sizeof(int) * hmm->len);
00123   
00124   /* first frame: only set initial score */
00125   /*if (hmm->state[0].is_pseudo_state) {
00126     jlog("Warning: state %d: pseudo state?\n", 0);
00127     }*/
00128   if (multipath) {
00129     nodescore[tn][0] = 0.0;
00130   } else {
00131     nodescore[tn][0] = outprob(wrk, 0, &(hmm->state[0]), param);
00132   }
00133 
00134   /* do viterbi for rest frame */
00135   if (multipath) {
00136     startt = 0;  endt = param->samplenum;
00137   } else {
00138     startt = 1;  endt = param->samplenum - 1;
00139   }
00140   for (t = startt; t <= endt; t++) {
00141     i = tl;
00142     tl = tn;
00143     tn = i;
00144     maxscore = LOG_ZERO;
00145     minscore = 0.0;
00146 
00147     /* clear next scores */
00148     for (i=0;i<hmm->len;i++) {
00149       nodescore[tn][i] = LOG_ZERO;
00150       from_node[i] = -1;
00151     }
00152 
00153     /* select viterbi path for each node */
00154     for (n = 0; n < hmm->len; n++) {
00155       if (nodescore[tl][n] <= LOG_ZERO) continue;
00156       for (ac = hmm->state[n].ac; ac; ac = ac->next) {
00157         tmpsum = nodescore[tl][n] + ac->a;
00158         if (nodescore[tn][ac->arc] < tmpsum) {
00159           nodescore[tn][ac->arc] = tmpsum;
00160           from_node[ac->arc] = n;
00161         }
00162       }
00163     }
00164     /* propagate token, appending new if path was selected between units */
00165     if (multipath) {
00166       for (n = 0; n < hmm->len; n++) {
00167         if (from_node[n] == -1 || nodescore[tn][n] <= LOG_ZERO) {
00168           /*tokenp[tn][n] = NULL;*/
00169         } else {
00170           i=0;
00171           while (from_node[n] > endstates[i]) i++;
00172           if (n > endstates[i]) {
00173             newtoken = (SEGTOKEN *)mymalloc(sizeof(SEGTOKEN));
00174             newtoken->last_id = i;
00175             newtoken->last_end_frame = t-1;
00176             newtoken->last_end_score = nodescore[tl][from_node[n]];
00177             newtoken->list = root;
00178             root = newtoken;
00179             newtoken->next = tokenp[tl][from_node[n]];
00180             tokenp[tn][n] = newtoken;
00181           } else {
00182             tokenp[tn][n] = tokenp[tl][from_node[n]];
00183           }
00184         }
00185       }
00186     } else {                    /* not multipath */
00187       for (n = 0; n < hmm->len; n++) {
00188         if (from_node[n] == -1) {
00189           tokenp[tn][n] = NULL;
00190         } else if (nodescore[tn][n] <= LOG_ZERO) {
00191           tokenp[tn][n] = tokenp[tl][from_node[n]];
00192         } else {
00193           if (u_end[from_node[n]] != -1 && u_start[n] != -1
00194               && from_node[n] !=  n) {
00195             newtoken = (SEGTOKEN *)mymalloc(sizeof(SEGTOKEN));
00196             newtoken->last_id = u_end[from_node[n]];
00197             newtoken->last_end_frame = t-1;
00198             newtoken->last_end_score = nodescore[tl][from_node[n]];
00199             newtoken->list = root;
00200             root = newtoken;
00201             newtoken->next = tokenp[tl][from_node[n]];
00202             tokenp[tn][n] = newtoken;
00203           } else {
00204             tokenp[tn][n] = tokenp[tl][from_node[n]];
00205           }
00206         }
00207       }
00208     }
00209 
00210     if (multipath) {
00211       /* if this is next of last frame, loop ends here */
00212       if (t == param->samplenum) break;
00213     }
00214         
00215     /* calc outprob to new nodes */
00216     for (n = 0; n < hmm->len; n++) {
00217       if (multipath) {
00218         if (hmm->state[n].out.state == NULL) continue;
00219       }
00220       if (nodescore[tn][n] > LOG_ZERO) {
00221         if (hmm->state[n].is_pseudo_state) {
00222           jlog("Warning: vsegment: state %d: pseudo state?\n", n);
00223         }
00224         nodescore[tn][n] += outprob(wrk, t, &(hmm->state[n]), param);
00225       }
00226       if (nodescore[tn][n] > maxscore) { /* for debug */
00227         maxscore = nodescore[tn][n];
00228         maxnode = n;
00229       }
00230     }
00231     
00232 #if 0
00233     for (i=0;i<ulen;i++) {
00234       printf("%d: unit %d(%d-%d): begin_frame = %d\n", t - 1, i,
00235              (i > 0) ? endstates[i-1]+1 : 0, endstates[i],
00236              (multipath && tokenp[tl][endstates[i]] == NULL) ? -1 : tokenp[tl][endstates[i]]->last_end_frame + 1);
00237     }
00238 #endif
00239 
00240     /* printf("t=%3d max=%f n=%d\n",t,maxscore, maxnode); */
00241     
00242   }
00243 
00244   result_score = nodescore[tn][hmm->len-1];
00245 
00246   /* parse back the last token to see the trail of best viterbi path */
00247   /* and store the informations to returning buffer */
00248   slen = 0;
00249   if (!multipath) slen++;
00250   for(token = tokenp[tn][hmm->len-1]; token; token = token->next) {
00251     if (token->last_end_frame == -1) break;
00252     slen++;
00253   }
00254   id = (int *)mymalloc(sizeof(int)*slen);
00255   seg = (int *)mymalloc(sizeof(int)*slen);
00256   uscore = (LOGPROB *)mymalloc(sizeof(LOGPROB)*slen);
00257 
00258   if (multipath) {
00259     i = slen - 1;
00260   } else {
00261     id[slen-1] = ulen - 1;
00262     seg[slen-1] = t - 1;
00263     uscore[slen-1] = result_score;
00264     i = slen - 2;
00265   }
00266   for(token = tokenp[tn][hmm->len-1]; token; token = token->next) {
00267     if (i < 0 || token->last_end_frame == -1) break;
00268     id[i] = token->last_id;
00269     seg[i] = token->last_end_frame;
00270     uscore[i] = token->last_end_score;
00271     i--;
00272   }
00273 
00274   /* normalize scores by frame */
00275   for (i=slen-1;i>0;i--) {
00276     uscore[i] = (uscore[i] - uscore[i-1]) / (seg[i] - seg[i-1]);
00277   }
00278   uscore[0] = uscore[0] / (seg[0] + 1);
00279 
00280   /* set return value */
00281   *id_ret = id;
00282   *seg_ret = seg;
00283   *uscore_ret = uscore;
00284   *slen_ret = slen;
00285 
00286   /* free memory */
00287   if (!multipath) {
00288     free(u_start);
00289     free(u_end);
00290   }
00291   free(from_node);
00292   token = root;
00293   while(token) {
00294     tmptoken = token->list;
00295     free(token);
00296     token = tmptoken;
00297   }
00298   for (i=0;i<2;i++) {
00299     free(nodescore[i]);
00300     free(tokenp[i]);
00301   }
00302 
00303   return(result_score);
00304 
00305 }