Julius 4.2
|
00001 00022 /* 00023 * Copyright (c) 1991-2011 Kawahara Lab., Kyoto University 00024 * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology 00025 * Copyright (c) 2005-2011 Julius project team, Nagoya Institute of Technology 00026 * All rights reserved 00027 */ 00028 00029 /* 00030 Implementation of Gaussian Mixture Selection (old doc...) 00031 00032 It is called from gs_calc_selected_mixture_and_cache_{safe,heu,beam} in 00033 the first pass for each frame. It calculates all GS HMM outprob for 00034 given input frame and get the N-best GS HMM states. Then, 00035 for the selected (N-best) states: 00036 calculate the corresponding codebook, 00037 and set fallback_score[t][book] to LOG_ZERO. 00038 else: 00039 set fallback_score[t][book] to the GS HMM outprob. 00040 Later, when calculating state outprobs, the fallback_score[t][book] 00041 is consulted and, 00042 if fallback_score[t][book] == LOG_ZERO: 00043 it means it has been selected, so calculate the outprob with 00044 the corresponding codebook and its weights. 00045 else: 00046 it means it was pruned, so use the fallback_score[t][book] 00047 as its outprob. 00048 00049 00050 For triphone, GS HMMs should be assigned to each state. 00051 So the fallback_score[][] is kept according to the GS state ID, 00052 and corresponding GS HMM state id for each triphone state id should be 00053 kept beforehand. 00054 GS HMM Calculation: 00055 for the selected (N-best) GS HMM states: 00056 set fallback_score[t][gs_stateid] to LOG_ZERO. 00057 else: 00058 set fallback_score[t][gs_stateid] to the GS HMM outprob. 00059 triphone HMM probabilities are assigned as: 00060 if fallback_score[t][state2gs[tri_stateid]] == LOG_ZERO: 00061 it has been selected, so calculate the original outprob. 00062 else: 00063 as it was pruned, re-use the fallback_score[t][stateid] 00064 as its outprob. 00065 */ 00066 00067 00068 #include <sent/stddefs.h> 00069 #include <sent/htk_hmm.h> 00070 #include <sent/htk_param.h> 00071 #include <sent/hmm.h> 00072 #include <sent/hmm_calc.h> 00073 00074 #undef NORMALIZE_GS_SCORE /* normalize score (ad-hoc) */ 00075 00076 /* GS HMMs must be defined at STATE level using "~s NAME" macro, 00077 where NAMES are like "i:4m", "s2m", etc. */ 00078 00079 00086 static void 00087 build_gsset(HMMWork *wrk) 00088 { 00089 HTK_HMM_State *st; 00090 00091 /* allocate */ 00092 wrk->gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * wrk->OP_gshmm->totalstatenum); 00093 wrk->gsset_num = wrk->OP_gshmm->totalstatenum; 00094 /* make ID */ 00095 for(st = wrk->OP_gshmm->ststart; st; st=st->next) { 00096 wrk->gsset[st->id].state = st; 00097 } 00098 } 00099 00106 static void 00107 free_gsset(HMMWork *wrk) 00108 { 00109 free(wrk->gsset); 00110 } 00111 00119 static boolean 00120 build_state2gs(HMMWork *wrk) 00121 { 00122 HTK_HMM_Data *dt; 00123 HTK_HMM_State *st, *cr; 00124 int i; 00125 char gstr[MAX_HMMNAME_LEN], cbuf[MAX_HMMNAME_LEN]; 00126 boolean ok_p = TRUE; 00127 00128 /* initialize */ 00129 wrk->state2gs = (int *)mymalloc(sizeof(int) * wrk->OP_hmminfo->totalstatenum); 00130 for(i=0;i<wrk->OP_hmminfo->totalstatenum;i++) wrk->state2gs[i] = -1; 00131 00132 /* parse through all HMM macro to register their state */ 00133 for(dt = wrk->OP_hmminfo->start; dt; dt=dt->next) { 00134 if (strlen(dt->name) >= MAX_HMMNAME_LEN - 2) { 00135 jlog("Error: gms: too long hmm name (>%d): \"%s\"\n", 00136 MAX_HMMNAME_LEN-3, dt->name); 00137 jlog("Error: gms: change value of MAX_HMMNAME_LEN\n"); 00138 ok_p = FALSE; 00139 continue; 00140 } 00141 for(i=1;i<dt->state_num-1;i++) { /* for all state */ 00142 st = dt->s[i]; 00143 /* skip if already assigned */ 00144 if (wrk->state2gs[st->id] != -1) continue; 00145 /* set corresponding gshmm name */ 00146 sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1); 00147 /* look up the state in OP_gshmm */ 00148 if ((cr = state_lookup(wrk->OP_gshmm, gstr)) == NULL) { 00149 jlog("Error: gms: GS HMM \"%s\" not defined\n", gstr); 00150 ok_p = FALSE; 00151 continue; 00152 } 00153 /* store its ID */ 00154 wrk->state2gs[st->id] = cr->id; 00155 } 00156 } 00157 #ifdef PARANOIA 00158 { 00159 HTK_HMM_State *st; 00160 for(st=wrk->OP_hmminfo->ststart; st; st=st->next) { 00161 printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name, 00162 (wrk->gsset[wrk->state2gs[st->id]].state)->name); 00163 } 00164 } 00165 #endif 00166 return ok_p; 00167 } 00168 00175 static void 00176 free_state2gs(HMMWork *wrk) 00177 { 00178 free(wrk->state2gs); 00179 } 00180 00181 00182 /* sort to find N-best states */ 00183 #define SD(A) idx[A-1] ///< Index macro for heap sort 00184 #define SCOPY(D,S) D = S ///< Element copy macro for heap sort 00185 #define SVAL(A) (fs[idx[A-1]]) ///< Element evaluation macro for heap sort 00186 #define STVAL (fs[s]) ///< Element current value macro for heap sort 00187 00194 static void 00195 sort_gsindex_upward(HMMWork *wrk) 00196 { 00197 int n,root,child,parent; 00198 int s; 00199 int *idx; 00200 LOGPROB *fs; 00201 int neednum, totalnum; 00202 00203 idx = wrk->gsindex; 00204 fs = wrk->t_fs; 00205 neednum = wrk->my_nbest; 00206 totalnum = wrk->gsset_num; 00207 00208 for (root = totalnum/2; root >= 1; root--) { 00209 SCOPY(s, SD(root)); 00210 parent = root; 00211 while ((child = parent * 2) <= totalnum) { 00212 if (child < totalnum && SVAL(child) < SVAL(child+1)) { 00213 child++; 00214 } 00215 if (STVAL >= SVAL(child)) { 00216 break; 00217 } 00218 SCOPY(SD(parent), SD(child)); 00219 parent = child; 00220 } 00221 SCOPY(SD(parent), s); 00222 } 00223 n = totalnum; 00224 while ( n > totalnum - neednum) { 00225 SCOPY(s, SD(n)); 00226 SCOPY(SD(n), SD(1)); 00227 n--; 00228 parent = 1; 00229 while ((child = parent * 2) <= n) { 00230 if (child < n && SVAL(child) < SVAL(child+1)) { 00231 child++; 00232 } 00233 if (STVAL >= SVAL(child)) { 00234 break; 00235 } 00236 SCOPY(SD(parent), SD(child)); 00237 parent = child; 00238 } 00239 SCOPY(SD(parent), s); 00240 } 00241 } 00242 00249 static void 00250 do_gms(HMMWork *wrk) 00251 { 00252 int i; 00253 00254 /* compute all gshmm scores (in gs_score.c) */ 00255 compute_gs_scores(wrk); 00256 /* sort and select */ 00257 sort_gsindex_upward(wrk); 00258 for(i=wrk->gsset_num - wrk->my_nbest;i<wrk->gsset_num;i++) { 00259 /* set scores of selected states to LOG_ZERO */ 00260 wrk->t_fs[wrk->gsindex[i]] = LOG_ZERO; 00261 } 00262 00263 /* power e -> 10 */ 00264 #ifdef NORMALIZE_GS_SCORE 00265 /* normalize other fallback scores (rate of max) */ 00266 for(i=0;i<wrk->gsset_num;i++) { 00267 if (wrk->t_fs[i] != LOG_ZERO) { 00268 wrk->t_fs[i] *= 0.975; 00269 } 00270 } 00271 #endif 00272 } 00273 00274 00282 boolean 00283 gms_init(HMMWork *wrk) 00284 { 00285 int i; 00286 00287 /* Check gshmm type */ 00288 if (wrk->OP_gshmm->is_triphone) { 00289 jlog("Error: gms: GS HMM should be a monophone model\n"); 00290 return FALSE; 00291 } 00292 if (wrk->OP_gshmm->is_tied_mixture) { 00293 jlog("Error: gms: GS HMM should not be a tied mixture model\n"); 00294 return FALSE; 00295 } 00296 00297 /* Register all GS HMM states in GS_SET */ 00298 build_gsset(wrk); 00299 /* Make correspondence of all triphone states to GS HMM states */ 00300 if (build_state2gs(wrk) == FALSE) { 00301 jlog("Error: gms: failed in assigning GS HMM state for each state\n"); 00302 return FALSE; 00303 } 00304 jlog("Stat: gms: GS HMMs are mapped to HMM states\n"); 00305 00306 /* prepare index buffer for heap sort */ 00307 wrk->gsindex = (int *)mymalloc(sizeof(int) * wrk->gsset_num); 00308 for(i=0;i<wrk->gsset_num;i++) wrk->gsindex[i] = i; 00309 00310 /* init cache status */ 00311 wrk->fallback_score = NULL; 00312 wrk->gms_is_selected = NULL; 00313 wrk->gms_allocframenum = -1; 00314 00315 /* initialize gms_gprune functions */ 00316 gms_gprune_init(wrk); 00317 00318 return TRUE; 00319 } 00320 00329 boolean 00330 gms_prepare(HMMWork *wrk, int framenum) 00331 { 00332 LOGPROB *tmp; 00333 int t; 00334 00335 /* allocate cache */ 00336 if (wrk->gms_allocframenum < framenum) { 00337 if (wrk->fallback_score != NULL) { 00338 free(wrk->fallback_score[0]); 00339 free(wrk->fallback_score); 00340 free(wrk->gms_is_selected); 00341 } 00342 wrk->fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum); 00343 tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->gsset_num * framenum); 00344 for(t=0;t<framenum;t++) { 00345 wrk->fallback_score[t] = &(tmp[wrk->gsset_num * t]); 00346 } 00347 wrk->gms_is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum); 00348 wrk->gms_allocframenum = framenum; 00349 } 00350 /* clear */ 00351 for(t=0;t<framenum;t++) wrk->gms_is_selected[t] = FALSE; 00352 00353 /* prepare gms_gprune functions */ 00354 gms_gprune_prepare(wrk); 00355 00356 return TRUE; 00357 } 00358 00365 void 00366 gms_free(HMMWork *wrk) 00367 { 00368 free_gsset(wrk); 00369 free_state2gs(wrk); 00370 free(wrk->gsindex); 00371 if (wrk->fallback_score != NULL) { 00372 free(wrk->fallback_score[0]); 00373 free(wrk->fallback_score); 00374 free(wrk->gms_is_selected); 00375 } 00376 gms_gprune_free(wrk); 00377 } 00378 00379 00380 00393 LOGPROB 00394 gms_state(HMMWork *wrk) 00395 { 00396 LOGPROB gsprob; 00397 if (wrk->OP_last_time != wrk->OP_time) { /* different frame */ 00398 /* set current buffer */ 00399 wrk->t_fs = wrk->fallback_score[wrk->OP_time]; 00400 /* select state if not yet */ 00401 if (!wrk->gms_is_selected[wrk->OP_time]) { 00402 do_gms(wrk); 00403 wrk->gms_is_selected[wrk->OP_time] = TRUE; 00404 } 00405 } 00406 if ((gsprob = wrk->t_fs[wrk->state2gs[wrk->OP_state_id]]) != LOG_ZERO) { 00407 /* un-selected: return the fallback value */ 00408 return(gsprob); 00409 } 00410 /* selected: calculate the real outprob of the state */ 00411 return((*(wrk->calc_outprob))(wrk)); 00412 }