美文网首页
Kaggle文本挖掘获奖选手代码解析(一):数据预处理

Kaggle文本挖掘获奖选手代码解析(一):数据预处理

作者: 马尔克ov | 来源:发表于2017-08-12 16:25 被阅读480次

    一.题目背景

    kaggle上这三道题目都是和文本相似度相关的,要求建模评估两个文本内容的相关度。

    两个电商场景:
    https://www.kaggle.com/c/crowdflower-search-relevance
    https://www.kaggle.com/c/home-depot-product-search-relevance
    要求评估用户输入的搜索项(query)和查询的商品之间的相似度

    一个论坛场景:
    https://www.kaggle.com/c/quora-question-pairs
    要求评估两个问题之间的相似度

    虽然应用场景不同, 但我们可以有一些共同的套路来解决这些问题。

    二.第三名代码解析

    文章中出现的代码都来自这个github地址: https://github.com/ChenglongChen/Kaggle_HomeDepot

    三.代码解析

    1. pattern-replace基类

    初始化: 传一个pattern_replace_pair_list, 把满足pattern的部分转换成replace
    转换: 按每一对pattern, replace(tuple格式)做替换, 去掉开头结尾空格(strip())

    class BaseReplacer:
        def __init__(self, pattern_replace_pair_list=[]):
            self.pattern_replace_pair_list = pattern_replace_pair_list
            
        def transform(self, text):
            for pattern, replace in self.pattern_replace_pair_list:
                try:
                    text = regex.sub(pattern, replace, text)
                except:
                    pass
            return regex.sub(r"\s+", " ", text).strip()
    

    2. 首字母大写变小写

    覆盖父类transform方法

    class LowerCaseConverter(BaseReplacer):
        def transform(self, text):
            return text.lower()
    

    3. 本应该分开的两个单词,第二个单词首字母大写

    把:hidden from viewDurable rich finishLimited lifetime warrantyEncapsulated panels
    变成: hidden from view Durable rich finish limited lifetime warranty Encapsulated panels

    class LowerUpperCaseSplitter(BaseReplacer):
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"(\w)[\.?!]([A-Z])", r"\1 \2"),
                (r"(?<=( ))([a-z]+)([A-Z]+)", r"\2 \3"),
            ]
    

    正则表达式: ()表示分组, \1 \2 \3可以取第几个分组

    4. 单词替换

    给一个词典, 一一对应替换, 用于改正typo
    子类初始化时只积累pattern_replace_pair_list, 并没有完成替换.
    调用transform方法才替换

    class WordReplacer(BaseReplacer):
        def __init__(self, replace_fname):
            self.replace_fname = replace_fname  # 词典文件路径
            self.pattern_replace_pair_list = []
            for line in csv.reader(open(self.replace_fname)):
                if len(line) == 1 and line[0].startswith("#"):
                    continue
                try:
                    pattern = r"(?<=\W|^)%s(?=\W|$)"%line[0]   
                    replace = line[1]
                    self.pattern_replace_pair_list.append( (pattern, replace) )
                except:
                    print(line)
                    pass
    

    5. 分割连接符连接的单词

    类似上面()正则的用法

    class LetterLetterSplitter(BaseReplacer):
        """
        For letter and letter
        /:
        Cleaner/Conditioner -> Cleaner Conditioner
    
        -:
        Vinyl-Leather-Rubber -> Vinyl Leather Rubber
    
        For digit and digit, we keep it as we will generate some features via math operations,
        such as approximate height/width/area etc.
        /:
        3/4 -> 3/4
    
        -:
        1-1/4 -> 1-1/4
        """
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"([a-zA-Z]+)[/\-]([a-zA-Z]+)", r"\1 \2"),
            ]
    

    6. 数字加字母

    数字.或-字母
    字母.或-数字

    class DigitLetterSplitter(BaseReplacer):
        """
        x:
        1x1x1x1x1 -> 1 x 1 x 1 x 1 x 1
        19.875x31.5x1 -> 19.875 x 31.5 x 1
    
        -:
        1-Gang -> 1 Gang
        48-Light -> 48 Light
    
        .:
        includes a tile flange to further simplify installation.60 in. L x 36 in. W x 20 in. ->
        includes a tile flange to further simplify installation. 60 in. L x 36 in. W x 20 in.
        """
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"(\d+)[\.\-]*([a-zA-Z]+)", r"\1 \2"),
                (r"([a-zA-Z]+)[\.\-]*(\d+)", r"\1 \2"),
            ]
    

    7. 去掉数字中的逗号

    class DigitCommaDigitMerger(BaseReplacer):
        """
        1,000,000 -> 1000000
        """
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"(?<=\d+),(?=000)", r""),
            ]
    

    8. 单词的数字换成阿拉伯数字

    class NumberDigitMapper(BaseReplacer):
        """
        one -> 1
        two -> 2
        """
        def __init__(self):
            numbers = [
                "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
                "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen", "eighteen",
                "nineteen", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"
            ]
            digits = [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
                16, 17, 18, 19, 20, 30, 40, 50, 60, 70, 80, 90
            ]
            self.pattern_replace_pair_list = [
                (r"(?<=\W|^)%s(?=\W|$)"%n, str(d)) for n,d in zip(numbers, digits)
            ]
    

    9. 统一单位

    class UnitConverter(BaseReplacer):
        """
        shadeMature height: 36 in. - 48 in.Mature width
        PUT one UnitConverter before LowerUpperCaseSplitter
        """
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"([0-9]+)( *)(inches|inch|in|in.|')\.?", r"\1 in. "),
                (r"([0-9]+)( *)(pounds|pound|lbs|lb|lb.)\.?", r"\1 lb. "),
                (r"([0-9]+)( *)(foot|feet|ft|ft.|'')\.?", r"\1 ft. "),
                (r"([0-9]+)( *)(square|sq|sq.) ?\.?(inches|inch|in|in.|')\.?", r"\1 sq.in. "),
                (r"([0-9]+)( *)(square|sq|sq.) ?\.?(feet|foot|ft|ft.|'')\.?", r"\1 sq.ft. "),
                (r"([0-9]+)( *)(cubic|cu|cu.) ?\.?(inches|inch|in|in.|')\.?", r"\1 cu.in. "),
                (r"([0-9]+)( *)(cubic|cu|cu.) ?\.?(feet|foot|ft|ft.|'')\.?", r"\1 cu.ft. "),
                (r"([0-9]+)( *)(gallons|gallon|gal)\.?", r"\1 gal. "),
                (r"([0-9]+)( *)(ounces|ounce|oz)\.?", r"\1 oz. "),
                (r"([0-9]+)( *)(centimeters|cm)\.?", r"\1 cm. "),
                (r"([0-9]+)( *)(milimeters|mm)\.?", r"\1 mm. "),
                (r"([0-9]+)( *)(minutes|minute)\.?", r"\1 min. "),
                (r"([0-9]+)( *)(°|degrees|degree)\.?", r"\1 deg. "),
                (r"([0-9]+)( *)(v|volts|volt)\.?", r"\1 volt. "),
                (r"([0-9]+)( *)(wattage|watts|watt)\.?", r"\1 watt. "),
                (r"([0-9]+)( *)(amperes|ampere|amps|amp)\.?", r"\1 amp. "),
                (r"([0-9]+)( *)(qquart|quart)\.?", r"\1 qt. "),
                (r"([0-9]+)( *)(hours|hour|hrs.)\.?", r"\1 hr "),
                (r"([0-9]+)( *)(gallons per minute|gallon per minute|gal per minute|gallons/min.|gallons/min)\.?", r"\1 gal. per min. "),
                (r"([0-9]+)( *)(gallons per hour|gallon per hour|gal per hour|gallons/hour|gallons/hr)\.?", r"\1 gal. per hr "),
            ]
    

    10. 去掉html标签

    class HtmlCleaner:
        def __init__(self, parser):
            self.parser = parser
    
        def transform(self, text):
            bs = BeautifulSoup(text, self.parser)
            text = bs.get_text(separator=" ")
            return text
    

    11. 标点乱码替换成原始标点

    class QuartetCleaner(BaseReplacer):
        def __init__(self):
            self.pattern_replace_pair_list = [
                (r"<.+?>", r""),
                # html codes
                (r" ", r" "),
                (r"&", r"&"),
                (r"'", r"'"),
                (r"/>/Agt/>", r""),
                (r"</a<gt/", r""),
                (r"gt/>", r""),
                (r"/>", r""),
                (r"<br", r""),
                # do not remove [".", "/", "-", "%"] as they are useful in numbers, e.g., 1.97, 1-1/2, 10%, etc.
                (r"[ &<>)(_,;:!?\+^~@#\$]+", r" "),
                ("'s\\b", r""),
                (r"[']+", r""),
                (r"[\"]+", r""),
            ]
    

    12. 提取词干Lemma和词型归一stemmer

    区别:Lemma更短,更"根本".比如
    Stemmers: having->hav
    Lemmatizers: having->have

    class Lemmatizer:
        def __init__(self):
            self.Tokenizer = nltk.tokenize.TreebankWordTokenizer()
            self.Lemmatizer = nltk.stem.wordnet.WordNetLemmatizer()
    
        def transform(self, text):
            tokens = [self.Lemmatizer.lemmatize(token) for token in self.Tokenizer.tokenize(text)]
            return " ".join(tokens)
    
    
    ## stemming
    class Stemmer:
        def __init__(self, stemmer_type="snowball"):
            self.stemmer_type = stemmer_type
            if self.stemmer_type == "porter":
                self.stemmer = nltk.stem.PorterStemmer()
            elif self.stemmer_type == "snowball":
                self.stemmer = nltk.stem.SnowballStemmer("english")
    
        def transform(self, text):
            tokens = [self.stemmer.stem(token) for token in text.split(" ")]
            return " ".join(tokens)
    

    13. Query拓展

    因为title和search term属于短文本, 做一些处理丰富这部分内容。
    title_ngram: 把商品的title部分做ngram处理,连续的n个单词用一个连接符连在一起。
    search_term_alt: 搜索项部分找每个搜索项中最常见的title_ngram。

    class QueryExpansion:
        def __init__(self, df, ngram=3, stopwords_threshold=0.9, base_stopwords=set()):
            self.df = df[["search_term", "product_title"]].copy()
            self.ngram = ngram
            self.stopwords_threshold = stopwords_threshold
            self.stopwords = set(base_stopwords).union(self._get_customized_stopwords())
            
        def _get_customized_stopwords(self):
            words = " ".join(list(self.df["product_title"].values)).split(" ")
            counter = Counter(words)
            num_uniq = len(list(counter.keys()))
            num_stop = int((1.-self.stopwords_threshold)*num_uniq)
            stopwords = set()
            for e,(w,c) in enumerate(sorted(counter.items(), key=lambda x: x[1])):
                if e == num_stop:
                    break
                stopwords.add(w)
            return stopwords
    
        def _ngram(self, text):
            tokens = text.split(" ")
            tokens = [token for token in tokens if token not in self.stopwords]
            return ngram_utils._ngrams(tokens, self.ngram, " ")
    
        def _get_alternative_query(self, df):
            res = []
            for v in df:
                res += v
            c = Counter(res)
            value, count = c.most_common()[0]
            return value
    
        def build(self):
            self.df["title_ngram"] = self.df["product_title"].apply(self._ngram)
            corpus = self.df.groupby("search_term").apply(lambda df: self._get_alternative_query(df["title_ngram"]))
            corpus = corpus.reset_index()
            corpus.columns = ["search_term", "search_term_alt"]
            self.df = pd.merge(self.df, corpus, on="search_term", how="left")
            return self.df["search_term_alt"].values
    

    其中的ngram,以n=2为例:

    def _bigrams(words, join_string, skip=0):
        """
           Input: a list of words, e.g., ["I", "am", "Denny"]
           Output: a list of bigram, e.g., ["I_am", "am_Denny"]
           I use _ as join_string for this example.
        """
        assert type(words) == list
        L = len(words)
        if L > 1:
            lst = []
            for i in range(L-1):
                for k in range(1,skip+2):
                    if i+k < L:
                        lst.append( join_string.join([words[i], words[i+k]]) )
        else:
            # set it as unigram
            lst = _unigrams(words)
        return lst
    

    14. 处理商品名称

    transform前需要先用前面的类统一格式

    class ProductNameExtractor(BaseReplacer):
        def __init__(self):
            self.pattern_replace_pair_list = [
                # Remove descriptions (text between paranthesis/brackets)
                ("[ ]?[[(].+?[])]", r""),
                # Remove "made in..."
                ("made in [a-z]+\\b", r""),
                # Remove descriptions (hyphen or comma followed by space then at most 2 words, repeated)
                ("([,-]( ([a-zA-Z0-9]+\\b)){1,2}[ ]?){1,}$", r""),
                # Remove descriptions (prepositions staring with: with, for, by, in )
                ("\\b(with|for|by|in|w/) .+$", r""),
                # colors & sizes
                ("size: .+$", r""),
                ("size [0-9]+[.]?[0-9]+\\b", r""),
                (COLORS_PATTERN, r""),
                # dimensions
                (DIM_PATTERN_NxNxN, r""),
                (DIM_PATTERN_NxN, r""),
                # measurement units
                (UNITS_PATTERN, r""),
                # others
                ("(value bundle|warranty|brand new|excellent condition|one size|new in box|authentic|as is)", r""),
                # stop words
                ("\\b(in)\\b", r""),
                # hyphenated words
                ("([a-zA-Z])-([a-zA-Z])", r"\1\2"),
                # special characters
                ("[ &<>)(_,.;:!?/+#*-]+", r" "),
                # numbers that are not part of a word
                ("\\b[0-9]+\\b", r""),
            ]
            
        def preprocess(self, text):
            pattern_replace_pair_list = [
                # Remove single & double apostrophes
                ("[\"]+", r""),
                # Remove product codes (long words (>5 characters) that are all caps, numbers or mix pf both)
                # don't use raw string format
                ("[ ]?\\b[0-9A-Z-]{5,}\\b", ""),
            ]
            text = BaseReplacer(pattern_replace_pair_list).transform(text)
            text = LowerCaseConverter().transform(text)
            text = DigitLetterSplitter().transform(text)
            text = UnitConverter().transform(text)
            text = DigitCommaDigitMerger().transform(text)
            text = NumberDigitMapper().transform(text)
            text = UnitConverter().transform(text)
            return text
            
        def transform(self, text):
            text = super().transform(self.preprocess(text))
            text = Lemmatizer().transform(text)
            text = Stemmer(stemmer_type="snowball").transform(text)
            # last two words in product
            text = " ".join(text.split(" ")[-2:])
            return text
    

    15. 处理商品属性

    根据输出的不同格式写两个方法

    def _split_attr_to_text(text):
        attrs = text.split(config.ATTR_SEPARATOR)
        return " ".join(attrs)
    
    def _split_attr_to_list(text):
        attrs = text.split(config.ATTR_SEPARATOR)        
        if len(attrs) == 1:
            # missing
            return [[attrs[0], attrs[0]]]
        else:
            return [[n,v] for n,v in zip(attrs[::2], attrs[1::2])]
    

    16. 处理不同数据结构的输入输出

    待处理的数据可能放在list里,也可能放在dataframe里, 处理方法稍有不同。

    class ListProcessor:
        """
        WARNING: This class will operate on the original input list itself
        """
        def __init__(self, processors):
            self.processors = processors
    
        def process(self, lst):
            for i in range(len(lst)):
                for processor in self.processors:
                    lst[i] = ProcessorWrapper(processor).transform(lst[i])
            return lst
    
    
    class DataFrameProcessor:
        """
        WARNING: This class will operate on the original input dataframe itself
        """
        def __init__(self, processors):
            self.processors = processors
    
        def process(self, df):
            for processor in self.processors:
                df = df.apply(ProcessorWrapper(processor).transform)
            return df
    
    
    class DataFrameParallelProcessor:
        """
        WARNING: This class will operate on the original input dataframe itself
    
        https://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
        """
        def __init__(self, processors, n_jobs=4):
            self.processors = processors
            self.n_jobs = n_jobs
    
        def process(self, dfAll, columns):
            df_processor = DataFrameProcessor(self.processors)
            p = multiprocessing.Pool(self.n_jobs)
            dfs = p.imap(df_processor.process, [dfAll[col] for col in columns])
            for col,df in zip(columns, dfs):
                dfAll[col] = df
            return dfAll
    

    四. 总结

    把上面的代码都用户合并到一起就是main函数咯。总结一下大神的代码有几点很值得学习:

    1. 重复的内容用类,和类的继承减少代码重复(各种replacer),并且放到一起使得逻辑清晰。
    2. 用config文件中的字段控制部分代码是否运行(各种if config.大写单词), 比反复注释,取消注释节省时间。
    3. 各种设置统一到一个config文件中,避免调试的时候到处找,浪费时间。
    def main():
    
        ###########
        ## Setup ##
        ###########
        logname = "data_processor_%s.log"%now
        logger = logging_utils._get_logger(config.LOG_DIR, logname)
    
        # put product_attribute_list, product_attribute and product_description first as they are
        # quite time consuming to process
        columns_to_proc = [
            # # product_attribute_list is very time consuming to process
            # # so we just process product_attribute which is of the form 
            # # attr_name1 | attr_value1 | attr_name2 | attr_value2 | ...
            # # and split it into a list afterwards
            # "product_attribute_list",
            "product_attribute_concat",
            "product_description",
            "product_brand", 
            "product_color",
            "product_title",
            "search_term", 
        ]
        if config.PLATFORM == "Linux":
            config.DATA_PROCESSOR_N_JOBS = len(columns_to_proc)
    
        # clean using a list of processors
        processors = [
            LowerCaseConverter(), 
            # See LowerUpperCaseSplitter and UnitConverter for why we put UnitConverter here
            UnitConverter(),
            LowerUpperCaseSplitter(), 
            WordReplacer(replace_fname=config.WORD_REPLACER_DATA), 
            LetterLetterSplitter(),
            DigitLetterSplitter(), 
            DigitCommaDigitMerger(), 
            NumberDigitMapper(),
            UnitConverter(), 
            QuartetCleaner(), 
            HtmlCleaner(parser="html.parser"), 
            Lemmatizer(),
        ]
        stemmers = [
            Stemmer(stemmer_type="snowball"), 
            Stemmer(stemmer_type="porter")
        ][0:1]
    
        ## simple test
        text = "1/2 inch rubber lep tips Bullet07"
        print("Original:")
        print(text)
        list_processor = ListProcessor(processors)
        print("After:")
        print(list_processor.process([text]))
    
        #############
        ## Process ##
        #############
        ## load raw data
        dfAll = pkl_utils._load(config.ALL_DATA_RAW)
        columns_to_proc = [col for col in columns_to_proc if col in dfAll.columns]
    
    
        ## extract product name from search_term and product_title
        ext = ProductNameExtractor()
        dfAll["search_term_product_name"] = dfAll["search_term"].apply(ext.transform)
        dfAll["product_title_product_name"] = dfAll["product_title"].apply(ext.transform)
        if config.TASK == "sample":
            print(dfAll[["search_term", "search_term_product_name", "product_title_product_name"]])
    
    
        ## clean using GoogleQuerySpellingChecker
        # MUST BE IN FRONT OF ALL THE PROCESSING
        if config.GOOGLE_CORRECTING_QUERY:
            logger.info("Run GoogleQuerySpellingChecker at search_term")
            checker = GoogleQuerySpellingChecker()
            dfAll["search_term"] = dfAll["search_term"].apply(checker.correct)
    
    
        ## clean uisng a list of processors
        df_processor = DataFrameParallelProcessor(processors, config.DATA_PROCESSOR_N_JOBS)
        df_processor.process(dfAll, columns_to_proc)
        # split product_attribute_concat into product_attribute and product_attribute_list
        dfAll["product_attribute"] = dfAll["product_attribute_concat"].apply(_split_attr_to_text)
        dfAll["product_attribute_list"] = dfAll["product_attribute_concat"].apply(_split_attr_to_list)
        if config.TASK == "sample":
            print(dfAll[["product_attribute", "product_attribute_list"]])
        # query expansion
        if config.QUERY_EXPANSION:
            list_processor = ListProcessor(processors)
            base_stopwords = set(list_processor.process(list(config.STOP_WORDS)))
            qe = QueryExpansion(dfAll, ngram=3, stopwords_threshold=0.9, base_stopwords=base_stopwords)
            dfAll["search_term_alt"] = qe.build()
            if config.TASK == "sample":
                print(dfAll[["search_term", "search_term_alt"]])
        # save data
        logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED)
        columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
        pkl_utils._save(config.ALL_DATA_LEMMATIZED, dfAll[columns_to_save])
    
    
        ## auto correcting query
        if config.AUTO_CORRECTING_QUERY:
            logger.info("Run AutoSpellingChecker at search_term")
            checker = AutoSpellingChecker(dfAll, exclude_stopwords=False, min_len=4)
            dfAll["search_term_auto_corrected"] = list(dfAll["search_term"].apply(checker.correct))
            columns_to_proc += ["search_term_auto_corrected"]
            if config.TASK == "sample":
                print(dfAll[["search_term", "search_term_auto_corrected"]])
            # save query_correction_map and spelling checker
            fname = "%s/auto_spelling_checker_query_correction_map_%s.log"%(config.LOG_DIR, now)
            checker.save_query_correction_map(fname)
            # save data
            logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED)
            columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
            pkl_utils._save(config.ALL_DATA_LEMMATIZED, dfAll[columns_to_save])
    
    
        ## clean using stemmers
        df_processor = DataFrameParallelProcessor(stemmers, config.DATA_PROCESSOR_N_JOBS)
        df_processor.process(dfAll, columns_to_proc)
        # split product_attribute_concat into product_attribute and product_attribute_list
        dfAll["product_attribute"] = dfAll["product_attribute_concat"].apply(_split_attr_to_text)
        dfAll["product_attribute_list"] = dfAll["product_attribute_concat"].apply(_split_attr_to_list)
        # query expansion
        if config.QUERY_EXPANSION:
            list_processor = ListProcessor(stemmers)
            base_stopwords = set(list_processor.process(list(config.STOP_WORDS)))
            qe = QueryExpansion(dfAll, ngram=3, stopwords_threshold=0.9, base_stopwords=base_stopwords)
            dfAll["search_term_alt"] = qe.build()
            if config.TASK == "sample":
                print(dfAll[["search_term", "search_term_alt"]])
        # save data
        logger.info("Save to %s"%config.ALL_DATA_LEMMATIZED_STEMMED)
        columns_to_save = [col for col in dfAll.columns if col != "product_attribute_concat"]
        pkl_utils._save(config.ALL_DATA_LEMMATIZED_STEMMED, dfAll[columns_to_save])
    

    相关文章

      网友评论

          本文标题:Kaggle文本挖掘获奖选手代码解析(一):数据预处理

          本文链接:https://www.haomeiwen.com/subject/gnvfrxtx.html