Julius 4.2
libsent/src/ngram/ngram_access.c
説明を見る。
00001 
00019 /*
00020  * Copyright (c) 1991-2011 Kawahara Lab., Kyoto University
00021  * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology
00022  * Copyright (c) 2005-2011 Julius project team, Nagoya Institute of Technology
00023  * All rights reserved
00024  */
00025 
00026 #include <sent/stddefs.h>
00027 #include <sent/ngram2.h>
00028 
00029 #undef ADEBUG
00030 
00042 static NNID
00043 search_ngram_core(NGRAM_INFO *ndata, int n, NNID nid_prev, WORD_ID wkey)
00044 {
00045   NGRAM_TUPLE_INFO *t, *tprev;
00046   NNID nnid;
00047   NNID left,right,mid;
00048   NNID x;
00049 
00050   if (ndata->bigram_index_reversed && n == 2) {
00051     /* old binary format builds 1gram->2gram mapping using LR 2-gram,
00052        although the main model is RL 3-gram.  This hacks this problem */
00053     x = nid_prev;
00054     nid_prev = wkey;
00055     wkey = x;
00056   }
00057 
00058   t = &(ndata->d[n-1]);
00059   tprev = &(ndata->d[n-2]);
00060   
00061   if (tprev->ct_compaction) {
00062     nnid = tprev->nnid2ctid_upper[nid_prev];
00063     if (nnid == NNID_INVALID_UPPER) return (NNID_INVALID);
00064     nnid = (nnid << 16) + (NNID)(tprev->nnid2ctid_lower[nid_prev]);
00065   } else {
00066     nnid = nid_prev;
00067   }
00068   if (t->is24bit) {
00069     left = t->bgn_upper[nnid];
00070     if (left == NNID_INVALID_UPPER) return (NNID_INVALID);
00071     left = (left << 16) + (NNID)(t->bgn_lower[nnid]);
00072   } else {
00073     left = t->bgn[nnid];
00074     if (left == NNID_INVALID) return (NNID_INVALID);
00075   }
00076   right = left + t->num[nnid] - 1;
00077 
00078   while(left < right) {
00079     mid = (left + right) / 2;
00080     if (t->nnid2wid[mid] < wkey) {
00081       left = mid + 1;
00082     } else {
00083       right = mid;
00084     }
00085   }
00086   if (t->nnid2wid[left] == wkey) {
00087     return (left);
00088   } else {
00089     return (NNID_INVALID);
00090   }
00091 }
00092 
00102 NNID
00103 search_ngram(NGRAM_INFO *ndata, int n, WORD_ID *w)
00104 {
00105   int i;
00106   NNID prev, next;
00107 
00108   if (n == 1) {
00109     /* wid = nnid in 1-gram */
00110     return(w[0]);
00111   }
00112 
00113   prev = w[0];
00114   for(i=2;i<=n;i++) {
00115     next = search_ngram_core(ndata, i, prev, w[i-1]);
00116     if (next == NNID_INVALID) {
00117       return NNID_INVALID;
00118     }
00119     prev = next;
00120   }
00121   return(next);
00122 }
00123 
00124 
00134 LOGPROB
00135 ngram_prob(NGRAM_INFO *ndata, int n, WORD_ID *w)
00136 {
00137   int i;
00138   NNID prev, next, bid;
00139   LOGPROB p;
00140   NGRAM_TUPLE_INFO *t;
00141 
00142   if (n > ndata->n) {
00143     jlog("ERROR: no %d-gram exist (max %d)\n", n, ndata->n);
00144     return LOG_ZERO;
00145   }
00146 
00147 #ifdef ADEBUG
00148   printf("[");
00149   if (n > 1) {
00150     for(i=0;i<n-1;i++) printf("%s ", ndata->wname[w[i]]);
00151     printf("| ");
00152   }
00153   printf("%s]\n", ndata->wname[w[n-1]]);
00154 #endif
00155 
00156   /* unigram */
00157   if (n == 1) {
00158     p = ndata->d[0].prob[w[0]];
00159     if (w[0] == ndata->unk_id) p -= ndata->unk_num_log;
00160 #ifdef ADEBUG
00161     printf("hit: %f\n", p);
00162 #endif
00163     return(p);
00164   }
00165 
00166   /* parse for ngram to reach the N-gram tuple */
00167   prev = w[0];
00168   for(i=2;i<=n;i++) {
00169     next = search_ngram_core(ndata, i, prev, w[i-1]);
00170     if (next == NNID_INVALID) break;
00171     prev = next;
00172   }
00173   if (next == NNID_INVALID) {   /* not reached */
00174     /* both back-off or fallback uses (n-1) gram of the target word */
00175     /* recursive call to get the fallback likelihood */
00176 #ifdef ADEBUG
00177     printf("--(not found)->\n");
00178 #endif
00179     p = ngram_prob(ndata, n-1, &(w[1]));
00180     if (i == n) {     /* the last parse was terminated at last step */
00181       /* get back-off weight on prev */
00182       t = &(ndata->d[i-2]);
00183       if (t->ct_compaction) {
00184         if ((bid = t->nnid2ctid_upper[prev]) == NNID_INVALID_UPPER) {
00185           /* in case back-off entry not found, it means bo_wt == 0.0 */
00186 #ifdef ADEBUG
00187           printf("fall: %f\n", p);
00188 #endif
00189           return(p);
00190         } else {
00191           bid = (bid << 16) + (NNID)(t->nnid2ctid_lower[prev]);
00192         }
00193       } else {
00194         bid = prev;
00195       }
00196       /* return back-off likelihood */
00197 #ifdef ADEBUG
00198       printf("back: %f + %f\n", t->bo_wt[bid], p);
00199 #endif
00200       return(t->bo_wt[bid] + p);
00201     } else {
00202       /* previous context not found, fallback to (n-1)-gram */
00203       return(p);
00204     }
00205   }
00206   /* n-gram found */
00207   /* trigram exist */
00208   p = ndata->d[n-1].prob[next];
00209   if (w[n-1] == ndata->unk_id) p -= ndata->unk_num_log;
00210 
00211 #ifdef ADEBUG
00212   printf("hit: %f\n", p);
00213 #endif
00214   return(p);
00215 }
00216 
00217 /* ---------------------------------------------------------------------- */
00218 /* separate access functions for the 1st pass */
00219 
00228 LOGPROB
00229 uni_prob(NGRAM_INFO *ndata, WORD_ID w)
00230 {
00231   if (w != ndata->unk_id) {
00232     return(ndata->d[0].prob[w]);
00233   } else {
00234     return(ndata->d[0].prob[w] - ndata->unk_num_log);
00235   }
00236 }
00237 
00248 static NNID
00249 search_bigram(NGRAM_INFO *ndata, WORD_ID w_context, WORD_ID w)
00250 {
00251   /* do binary search to find bigram entry */
00252   /* assume ct_compaction and is24bit is FALSE on 2-gram */
00253   NNID left,right,mid;          /* n2 */
00254   NGRAM_TUPLE_INFO *t;
00255 
00256   t = &(ndata->d[1]);
00257 
00258   if ((left = t->bgn[w_context]) == NNID_INVALID) /* has no bigram */
00259     return (NNID_INVALID);
00260   right = left + t->num[w_context] - 1;
00261   while(left < right) {
00262     mid = (left + right) / 2;
00263     if (t->nnid2wid[mid] < w) {
00264       left = mid + 1;
00265     } else {
00266       right = mid;
00267     }
00268   }
00269   if (t->nnid2wid[left] == w) {
00270     return (left);
00271   } else {
00272     return (NNID_INVALID);
00273   }
00274 }
00275 
00287 static LOGPROB
00288 bi_prob_normal(NGRAM_INFO *ndata, WORD_ID w1, WORD_ID w2)
00289 {
00290   NNID n2;
00291   LOGPROB prob;
00292 
00293   /* index is LR */
00294   /* prob is in main N-gram area */
00295   if ((n2 = search_bigram(ndata, w1, w2)) != NNID_INVALID) {
00296     prob = ndata->d[1].prob[n2];
00297   } else {
00298     prob = ndata->d[0].bo_wt[w1] + ndata->d[0].prob[w2];
00299   }
00300   if (w2 != ndata->unk_id) {
00301     return(prob);
00302   } else {
00303     return(prob - ndata->unk_num_log);
00304   }
00305 }
00306 
00319 static LOGPROB
00320 bi_prob_additional_oldbin(NGRAM_INFO *ndata, WORD_ID w1, WORD_ID w2)
00321 {
00322   NNID n2;
00323   LOGPROB prob;
00324 
00325   /* index is LR */
00326   /* prob is in additional N-gram area */
00327   if ((n2 = search_bigram(ndata, w1, w2)) != NNID_INVALID) {
00328     prob = ndata->p_2[n2];
00329   } else {
00330     prob = ndata->bo_wt_1[w1] + ndata->d[0].prob[w2];
00331   }
00332   if (w2 != ndata->unk_id) {
00333     return(prob);
00334   } else {
00335     return(prob - ndata->unk_num_log);
00336   }
00337 }
00338 
00339 
00350 static LOGPROB
00351 bi_prob_additional(NGRAM_INFO *ndata, WORD_ID w1, WORD_ID w2)
00352 {
00353   NNID n2;
00354   LOGPROB prob;
00355 
00356   /* index is RL */
00357   /* prob is in additional N-gram area */
00358   if ((n2 = search_bigram(ndata, w2, w1)) != NNID_INVALID) {
00359     prob = ndata->p_2[n2];
00360   } else {
00361     prob = ndata->bo_wt_1[w1] + ndata->d[0].prob[w2];
00362   }
00363   if (w2 != ndata->unk_id) {
00364     return(prob);
00365   } else {
00366     return(prob - ndata->unk_num_log);
00367   }
00368 }
00369 
00370 
00382 static LOGPROB
00383 bi_prob_compute(NGRAM_INFO *ndata, WORD_ID w1, WORD_ID w2)
00384 {
00385   NNID n2;
00386   LOGPROB prob;
00387 
00388   /* index is RL */
00389   /* no additional N-gram, compute it directly */
00390   /* get p(w1|w2) */
00391   if ((n2 = search_bigram(ndata, w2, w1)) != NNID_INVALID) {
00392     prob = ndata->d[1].prob[n2];
00393   } else {
00394     prob = ndata->d[0].bo_wt[w2] + ndata->d[0].prob[w1];
00395   }
00396   /* p(w2|w1) = p(w1|w2) * p(w2) / p(w1) */
00397   prob = prob + ndata->d[0].prob[w2] - ndata->d[0].prob[w1];
00398   if (w2 != ndata->unk_id) {
00399     return(prob);
00400   } else {
00401     return(prob - ndata->unk_num_log);
00402   }
00403 }
00404 
00405 
00418 LOGPROB
00419 bi_prob(NGRAM_INFO *ndata, WORD_ID w1, WORD_ID w2)
00420 {
00421   LOGPROB p;
00422   if (ndata->bigram_index_reversed) {
00423     /* old binary format */
00424     /* RL 3-gram with additional LR 2-gram, index by LR */
00425     /* indexes are LR (swap not needed), probs are in additional area */
00426     p = bi_prob_additional_oldbin(ndata, w1, w2);
00427   } else if (ndata->dir == DIR_LR) {
00428     /* LR 3-gram, index by LR */
00429     p = bi_prob_normal(ndata, w1, w2);
00430   } else if (ndata->bo_wt_1 != NULL) {
00431     /* RL 3-gram with additional LR 2-gram, index by RL */
00432     p = bi_prob_additional(ndata, w1, w2);
00433   } else {
00434     /* RL 3-gram only, index by RL */
00435     p = bi_prob_compute(ndata, w1, w2);
00436   }
00437   return p;
00438 }
00439 
00448 void
00449 bi_prob_func_set(NGRAM_INFO *ndata)
00450 {
00451   if (ndata->bigram_index_reversed) {
00452     /* old binary format */
00453     /* RL 3-gram with additional LR 2-gram, index by LR */
00454     /* indexes are LR (swap not needed), probs are in additional area */
00455     ndata->bigram_prob = bi_prob_additional_oldbin;
00456   } else if (ndata->dir == DIR_LR) {
00457     /* LR 3-gram, index by LR */
00458     ndata->bigram_prob = bi_prob_normal;
00459   } else if (ndata->bo_wt_1 != NULL) {
00460     /* RL 3-gram with additional LR 2-gram, index by RL */
00461     ndata->bigram_prob = bi_prob_additional;
00462   } else {
00463     /* RL 3-gram only, index by RL */
00464     ndata->bigram_prob = bi_prob_compute;
00465   }
00466 }