00001
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include <sent/stddefs.h>
00068 #include <sent/htk_hmm.h>
00069 #include <sent/htk_param.h>
00070 #include <sent/hmm.h>
00071 #include <sent/gprune.h>
00072 #include "globalvars.h"
00073
00074 #undef NORMALIZE_GS_SCORE
00075
00076
00077
00078
00079
00080 static int my_nbest;
00081 static int allocframenum;
00082
00083
00084 static GS_SET *gsset;
00085 static int gsset_num;
00086 static int *state2gs;
00087
00088
00089 static boolean *is_selected;
00090 static LOGPROB **fallback_score = NULL;
00091
00092
00093 static int *gsindex;
00094 static LOGPROB *t_fs;
00095
00096
00101 static void
00102 build_gsset()
00103 {
00104 HTK_HMM_State *st;
00105
00106
00107 gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * OP_gshmm->totalstatenum);
00108 gsset_num = OP_gshmm->totalstatenum;
00109
00110 for(st = OP_gshmm->ststart; st; st=st->next) {
00111 gsset[st->id].state = st;
00112 }
00113 }
00114
00119 static void
00120 free_gsset()
00121 {
00122 free(gsset);
00123 }
00124
00125 #define MAXHMMNAMELEN 40
00126
00127
00132 static boolean
00133 build_state2gs()
00134 {
00135 HTK_HMM_Data *dt;
00136 HTK_HMM_State *st, *cr;
00137 int i;
00138 char gstr[MAXHMMNAMELEN], cbuf[MAXHMMNAMELEN];
00139 boolean ok_p = TRUE;
00140
00141
00142 state2gs = (int *)mymalloc(sizeof(int) * OP_hmminfo->totalstatenum);
00143 for(i=0;i<OP_hmminfo->totalstatenum;i++) state2gs[i] = -1;
00144
00145
00146 for(dt = OP_hmminfo->start; dt; dt=dt->next) {
00147 if (strlen(dt->name) >= MAXHMMNAMELEN - 2) {
00148 j_printerr("Error: too long hmm name (>%d): \"%s\"\n",
00149 MAXHMMNAMELEN-3, dt->name);
00150 ok_p = FALSE;
00151 continue;
00152 }
00153 for(i=1;i<dt->state_num-1;i++) {
00154 st = dt->s[i];
00155
00156 if (state2gs[st->id] != -1) continue;
00157
00158 sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00159
00160 if ((cr = state_lookup(OP_gshmm, gstr)) == NULL) {
00161 j_printerr("Error: GS HMM \"%s\" not defined\n", gstr);
00162 ok_p = FALSE;
00163 continue;
00164 }
00165
00166 state2gs[st->id] = cr->id;
00167 }
00168 }
00169 #ifdef PARANOIA
00170 {
00171 HTK_HMM_State *st;
00172 for(st=OP_hmminfo->ststart; st; st=st->next) {
00173 printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00174 (gsset[state2gs[st->id]].state)->name);
00175 }
00176 }
00177 #endif
00178 return ok_p;
00179 }
00180
00185 static void
00186 free_state2gs()
00187 {
00188 free(state2gs);
00189 }
00190
00191
00192
00193 #define SD(A) gsindex[A-1]
00194 #define SCOPY(D,S) D = S
00195 #define SVAL(A) (t_fs[gsindex[A-1]])
00196 #define STVAL (t_fs[s])
00197
00198
00204 static void
00205 sort_gsindex_upward(int neednum, int totalnum)
00206 {
00207 int n,root,child,parent;
00208 int s;
00209 for (root = totalnum/2; root >= 1; root--) {
00210 SCOPY(s, SD(root));
00211 parent = root;
00212 while ((child = parent * 2) <= totalnum) {
00213 if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00214 child++;
00215 }
00216 if (STVAL >= SVAL(child)) {
00217 break;
00218 }
00219 SCOPY(SD(parent), SD(child));
00220 parent = child;
00221 }
00222 SCOPY(SD(parent), s);
00223 }
00224 n = totalnum;
00225 while ( n > totalnum - neednum) {
00226 SCOPY(s, SD(n));
00227 SCOPY(SD(n), SD(1));
00228 n--;
00229 parent = 1;
00230 while ((child = parent * 2) <= n) {
00231 if (child < n && SVAL(child) < SVAL(child+1)) {
00232 child++;
00233 }
00234 if (STVAL >= SVAL(child)) {
00235 break;
00236 }
00237 SCOPY(SD(parent), SD(child));
00238 parent = child;
00239 }
00240 SCOPY(SD(parent), s);
00241 }
00242 }
00243
00248 static void
00249 do_gms()
00250 {
00251 int i;
00252
00253
00254 compute_gs_scores(gsset, gsset_num, t_fs);
00255
00256 sort_gsindex_upward(my_nbest, gsset_num);
00257 for(i=gsset_num - my_nbest;i<gsset_num;i++) {
00258
00259 t_fs[gsindex[i]] = LOG_ZERO;
00260 }
00261
00262
00263 #ifdef NORMALIZE_GS_SCORE
00264
00265 for(i=0;i<gsset_num;i++) {
00266 if (t_fs[i] != LOG_ZERO) {
00267 t_fs[i] = t_fs[i] * 0.975;
00268 }
00269 }
00270 #endif
00271 }
00272
00273
00281 boolean
00282 gms_init(int nbest)
00283 {
00284 int i;
00285
00286
00287 if (OP_gshmm->is_triphone) {
00288 j_printerr("Error: GS HMM should be a monophone model\n");
00289 return FALSE;
00290 }
00291 if (OP_gshmm->is_tied_mixture) {
00292 j_printerr("Error: GS HMM should not be a tied mixture model\n");
00293 return FALSE;
00294 }
00295
00296
00297 my_nbest = nbest;
00298
00299
00300 build_gsset();
00301
00302 j_printerr("Mapping HMM states to GS HMM...");
00303 if (build_state2gs() == FALSE) {
00304 j_printerr("Error: failed in assigning GS HMM state for each state\n");
00305 return FALSE;
00306 }
00307 j_printerr("done\n");
00308
00309
00310 gsindex = (int *)mymalloc(sizeof(int) * gsset_num);
00311 for(i=0;i<gsset_num;i++) gsindex[i] = i;
00312
00313
00314 fallback_score = NULL;
00315 is_selected = NULL;
00316 allocframenum = -1;
00317
00318
00319 gms_gprune_init(OP_hmminfo, gsset_num);
00320
00321 return TRUE;
00322 }
00323
00331 boolean
00332 gms_prepare(int framenum)
00333 {
00334 LOGPROB *tmp;
00335 int t;
00336
00337
00338 if (allocframenum < framenum) {
00339 if (fallback_score != NULL) {
00340 free(fallback_score[0]);
00341 free(fallback_score);
00342 free(is_selected);
00343 }
00344 fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00345 tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gsset_num * framenum);
00346 for(t=0;t<framenum;t++) {
00347 fallback_score[t] = &(tmp[gsset_num * t]);
00348 }
00349 is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00350 allocframenum = framenum;
00351 }
00352
00353 for(t=0;t<framenum;t++) is_selected[t] = FALSE;
00354
00355
00356 gms_gprune_prepare();
00357
00358 return TRUE;
00359 }
00360
00365 void
00366 gms_free()
00367 {
00368 free_gsset();
00369 free_state2gs();
00370 free(gsindex);
00371 if (fallback_score != NULL) {
00372 free(fallback_score[0]);
00373 free(fallback_score);
00374 free(is_selected);
00375 }
00376 gms_gprune_free();
00377 }
00378
00379
00380
00391 LOGPROB
00392 gms_state()
00393 {
00394 LOGPROB gsprob;
00395 if (OP_last_time != OP_time) {
00396
00397 t_fs = fallback_score[OP_time];
00398
00399 if (!is_selected[OP_time]) {
00400 do_gms();
00401 is_selected[OP_time] = TRUE;
00402 }
00403 }
00404 if ((gsprob = t_fs[state2gs[OP_state_id]]) != LOG_ZERO) {
00405
00406 return(gsprob);
00407 }
00408
00409 return(calc_outprob());
00410 }