From b367690eebc97a2fea3da1c1694ba365c72fd713 Mon Sep 17 00:00:00 2001 From: Dingyuan Wang Date: Sun, 19 Oct 2014 10:32:23 +0800 Subject: [PATCH 1/2] use prefix dict instead of trie, add a command line interface, and a few small improvements --- jieba/__init__.py | 325 +++++++++++++++++-------------------- jieba/__main__.py | 35 ++++ jieba/analyse/__init__.py | 82 ++++++---- jieba/analyse/analyzer.py | 20 +-- jieba/finalseg/__init__.py | 36 ++-- jieba/posseg/__init__.py | 135 ++++++++------- jieba/posseg/viterbi.py | 41 +++-- 7 files changed, 350 insertions(+), 324 deletions(-) create mode 100644 jieba/__main__.py diff --git a/jieba/__init__.py b/jieba/__init__.py index acd6bd11..222cab04 100644 --- a/jieba/__init__.py +++ b/jieba/__init__.py @@ -16,14 +16,13 @@ DICTIONARY = "dict.txt" DICT_LOCK = threading.RLock() -trie = None # to be initialized +pfdict = None # to be initialized FREQ = {} min_freq = 0.0 -total =0.0 -user_word_tag_tab={} +total = 0.0 +user_word_tag_tab = {} initialized = False - log_console = logging.StreamHandler(sys.stderr) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -33,85 +32,80 @@ def setLogLevel(log_level): global logger logger.setLevel(log_level) -def gen_trie(f_name): +def gen_pfdict(f_name): lfreq = {} - trie = {} + pfdict = set() ltotal = 0.0 with open(f_name, 'rb') as f: lineno = 0 for line in f.read().rstrip().decode('utf-8').split('\n'): lineno += 1 try: - word,freq,_ = line.split(' ') + word,freq = line.split(' ')[:2] freq = float(freq) lfreq[word] = freq - ltotal+=freq - p = trie - for c in word: - if c not in p: - p[c] ={} - p = p[c] - p['']='' #ending flag + ltotal += freq + for ch in range(len(word)): + pfdict.add(word[:ch+1]) except ValueError as e: - logger.debug('%s at line %s %s' % (f_name, lineno, line)) + logger.debug('%s at line %s %s' % (f_name, lineno, line)) raise e - return trie, lfreq,ltotal + return pfdict, lfreq, ltotal def initialize(*args): - global trie, FREQ, total, min_freq, initialized - if len(args)==0: + global pfdict, FREQ, total, min_freq, initialized + if not args: dictionary = DICTIONARY else: dictionary = args[0] with DICT_LOCK: if initialized: return - if trie: - del trie - trie = None - _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) + if pfdict: + del pfdict + pfdict = None + _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) abs_path = os.path.join(_curpath,dictionary) - logger.debug("Building Trie..., from %s" % abs_path) + logger.debug("Building prefix dict from %s ..." % abs_path) t1 = time.time() - if abs_path == os.path.join(_curpath,"dict.txt"): #defautl dictionary - cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache") - else: #customer dictionary - cache_file = os.path.join(tempfile.gettempdir(),"jieba.user."+str(hash(abs_path))+".cache") + if abs_path == os.path.join(_curpath, "dict.txt"): #default dictionary + cache_file = os.path.join(tempfile.gettempdir(), "jieba.cache") + else: #custom dictionary + cache_file = os.path.join(tempfile.gettempdir(), "jieba.user.%s.cache" % hash(abs_path)) load_from_cache_fail = True - if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(abs_path): - logger.debug("loading model from cache %s" % cache_file) + if os.path.exists(cache_file) and os.path.getmtime(cache_file) > os.path.getmtime(abs_path): + logger.debug("Loading model from cache %s" % cache_file) try: with open(cache_file, 'rb') as cf: - trie,FREQ,total,min_freq = marshal.load(cf) - load_from_cache_fail = False + pfdict,FREQ,total,min_freq = marshal.load(cf) + # prevent conflict with old version + load_from_cache_fail = not isinstance(pfdict, set) except: load_from_cache_fail = True if load_from_cache_fail: - trie,FREQ,total = gen_trie(abs_path) + pfdict,FREQ,total = gen_pfdict(abs_path) FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.items()]) #normalize min_freq = min(FREQ.values()) - logger.debug("dumping model to file cache %s" % cache_file) + logger.debug("Dumping model to file cache %s" % cache_file) try: tmp_suffix = "."+str(random.random()) with open(cache_file+tmp_suffix,'wb') as temp_cache_file: - marshal.dump((trie,FREQ,total,min_freq),temp_cache_file) - if os.name=='nt': - import shutil - replace_file = shutil.move + marshal.dump((pfdict,FREQ,total,min_freq), temp_cache_file) + if os.name == 'nt': + from shutil import move as replace_file else: replace_file = os.rename - replace_file(cache_file+tmp_suffix,cache_file) + replace_file(cache_file + tmp_suffix, cache_file) except: - logger.error("dump cache file failed.") - logger.exception("") + logger.exception("Dump cache file failed.") initialized = True - logger.debug("loading model cost %s seconds." % (time.time() - t1)) - logger.debug("Trie has been built succesfully.") + logger.debug("Loading model cost %s seconds." % (time.time() - t1)) + logger.debug("Prefix dict has been built succesfully.") def require_initialized(fn): @@ -132,145 +126,136 @@ def __cut_all(sentence): dag = get_DAG(sentence) old_j = -1 for k,L in dag.items(): - if len(L)==1 and k>old_j: + if len(L) == 1 and k > old_j: yield sentence[k:L[0]+1] old_j = L[0] else: for j in L: - if j>k: + if j > k: yield sentence[k:j+1] old_j = j def calc(sentence,DAG,idx,route): N = len(sentence) - route[N] = (0.0,'') - for idx in range(N-1,-1,-1): - candidates = [ ( FREQ.get(sentence[idx:x+1],min_freq) + route[x+1][0],x ) for x in DAG[idx] ] + route[N] = (0.0, '') + for idx in range(N-1, -1, -1): + candidates = [(FREQ.get(sentence[idx:x+1],min_freq) + route[x+1][0], x) for x in DAG[idx]] route[idx] = max(candidates) @require_initialized def get_DAG(sentence): - N = len(sentence) - i,j=0,0 - p = trie + global pfdict, FREQ DAG = {} - while i=N: - i+=1 - j=i - p=trie - else: - p = trie - i+=1 - j=i - for i in range(len(sentence)): - if i not in DAG: - DAG[i] =[i] + N = len(sentence) + for k in range(N): + tmplist = [] + i = k + frag = sentence[k] + while i < N and frag in pfdict: + if frag in FREQ: + tmplist.append(i) + i += 1 + frag = sentence[k:i+1] + if not tmplist: + tmplist.append(k) + DAG[k] = tmplist return DAG def __cut_DAG_NO_HMM(sentence): re_eng = re.compile(r'[a-zA-Z0-9]',re.U) DAG = get_DAG(sentence) - route ={} - calc(sentence,DAG,0,route=route) + route = {} + calc(sentence, DAG, 0, route=route) x = 0 N = len(sentence) buf = '' - while x0: + if buf: yield buf buf = '' yield l_word - x =y - if len(buf)>0: + x = y + if buf: yield buf buf = '' def __cut_DAG(sentence): DAG = get_DAG(sentence) - route ={} - calc(sentence,DAG,0,route=route) + route = {} + calc(sentence, DAG, 0, route=route) x = 0 - buf ='' + buf = '' N = len(sentence) - while x0: - if len(buf)==1: + if buf: + if len(buf) == 1: yield buf - buf='' + buf = '' else: if (buf not in FREQ): - regognized = finalseg.cut(buf) - for t in regognized: + recognized = finalseg.cut(buf) + for t in recognized: yield t else: for elem in buf: yield elem - buf='' + buf = '' yield l_word - x =y + x = y - if len(buf)>0: - if len(buf)==1: + if buf: + if len(buf) == 1: yield buf + elif (buf not in FREQ): + recognized = finalseg.cut(buf) + for t in recognized: + yield t else: - if (buf not in FREQ): - regognized = finalseg.cut(buf) - for t in regognized: - yield t - else: - for elem in buf: - yield elem - -def cut(sentence,cut_all=False,HMM=True): + for elem in buf: + yield elem + +def cut(sentence, cut_all=False, HMM=True): '''The main function that segments an entire sentence that contains Chinese characters into seperated words. Parameter: - - sentence: The String to be segmented - - cut_all: Model. True means full pattern, false means accurate pattern. - - HMM: Whether use Hidden Markov Model. + - sentence: The str to be segmented. + - cut_all: Model type. True for full pattern, False for accurate pattern. + - HMM: Whether to use the Hidden Markov Model. ''' if isinstance(sentence, bytes): try: sentence = sentence.decode('utf-8') except UnicodeDecodeError: - sentence = sentence.decode('gbk','ignore') - ''' - \\u4E00-\\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han - \r\n|\s : whitespace characters. Will not be Handled. - ''' - re_han, re_skip = re.compile(r"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(r"(\r\n|\s)") + sentence = sentence.decode('gbk', 'ignore') + + # \u4E00-\u9FA5a-zA-Z0-9+#&\._ : All non-space characters. Will be handled with re_han + # \r\n|\s : whitespace characters. Will not be handled. + if cut_all: - re_han, re_skip = re.compile(r"([\u4E00-\u9FA5]+)", re.U), re.compile(r"[^a-zA-Z0-9+#\n]") + re_han, re_skip = re.compile(r"([\u4E00-\u9FA5]+)", re.U), re.compile(r"[^a-zA-Z0-9+#\n]", re.U) + else: + re_han, re_skip = re.compile(r"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(r"(\r\n|\s)", re.U) blocks = re_han.split(sentence) - if HMM: + if cut_all: + cut_block = __cut_all + elif HMM: cut_block = __cut_DAG else: cut_block = __cut_DAG_NO_HMM - if cut_all: - cut_block = __cut_all for blk in blocks: - if len(blk)==0: + if not blk: continue if re_han.match(blk): for word in cut_block(blk): @@ -286,15 +271,15 @@ def cut(sentence,cut_all=False,HMM=True): else: yield x -def cut_for_search(sentence,HMM=True): - words = cut(sentence,HMM=HMM) +def cut_for_search(sentence, HMM=True): + words = cut(sentence, HMM=HMM) for w in words: - if len(w)>2: + if len(w) > 2: for i in range(len(w)-1): gram2 = w[i:i+2] if gram2 in FREQ: yield gram2 - if len(w)>3: + if len(w) > 3: for i in range(len(w)-2): gram3 = w[i:i+3] if gram3 in FREQ: @@ -312,79 +297,69 @@ def load_userdict(f): ... Word type may be ignored ''' - global trie,total,FREQ if isinstance(f, str): f = open(f, 'rb') content = f.read().decode('utf-8') line_no = 0 for line in content.split("\n"): - line_no+=1 - if line.rstrip()=='': continue - tup =line.split(" ") - word,freq = tup[0],tup[1] - if freq.isdigit() is False: continue - if line_no==1: + line_no += 1 + if not line.rstrip(): + continue + tup = line.split(" ") + word, freq = tup[0], tup[1] + if freq.isdigit() is False: + continue + if line_no == 1: word = word.replace('\ufeff',"") #remove bom flag if it exists - if len(tup)==3: - add_word(word, freq, tup[2]) - else: - add_word(word, freq) + add_word(*tup) @require_initialized def add_word(word, freq, tag=None): - global FREQ, trie, total, user_word_tag_tab - freq = float(freq) - FREQ[word] = log(freq / total) + global FREQ, pfdict, total, user_word_tag_tab + FREQ[word] = log(float(freq) / total) if tag is not None: user_word_tag_tab[word] = tag.strip() - p = trie - for c in word: - if c not in p: - p[c] = {} - p = p[c] - p[''] = '' # ending flag + for ch in range(len(word)): + pfdict.add(word[:ch+1]) __ref_cut = cut __ref_cut_for_search = cut_for_search def __lcut(sentence): - return list(__ref_cut(sentence,False)) + return list(__ref_cut(sentence, False)) def __lcut_no_hmm(sentence): - return list(__ref_cut(sentence,False,False)) + return list(__ref_cut(sentence, False, False)) def __lcut_all(sentence): - return list(__ref_cut(sentence,True)) + return list(__ref_cut(sentence, True)) def __lcut_for_search(sentence): return list(__ref_cut_for_search(sentence)) @require_initialized def enable_parallel(processnum=None): - global pool,cut,cut_for_search - if os.name=='nt': + global pool, cut, cut_for_search + if os.name == 'nt': raise Exception("jieba: parallel mode only supports posix system") - if sys.version_info[0]==2 and sys.version_info[1]<6: - raise Exception("jieba: the parallel feature needs Python version>2.5 ") - from multiprocessing import Pool,cpu_count - if processnum==None: + from multiprocessing import Pool, cpu_count + if processnum is None: processnum = cpu_count() pool = Pool(processnum) def pcut(sentence,cut_all=False,HMM=True): parts = re.compile('([\r\n]+)').split(sentence) if cut_all: - result = pool.map(__lcut_all,parts) + result = pool.map(__lcut_all, parts) + elif HMM: + result = pool.map(__lcut, parts) else: - if HMM: - result = pool.map(__lcut,parts) - else: - result = pool.map(__lcut_no_hmm,parts) + result = pool.map(__lcut_no_hmm, parts) for r in result: for w in r: yield w def pcut_for_search(sentence): parts = re.compile('([\r\n]+)').split(sentence) - result = pool.map(__lcut_for_search,parts) + result = pool.map(__lcut_for_search, parts) for r in result: for w in r: yield w @@ -403,40 +378,44 @@ def disable_parallel(): def set_dictionary(dictionary_path): global initialized, DICTIONARY with DICT_LOCK: - abs_path = os.path.normpath( os.path.join( os.getcwd(), dictionary_path ) ) + abs_path = os.path.normpath(os.path.join(os.getcwd(), dictionary_path)) if not os.path.exists(abs_path): - raise Exception("jieba: path does not exist:" + abs_path) + raise Exception("jieba: path does not exist: " + abs_path) DICTIONARY = abs_path initialized = False def get_abs_path_dict(): - _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) + _curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) abs_path = os.path.join(_curpath,DICTIONARY) return abs_path -def tokenize(unicode_sentence,mode="default",HMM=True): - #mode ("default" or "search") +def tokenize(unicode_sentence, mode="default", HMM=True): + """Tokenize a sentence and yields tuples of (word, start, end) + Parameter: + - sentence: the str to be segmented. + - mode: "default" or "search", "search" is for finer segmentation. + - HMM: whether to use the Hidden Markov Model. + """ if not isinstance(unicode_sentence, str): raise Exception("jieba: the input parameter should be str.") start = 0 - if mode=='default': - for w in cut(unicode_sentence,HMM=HMM): + if mode == 'default': + for w in cut(unicode_sentence, HMM=HMM): width = len(w) - yield (w,start,start+width) - start+=width + yield (w, start, start+width) + start += width else: - for w in cut(unicode_sentence,HMM=HMM): + for w in cut(unicode_sentence, HMM=HMM): width = len(w) - if len(w)>2: + if len(w) > 2: for i in range(len(w)-1): gram2 = w[i:i+2] if gram2 in FREQ: - yield (gram2,start+i,start+i+2) - if len(w)>3: + yield (gram2, start+i, start+i+2) + if len(w) > 3: for i in range(len(w)-2): gram3 = w[i:i+3] if gram3 in FREQ: - yield (gram3,start+i,start+i+3) - yield (w,start,start+width) - start+=width - + yield (gram3, start+i, start+i+3) + yield (w, start, start+width) + start += width diff --git a/jieba/__main__.py b/jieba/__main__.py new file mode 100644 index 00000000..bdc94fa9 --- /dev/null +++ b/jieba/__main__.py @@ -0,0 +1,35 @@ +"""Jieba command line interface.""" +import sys +import jieba +from argparse import ArgumentParser + +parser = ArgumentParser(usage="%s -m jieba [options] filename" % sys.executable, description="Jieba command line interface.", version="Jieba " + jieba.__version__, epilog="If no filename specified, use STDIN instead.") +parser.add_argument("-d", "--delimiter", metavar="DELIM", default=' / ', + nargs='?', const=' ', + help="use DELIM instead of ' / ' for word delimiter; use a space if it is without DELIM") +parser.add_argument("-a", "--cut-all", + action="store_true", dest="cutall", default=False, + help="full pattern cutting") +parser.add_argument("-n", "--no-hmm", dest="hmm", action="store_false", + default=True, help="don't use the Hidden Markov Model") +parser.add_argument("-q", "--quiet", action="store_true", default=False, + help="don't print loading messages to stderr") +parser.add_argument("filename", nargs='?', help="input file") + +args = parser.parse_args() + +if args.quiet: + jieba.setLogLevel(60) +delim = str(args.delimiter) +cutall = args.cutall +hmm = args.hmm +fp = open(args.filename, 'r') if args.filename else sys.stdin + +jieba.initialize() +ln = fp.readline() +while ln: + l = ln.rstrip('\r\n') + print(delim.join(jieba.cut(ln.rstrip('\r\n'), cutall, hmm)).encode('utf-8')) + ln = fp.readline() + +fp.close() diff --git a/jieba/analyse/__init__.py b/jieba/analyse/__init__.py index d28f85a3..f70448b0 100644 --- a/jieba/analyse/__init__.py +++ b/jieba/analyse/__init__.py @@ -6,61 +6,77 @@ except ImportError: pass -_curpath = os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) +_curpath = os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) abs_path = os.path.join(_curpath, "idf.txt") -IDF_DICTIONARY = abs_path -STOP_WORDS = set([ - "the","of","is","and","to","in","that","we","for","an","are","by","be","as","on","with","can","if","from","which","you","it","this","then","at","have","all","not","one","has","or","that" -]) +STOP_WORDS = set(( + "the","of","is","and","to","in","that","we","for","an","are", + "by","be","as","on","with","can","if","from","which","you","it", + "this","then","at","have","all","not","one","has","or","that" +)) -def set_idf_path(idf_path): - global IDF_DICTIONARY - abs_path = os.path.normpath( os.path.join( os.getcwd(), idf_path ) ) - if not os.path.exists(abs_path): - raise Exception("jieba: path does not exist:" + abs_path) - IDF_DICTIONARY = abs_path - return +class IDFLoader: + def __init__(self): + self.path = "" + self.idf_freq = {} + self.median_idf = 0.0 -def get_idf(abs_path): - content = open(abs_path,'rb').read().decode('utf-8') - idf_freq = {} - lines = content.split('\n') - for line in lines: - word,freq = line.split(' ') - idf_freq[word] = float(freq) - median_idf = sorted(idf_freq.values())[len(idf_freq)//2] - return idf_freq, median_idf + def set_new_path(self, new_idf_path): + if self.path != new_idf_path: + content = open(new_idf_path, 'r', encoding='utf-8').read() + idf_freq = {} + lines = content.split('\n') + if lines and not lines[-1]: + lines.pop(-1) + for line in lines: + word, freq = line.split(' ') + idf_freq[word] = float(freq) + median_idf = sorted(idf_freq.values())[len(idf_freq)//2] + self.idf_freq = idf_freq + self.median_idf = median_idf + self.path = new_idf_path + + def get_idf(self): + return self.idf_freq, self.median_idf + +idf_loader = IDFLoader() +idf_loader.set_new_path(abs_path) + +def set_idf_path(idf_path): + new_abs_path = os.path.normpath(os.path.join(os.getcwd(), idf_path)) + if not os.path.exists(new_abs_path): + raise Exception("jieba: path does not exist: " + new_abs_path) + idf_loader.set_new_path(new_abs_path) def set_stop_words(stop_words_path): global STOP_WORDS - abs_path = os.path.normpath( os.path.join( os.getcwd(), stop_words_path ) ) + abs_path = os.path.normpath(os.path.join(os.getcwd(), stop_words_path)) if not os.path.exists(abs_path): - raise Exception("jieba: path does not exist:" + abs_path) + raise Exception("jieba: path does not exist: " + abs_path) content = open(abs_path,'rb').read().decode('utf-8') lines = content.split('\n') for line in lines: STOP_WORDS.add(line) - return -def extract_tags(sentence,topK=20): - global IDF_DICTIONARY +def extract_tags(sentence, topK=20): global STOP_WORDS - idf_freq, median_idf = get_idf(IDF_DICTIONARY) + idf_freq, median_idf = idf_loader.get_idf() words = jieba.cut(sentence) freq = {} for w in words: - if len(w.strip())<2: continue - if w.lower() in STOP_WORDS: continue - freq[w]=freq.get(w,0.0)+1.0 + if len(w.strip()) < 2: + continue + if w.lower() in STOP_WORDS: + continue + freq[w] = freq.get(w, 0.0) + 1.0 total = sum(freq.values()) freq = [(k,v/total) for k,v in freq.items()] - tf_idf_list = [(v * idf_freq.get(k,median_idf),k) for k,v in freq] - st_list = sorted(tf_idf_list,reverse=True) + tf_idf_list = [(v*idf_freq.get(k,median_idf), k) for k,v in freq] + st_list = sorted(tf_idf_list, reverse=True) - top_tuples= st_list[:topK] + top_tuples = st_list[:topK] tags = [a[1] for a in top_tuples] return tags diff --git a/jieba/analyse/analyzer.py b/jieba/analyse/analyzer.py index 615130e6..c5bfd122 100644 --- a/jieba/analyse/analyzer.py +++ b/jieba/analyse/analyzer.py @@ -15,21 +15,19 @@ accepted_chars = re.compile(r"[\u4E00-\u9FA5]+") class ChineseTokenizer(Tokenizer): - def __call__(self,text,**kargs): - words = jieba.tokenize(text,mode="search") - token = Token() + def __call__(self, text, **kargs): + words = jieba.tokenize(text, mode="search") + token = Token() for (w,start_pos,stop_pos) in words: - if not accepted_chars.match(w): - if len(w)>1: - pass - else: - continue + if not accepted_chars.match(w) and len(w)<=1: + continue token.original = token.text = w token.pos = start_pos token.startchar = start_pos token.endchar = stop_pos yield token -def ChineseAnalyzer(stoplist=STOP_WORDS,minsize=1,stemfn=stem,cachesize=50000): - return ChineseTokenizer() | LowercaseFilter() | StopFilter(stoplist=stoplist,minsize=minsize)\ - |StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize) +def ChineseAnalyzer(stoplist=STOP_WORDS, minsize=1, stemfn=stem, cachesize=50000): + return (ChineseTokenizer() | LowercaseFilter() | + StopFilter(stoplist=stoplist,minsize=minsize) | + StemFilter(stemfn=stemfn, ignore=None,cachesize=cachesize)) diff --git a/jieba/finalseg/__init__.py b/jieba/finalseg/__init__.py index 0540d8cf..5cac02fd 100644 --- a/jieba/finalseg/__init__.py +++ b/jieba/finalseg/__init__.py @@ -3,7 +3,7 @@ import marshal import sys -MIN_FLOAT=-3.14e100 +MIN_FLOAT = -3.14e100 PROB_START_P = "prob_start.p" PROB_TRANS_P = "prob_trans.p" @@ -18,20 +18,20 @@ } def load_model(): - _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) ) + _curpath=os.path.normpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) start_p = {} abs_path = os.path.join(_curpath, PROB_START_P) with open(abs_path, mode='rb') as f: start_p = marshal.load(f) f.closed - + trans_p = {} abs_path = os.path.join(_curpath, PROB_TRANS_P) with open(abs_path, 'rb') as f: trans_p = marshal.load(f) f.closed - + emit_p = {} abs_path = os.path.join(_curpath, PROB_EMIT_P) with open(abs_path, 'rb') as f: @@ -57,40 +57,40 @@ def viterbi(obs, states, start_p, trans_p, emit_p): newpath = {} for y in states: em_p = emit_p[y].get(obs[t],MIN_FLOAT) - (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_FLOAT) + em_p ,y0) for y0 in PrevStatus[y] ]) - V[t][y] =prob + (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y, MIN_FLOAT) + em_p, y0) for y0 in PrevStatus[y]]) + V[t][y] = prob newpath[y] = path[state] + [y] path = newpath - - (prob, state) = max([(V[len(obs) - 1][y], y) for y in ('E','S')]) - + + (prob, state) = max([(V[len(obs)-1][y], y) for y in ('E','S')]) + return (prob, path[state]) def __cut(sentence): global emit_P - prob, pos_list = viterbi(sentence,('B','M','E','S'), start_P, trans_P, emit_P) + prob, pos_list = viterbi(sentence, ('B','M','E','S'), start_P, trans_P, emit_P) begin, next = 0,0 #print pos_list, sentence for i,char in enumerate(sentence): pos = pos_list[i] - if pos=='B': + if pos == 'B': begin = i - elif pos=='E': + elif pos == 'E': yield sentence[begin:i+1] next = i+1 - elif pos=='S': + elif pos == 'S': yield char next = i+1 - if next0: + if buf: yield pair(buf,'eng') buf = '' - yield pair(l_word,word_tag_tab.get(l_word,'x')) - x =y - if len(buf)>0: + yield pair(l_word, word_tag_tab.get(l_word, 'x')) + x = y + if buf: yield pair(buf,'eng') buf = '' def __cut_DAG(sentence): DAG = jieba.get_DAG(sentence) - route ={} + route = {} jieba.calc(sentence,DAG,0,route=route) x = 0 - buf ='' + buf = '' N = len(sentence) - while x0: - if len(buf)==1: - yield pair(buf,word_tag_tab.get(buf,'x')) - buf='' + if buf: + if len(buf) == 1: + yield pair(buf, word_tag_tab.get(buf, 'x')) + buf = '' else: if (buf not in jieba.FREQ): - regognized = __cut_detail(buf) - for t in regognized: + recognized = __cut_detail(buf) + for t in recognized: yield t else: for elem in buf: - yield pair(elem,word_tag_tab.get(elem,'x')) - buf='' - yield pair(l_word,word_tag_tab.get(l_word,'x')) - x =y - - if len(buf)>0: - if len(buf)==1: - yield pair(buf,word_tag_tab.get(buf,'x')) + yield pair(elem, word_tag_tab.get(elem, 'x')) + buf = '' + yield pair(l_word, word_tag_tab.get(l_word, 'x')) + x = y + + if buf: + if len(buf) == 1: + yield pair(buf, word_tag_tab.get(buf, 'x')) + elif (buf not in jieba.FREQ): + recognized = __cut_detail(buf) + for t in recognized: + yield t else: - if (buf not in jieba.FREQ): - regognized = __cut_detail(buf) - for t in regognized: - yield t - else: - for elem in buf: - yield pair(elem,word_tag_tab.get(elem,'x')) - -def __cut_internal(sentence,HMM=True): + for elem in buf: + yield pair(elem, word_tag_tab.get(elem, 'x')) + +def __cut_internal(sentence, HMM=True): if not isinstance(sentence, str): try: sentence = sentence.decode('utf-8') - except: - sentence = sentence.decode('gbk','ignore') + except UnicodeDecodeError: + sentence = sentence.decode('gbk', 'ignore') re_han, re_skip = re.compile(r"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)"), re.compile(r"(\r\n|\s)") - re_eng,re_num = re.compile(r"[a-zA-Z0-9]+"), re.compile(r"[\.0-9]+") + re_eng, re_num = re.compile(r"[a-zA-Z0-9]+"), re.compile(r"[\.0-9]+") blocks = re_han.split(sentence) if HMM: __cut_blk = __cut_DAG @@ -214,15 +213,15 @@ def __cut_internal(sentence,HMM=True): tmp = re_skip.split(blk) for x in tmp: if re_skip.match(x): - yield pair(x,'x') + yield pair(x, 'x') else: for xx in x: if re_num.match(xx): - yield pair(xx,'m') + yield pair(xx, 'm') elif re_eng.match(x): - yield pair(xx,'eng') + yield pair(xx, 'eng') else: - yield pair(xx,'x') + yield pair(xx, 'x') def __lcut_internal(sentence): return list(__cut_internal(sentence)) @@ -231,16 +230,16 @@ def __lcut_internal_no_hmm(sentence): @makesure_userdict_loaded -def cut(sentence,HMM=True): - if (not hasattr(jieba,'pool')) or (jieba.pool==None): - for w in __cut_internal(sentence,HMM=HMM): +def cut(sentence, HMM=True): + if (not hasattr(jieba, 'pool')) or (jieba.pool is None): + for w in __cut_internal(sentence, HMM=HMM): yield w else: parts = re.compile('([\r\n]+)').split(sentence) if HMM: - result = jieba.pool.map(__lcut_internal,parts) + result = jieba.pool.map(__lcut_internal, parts) else: - result = jieba.pool.map(__lcut_internal_no_hmm,parts) + result = jieba.pool.map(__lcut_internal_no_hmm, parts) for r in result: for w in r: yield w diff --git a/jieba/posseg/viterbi.py b/jieba/posseg/viterbi.py index a95921c1..dd707cdb 100644 --- a/jieba/posseg/viterbi.py +++ b/jieba/posseg/viterbi.py @@ -1,46 +1,45 @@ import operator -MIN_FLOAT=-3.14e100 -MIN_INF=float("-inf") +MIN_FLOAT = -3.14e100 +MIN_INF = float("-inf") -def get_top_states(t_state_v,K=4): +def get_top_states(t_state_v, K=4): items = t_state_v.items() - topK= sorted(items,key=operator.itemgetter(1),reverse=True)[:K] + topK = sorted(items, key=operator.itemgetter(1), reverse=True)[:K] return [x[0] for x in topK] def viterbi(obs, states, start_p, trans_p, emit_p): V = [{}] #tabular mem_path = [{}] all_states = trans_p.keys() - for y in states.get(obs[0],all_states): #init - V[0][y] = start_p[y] + emit_p[y].get(obs[0],MIN_FLOAT) + for y in states.get(obs[0], all_states): #init + V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT) mem_path[0][y] = '' - for t in range(1,len(obs)): + for t in range(1, len(obs)): V.append({}) mem_path.append({}) #prev_states = get_top_states(V[t-1]) - prev_states =[ x for x in mem_path[t-1].keys() if len(trans_p[x])>0 ] + prev_states = [x for x in mem_path[t-1].keys() if len(trans_p[x]) > 0] - prev_states_expect_next = set( (y for x in prev_states for y in trans_p[x].keys() ) ) - obs_states = states.get(obs[t],all_states) - obs_states = set(obs_states) & set(prev_states_expect_next) + prev_states_expect_next = set((y for x in prev_states for y in trans_p[x].keys())) + obs_states = set(states.get(obs[t], all_states)) & prev_states_expect_next - if len(obs_states)==0: obs_states = prev_states_expect_next - if len(obs_states)==0: obs_states = all_states + if not obs_states: + obs_states = prev_states_expect_next if prev_states_expect_next else all_states for y in obs_states: - (prob,state ) = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT) ,y0) for y0 in prev_states]) - V[t][y] =prob + prob, state = max([(V[t-1][y0] + trans_p[y0].get(y,MIN_INF) + emit_p[y].get(obs[t],MIN_FLOAT), y0) for y0 in prev_states]) + V[t][y] = prob mem_path[t][y] = state - last = [(V[-1][y], y) for y in mem_path[-1].keys() ] + last = [(V[-1][y], y) for y in mem_path[-1].keys()] #if len(last)==0: #print obs - (prob, state) = max(last) + prob, state = max(last) route = [None] * len(obs) - i = len(obs)-1 - while i>=0: + i = len(obs) - 1 + while i >= 0: route[i] = state state = mem_path[i][state] - i-=1 - return (prob, route) \ No newline at end of file + i -= 1 + return (prob, route) From 14671d4feb22844c0839eeedf4fe7f8d43c501d9 Mon Sep 17 00:00:00 2001 From: Dingyuan Wang Date: Sun, 19 Oct 2014 10:41:09 +0800 Subject: [PATCH 2/2] fix __main__.py --- jieba/__main__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jieba/__main__.py b/jieba/__main__.py index bdc94fa9..d52f3eec 100644 --- a/jieba/__main__.py +++ b/jieba/__main__.py @@ -3,7 +3,7 @@ import jieba from argparse import ArgumentParser -parser = ArgumentParser(usage="%s -m jieba [options] filename" % sys.executable, description="Jieba command line interface.", version="Jieba " + jieba.__version__, epilog="If no filename specified, use STDIN instead.") +parser = ArgumentParser(usage="%s -m jieba [options] filename" % sys.executable, description="Jieba command line interface.", epilog="If no filename specified, use STDIN instead.") parser.add_argument("-d", "--delimiter", metavar="DELIM", default=' / ', nargs='?', const=' ', help="use DELIM instead of ' / ' for word delimiter; use a space if it is without DELIM") @@ -14,6 +14,7 @@ default=True, help="don't use the Hidden Markov Model") parser.add_argument("-q", "--quiet", action="store_true", default=False, help="don't print loading messages to stderr") +parser.add_argument("-V", '--version', action='version', version="Jieba " + jieba.__version__) parser.add_argument("filename", nargs='?', help="input file") args = parser.parse_args() @@ -29,7 +30,7 @@ ln = fp.readline() while ln: l = ln.rstrip('\r\n') - print(delim.join(jieba.cut(ln.rstrip('\r\n'), cutall, hmm)).encode('utf-8')) + print(delim.join(jieba.cut(ln.rstrip('\r\n'), cutall, hmm))) ln = fp.readline() fp.close()