00001
00041
00042
00043
00044
00045
00046
00047 #include <julius.h>
00048
00049 #undef MES
00050
00051 static LOGPROB *gmm_score;
00052 static int framecount;
00053
00054
00055 static LOGPROB *OP_calced_score;
00056 static int *OP_calced_id;
00057 static int OP_calced_num;
00058 static int OP_calced_maxnum;
00059 static int OP_gprune_num;
00060 static VECT *OP_vec;
00061 static short OP_veclen;
00062
00082 static int
00083 gmm_find_insert_point(LOGPROB score, int len)
00084 {
00085
00086 int left = 0;
00087 int right = len - 1;
00088 int mid;
00089
00090 while (left < right) {
00091 mid = (left + right) / 2;
00092 if (OP_calced_score[mid] > score) {
00093 left = mid + 1;
00094 } else {
00095 right = mid;
00096 }
00097 }
00098 return(left);
00099 }
00100
00121 static int
00122 gmm_cache_push(int id, LOGPROB score, int len)
00123 {
00124 int insertp;
00125
00126 if (len == 0) {
00127 OP_calced_score[0] = score;
00128 OP_calced_id[0] = id;
00129 return(1);
00130 }
00131 if (OP_calced_score[len-1] >= score) {
00132 if (len < OP_gprune_num) {
00133 OP_calced_score[len] = score;
00134 OP_calced_id[len] = id;
00135 len++;
00136 }
00137 return len;
00138 }
00139 if (OP_calced_score[0] < score) {
00140 insertp = 0;
00141 } else {
00142 insertp = gmm_find_insert_point(score, len);
00143 }
00144 if (len < OP_gprune_num) {
00145 memmove(&(OP_calced_score[insertp+1]), &(OP_calced_score[insertp]), sizeof(LOGPROB)*(len - insertp));
00146 memmove(&(OP_calced_id[insertp+1]), &(OP_calced_id[insertp]), sizeof(int)*(len - insertp));
00147 } else if (insertp < len - 1) {
00148 memmove(&(OP_calced_score[insertp+1]), &(OP_calced_score[insertp]), sizeof(LOGPROB)*(len - insertp - 1));
00149 memmove(&(OP_calced_id[insertp+1]), &(OP_calced_id[insertp]), sizeof(int)*(len - insertp - 1));
00150 }
00151 OP_calced_score[insertp] = score;
00152 OP_calced_id[insertp] = id;
00153 if (len < OP_gprune_num) len++;
00154 return(len);
00155 }
00156
00175 static LOGPROB
00176 gmm_compute_g_base(HTK_HMM_Dens *binfo)
00177 {
00178 VECT tmp, x;
00179 VECT *mean;
00180 VECT *var;
00181 VECT *vec = OP_vec;
00182 short veclen = OP_veclen;
00183
00184 if (binfo == NULL) return(LOG_ZERO);
00185 mean = binfo->mean;
00186 var = binfo->var->vec;
00187 tmp = 0.0;
00188 for (; veclen > 0; veclen--) {
00189 x = *(vec++) - *(mean++);
00190 tmp += x * x * *(var++);
00191 }
00192 return((tmp + binfo->gconst) * -0.5);
00193 }
00194
00215 static LOGPROB
00216 gmm_compute_g_safe(HTK_HMM_Dens *binfo, LOGPROB thres)
00217 {
00218 VECT tmp, x;
00219 VECT *mean;
00220 VECT *var;
00221 VECT *vec = OP_vec;
00222 short veclen = OP_veclen;
00223 VECT fthres = thres * (-2.0);
00224
00225 if (binfo == NULL) return(LOG_ZERO);
00226 mean = binfo->mean;
00227 var = binfo->var->vec;
00228 tmp = binfo->gconst;
00229 for (; veclen > 0; veclen--) {
00230 x = *(vec++) - *(mean++);
00231 tmp += x * x * *(var++);
00232 if (tmp > fthres) return LOG_ZERO;
00233 }
00234 return(tmp * -0.5);
00235 }
00236
00251 static void
00252 gmm_gprune_safe_init(HTK_HMM_INFO *hmminfo, int prune_num)
00253 {
00254
00255 OP_gprune_num = prune_num;
00256
00257 OP_calced_maxnum = hmminfo->maxmixturenum;
00258
00259 OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
00260 OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
00261 }
00262
00288 static void
00289 gmm_gprune_safe(HTK_HMM_Dens **g, int gnum)
00290 {
00291 int i, num = 0;
00292 LOGPROB score, thres;
00293
00294 thres = LOG_ZERO;
00295 for (i = 0; i < gnum; i++) {
00296 if (num < OP_gprune_num) {
00297 score = gmm_compute_g_base(g[i]);
00298 } else {
00299 score = gmm_compute_g_safe(g[i], thres);
00300 if (score <= thres) continue;
00301 }
00302 num = gmm_cache_push(i, score, num);
00303 thres = OP_calced_score[num-1];
00304 }
00305 OP_calced_num = num;
00306 }
00307
00308
00325 static LOGPROB
00326 gmm_calc_mix(HTK_HMM_State *s)
00327 {
00328 int i;
00329 LOGPROB logprob = LOG_ZERO;
00330
00331
00332 gmm_gprune_safe(s->b, s->mix_num);
00333
00334
00335
00336
00337
00338 for(i=0;i<OP_calced_num;i++) {
00339 OP_calced_score[i] += s->bweight[OP_calced_id[i]];
00340 }
00341 logprob = addlog_array(OP_calced_score, OP_calced_num);
00342 if (logprob <= LOG_ZERO) return LOG_ZERO;
00343 return (logprob * INV_LOG_TEN);
00344 }
00345
00367 static LOGPROB
00368 outprob_state_nocache(int t, HTK_HMM_State *stateinfo, HTK_Param *param)
00369 {
00370
00371 OP_vec = param->parvec[t];
00372 OP_veclen = param->veclen;
00373 return(gmm_calc_mix(stateinfo));
00374 }
00375
00376
00377
00378
00379
00396 void
00397 gmm_init(HTK_HMM_INFO *gmm, int gmm_prune_num)
00398 {
00399 HTK_HMM_Data *d;
00400
00401
00402
00403 if (gmm->is_tied_mixture) {
00404 j_exit("Error: mixture-tying GMM is not supported yet.\n");
00405 }
00406
00407 for(d=gmm->start;d;d=d->next) {
00408 if (d->state_num > 3) {
00409 j_exit("Error: GMM has more than 1 output state! [%s]\n", d->name);
00410 }
00411 }
00412
00413
00414
00415
00416 gmm_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gmm->totalhmmnum);
00417
00418
00419 gmm_gprune_safe_init(gmm, gmm_prune_num);
00420
00421
00422 if (!gmm->variance_inversed) {
00423
00424 htk_hmm_inverse_variances(gmm);
00425 gmm->variance_inversed = TRUE;
00426 }
00427
00428 }
00429
00443 void
00444 gmm_prepare(HTK_HMM_INFO *gmm)
00445 {
00446 HTK_HMM_Data *d;
00447 int i;
00448
00449
00450 i = 0;
00451 for(d=gmm->start;d;d=d->next) {
00452 gmm_score[i] = 0.0;
00453 i++;
00454 }
00455 framecount = 0;
00456 }
00457
00476 void
00477 gmm_proceed(HTK_HMM_INFO *gmm, HTK_Param *param, int t)
00478 {
00479 HTK_HMM_Data *d;
00480 int i;
00481
00482 framecount++;
00483 i = 0;
00484 for(d=gmm->start;d;d=d->next) {
00485 gmm_score[i] += outprob_state_nocache(t, d->s[1], param);
00486 #ifdef MES
00487 printf("[%d:total=%f avg=%f]\n", i, gmm_score[i], gmm_score[i] / (float)framecount);
00488 #endif
00489 i++;
00490 }
00491 }
00492
00493 static HTK_HMM_Data *max_d;
00494 #ifdef CONFIDENCE_MEASURE
00495 static LOGPROB gmm_max_cm;
00496 #endif
00497 static HTK_HMM_INFO *gmm_local;
00498
00519 void
00520 gmm_end(HTK_HMM_INFO *gmm)
00521 {
00522 HTK_HMM_Data *d;
00523 LOGPROB maxprob, sum;
00524 int i;
00525
00526
00527 i = 0;
00528 maxprob = LOG_ZERO;
00529 for(d=gmm->start;d;d=d->next) {
00530 if (maxprob < gmm_score[i]) {
00531 max_d = d;
00532 maxprob = gmm_score[i];
00533 }
00534 i++;
00535 }
00536 #ifdef CONFIDENCE_MEASURE
00537
00538 sum = 0.0;
00539 i = 0;
00540 for(d=gmm->start;d;d=d->next) {
00541 sum += pow(10, cm_alpha * (gmm_score[i] - maxprob));
00542 i++;
00543 }
00544 gmm_max_cm = 1.0 / sum;
00545 #endif
00546
00547
00548 gmm_local = gmm;
00549 result_gmm();
00550 }
00551
00569 boolean
00570 gmm_valid_input()
00571 {
00572 if (max_d == NULL) return FALSE;
00573 if (strstr(gmm_reject_cmn_string, max_d->name)) {
00574 return FALSE;
00575 }
00576 return TRUE;
00577 }
00578
00579
00580
00581
00591 void
00592 ttyout_gmm(){
00593 HTK_HMM_Data *d;
00594 int i;
00595
00596 if (debug2_flag) {
00597 j_printf("--- GMM result begin ---\n");
00598 i = 0;
00599 for(d=gmm_local->start;d;d=d->next) {
00600 j_printf(" [%8s: total=%f avg=%f]\n", d->name, gmm_score[i], gmm_score[i] / (float)framecount);
00601 i++;
00602 }
00603 j_printf(" max = \"%s\"", max_d->name);
00604 #ifdef CONFIDENCE_MEASURE
00605 j_printf(" (CM: %f)", gmm_max_cm);
00606 #endif
00607 j_printf("\n");
00608 j_printf("--- GMM result end ---\n");
00609 } else if (verbose_flag) {
00610 j_printf("GMM: max = \"%s\"", max_d->name);
00611 #ifdef CONFIDENCE_MEASURE
00612 j_printf(" (CM: %f)", gmm_max_cm);
00613 #endif
00614 j_printf("\n");
00615 } else {
00616 j_printf("[GMM: %s]\n", max_d->name);
00617 }
00618 }
00619
00629 void
00630 msock_gmm()
00631 {
00632 module_send(module_sd, "<GMM RESULT=\"%s\"", max_d->name);
00633 #ifdef CONFIDENCE_MEASURE
00634 module_send(module_sd, " CMSCORE=\"%f\"", gmm_max_cm);
00635 #endif
00636 module_send(module_sd, "/>\n.\n");
00637 }