]> code.communitydata.science - articlequality_ordinal.git/blob - sample_training_labels.py
add the rest of the code.
[articlequality_ordinal.git] / sample_training_labels.py
1 #!/usr/bin/env python3
2
3 '''
4 Take a stratified sample of article quality labels. 
5
6 For now we just stratify by label type. 
7 Later we might add date. 
8 Later we might stratify by wikiproject too.
9
10 A key limitation of this approach is that we can sample on the level of the page. 
11 We'd really like to be able to sample on the level of edit session. 
12 But that isn't possible because of how article assessments work. 
13 '''
14 from itertools import islice, chain
15 from pathlib import Path
16 import pandas as pd
17 import numpy as np
18 random = np.random.RandomState(1968)
19 import json
20 import pyarrow.feather as feather
21 import fire
22 from collections import Counter
23 from pyRemembeR import Remember
24 from enum import IntEnum, unique
25 from datetime import datetime
26 from dataclasses import dataclass, asdict
27 from multiprocessing import Pool
28 from urllib.parse import unquote
29 from pyspark.sql import functions as f
30 from pyspark.sql import SparkSession, Window
31 from pyspark.sql.functions import udf
32 from pyspark.sql.types import StringType
33 from numpy import dtype
34 import csv
35
36 def wikiq_to_parquet():
37
38     path = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/")
39     outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_parquet/")
40     files = list(map(Path,path.glob("*.tsv")))
41     dumpfile = files[0]
42
43     def wikiq_tsv_to_parquet(dumpfile, outpath = Path("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet/")):
44         outfile = outpath / (dumpfile.name + ".parquet")
45         outpath.mkdir(parents=True, exist_ok=True)
46         _wikiq_tsv_to_parquet(dumpfile,outfile)
47
48     dumpfile = Path("/gscratch/comdata/users/nathante/wikiqRunning/wikiq_output/enwiki-20200301-pages-meta-history12-p4980874p5038451.tsv")
49
50     def _wikiq_tsv_to_parquet(dumpfile, outfile):
51
52         dtypes = {'anon': dtype('O'), 'articleid': dtype('int64'), 'deleted': dtype('bool'), 'editor': dtype('O'), 'editor_id': dtype('float64'), 'minor': dtype('bool'), 'namespace': dtype('int64'), 'revert': dtype('O'), 'reverteds': dtype('O'), 'revid': dtype('int64'), 'sha1': dtype('O'), 'text_chars': dtype('float64'), 'title': dtype('O')}
53
54         print(dumpfile)
55         df = pd.read_csv(dumpfile,sep='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False, warn_bad_lines=True,parse_dates=['date_time'],dtype=dtypes)
56
57         df.to_parquet(outfile)
58
59     with Pool(28) as pool:
60         jobs = pool.imap_unordered(wikiq_tsv_to_parquet, files)
61         list(jobs)
62
63     spark = SparkSession.builder.getOrCreate()
64
65     @udf(StringType())
66     def decode_strip_udf(val):
67         if val is None:
68             return ""
69         else:
70             return unquote(val).strip('\"')
71     df = spark.read.parquet('/gscratch/comdata/output/wikiq_enwiki_20200301_nathante.parquet')
72     df = df.withColumnRenamed("anon","anonRaw")
73     df = df.withColumn("anon",f.when(f.col("anonRaw")=="TRUE",True).otherwise(False))
74     df = df.drop("anonRaw")
75     df = df.withColumnRenamed("text_chars","text_chars_raw")
76     df = df.withColumn("text_chars",f.col("text_chars_raw").cast('int'))
77     df = df.drop("text_chars_raw")
78     df = df.withColumnRenamed("editor_id",'editor_id_raw')
79     df = df.withColumn("editor_id",f.col("editor_id_raw").cast("int"))
80     df = df.drop("editor_id_raw")
81     df = df.withColumnRenamed("revert","revert_raw")
82     df = df.withColumn("revert",f.when(f.col("revert_raw")=="TRUE",True).otherwise(False))
83     df = df.drop("revert_raw")
84     df = df.withColumnRenamed("title","title_raw")
85     df = df.withColumn("title", decode_strip_udf(f.col("title_raw")))
86     df = df.drop("title_raw")
87     df = df.withColumnRenamed("editor","editor_raw")
88     df = df.withColumn("editor", decode_strip_udf(f.col("editor_raw")))
89     df = df.drop("editor_raw")
90     df = df.repartition(400,'articleid')
91     df.write.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet",mode='overwrite')
92
93 @unique
94 class WP10(IntEnum):
95     start = 1
96     stub = 2
97     c = 3
98     b = 4
99     a = 5
100     ga = 6
101     fa = 7
102
103     @staticmethod
104     def from_string(s):
105         return {'start':WP10.start,
106                 'stub':WP10.stub,
107                 'c':WP10.c,
108                 'b':WP10.b,
109                 'a':WP10.a,
110                 'ga':WP10.ga,
111                 'fa':WP10.fa}.get(s,None)
112     
113     def to_string(self):
114         return {WP10.start:'start',
115                 WP10.stub:'stub',
116                 WP10.c:'c',
117                 WP10.b:'b',
118                 WP10.a:'a',
119                 WP10.ga:'ga',
120                 WP10.fa:'fa'}[self]
121
122
123 @dataclass
124 class PageLabel:
125     timestamp:datetime
126     wp10:WP10
127
128     @staticmethod
129     def from_json(obj):
130         timestamp = obj.get('timestamp',None) 
131         if timestamp is not None:
132             timestamp = datetime.strptime(obj['timestamp'],'%Y%m%d%H%M%S')
133         else:
134             timestamp = None
135
136         return PageLabel(timestamp=timestamp,
137                          wp10=WP10.from_string(obj.get('wp10')))
138
139     @staticmethod
140     def from_row(row):
141         return PageLabel(timestamp = row.timestamp,
142                          wp10 = WP10(row.wp10))
143
144     def to_json(self):
145         d = asdict(self)
146
147         if self.timestamp is not None:
148             d['timestamp'] =  self.timestamp.strftime('%Y%m%d%H%M%S')
149
150         if self.wp10 is not None:
151             d['wp10'] = self.wp10.to_string()
152
153         return json.dumps(d)
154
155 @dataclass
156 class TalkPageLabel(PageLabel):
157     dump_talk_page_title:str
158     talk_page_id:int
159     project:str
160
161     @staticmethod
162     def from_json(obj):
163         res = PageLabel.from_json(obj)  
164
165         return TalkPageLabel(dump_talk_page_title=obj.get('dump_talk_page_title',None),
166                              talk_page_id=obj.get('talk_page_id',None),
167                              project=obj.get("project",None),
168                              **asdict(res)
169                              )
170     @staticmethod
171     def from_row(row):
172         res = PageLabel.from_row(row)
173         return TalkPageLabel(dump_talk_page_title = row.dump_talk_page_title,
174                              talk_page_id = row.talk_page_id,
175                              project = row.project
176                              **asdict(res))
177
178
179                          
180 @dataclass
181 class ArticlePageLabel(PageLabel):
182     '''class representing labels to a page'''
183     title: str
184     articleid: int
185     revid:int
186
187     @staticmethod
188     def from_json(obj):
189         res = PageLabel.from_json(obj)
190
191         return ArticlePageLabel(title=obj.get('title',None),
192                                 articleid=obj.get('articleid',None),
193                                 **asdict(res)
194                                 )
195
196     @staticmethod
197     def from_row(row):
198         res = PageLabel.from_row(row)
199         return ArticlePageLabel(title = row.title,
200                                 articleid = row.articleid,
201                                 revid = row.revid,
202                                 **asdict(res))
203                          
204 infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json"; samplesize=5000*7
205
206 def main(infiles="enwiki-20200301-pages-meta-history*.xml-p*.7z_article_labelings.json", samplesize=5000*7):
207     path = Path('data')
208     infiles = path.glob(infiles)
209
210     pool = Pool(28)
211
212     lines = chain(* map(lambda f: open(f,'r'), infiles))
213
214     parsed = pool.imap_unordered(json.loads, lines, chunksize=int(1e3))
215     formatted = pool.imap_unordered(TalkPageLabel.from_json, parsed, chunksize=int(1e3))
216     dicted = pool.imap_unordered(asdict,formatted, chunksize=int(1e3))
217
218     # data frame of the the latest labels.
219     df = pd.DataFrame(dicted)
220
221     df = df.loc[df.timestamp <= datetime(2019,1,1)]
222
223     groups = df.groupby(["talk_page_id"])
224     max_labels = groups.wp10.max().reset_index()
225
226     df2 = pd.merge(df,max_labels,on=['talk_page_id','wp10'],how='right')
227     last_timestamp = df2.groupby(['talk_page_id']).timestamp.max().reset_index()
228     
229     df2 = pd.merge(df2, last_timestamp, on=['talk_page_id','timestamp'], how='right')
230     first_project = df2.groupby(['talk_page_id']).project.first()
231     df2 = pd.merge(df2, first_project,on=['talk_page_id','project'], how='right')
232
233     tpid = df2
234
235     #.wp10.max().reset_index()
236     tpid = tpid.loc[~tpid.dump_talk_page_title.isna()]
237  
238     # pick out just the samples we want.
239     spark = SparkSession.builder.getOrCreate()
240
241     sparkdf = spark.read.parquet("/gscratch/comdata/output/wikiq_enwiki_20200301_nathante_partitioned.parquet")
242     
243     tpid['timestamp'] = tpid['timestamp'].dt.tz_localize('utc')
244     labels = spark.createDataFrame(tpid)
245     talks = sparkdf.filter(sparkdf.namespace==1)
246     articles = sparkdf.filter(sparkdf.namespace==0)
247
248     # labels = labels.join(talks,on=[labels.talk_page_id == talks.articleid],how='left_outer')
249
250     talks = talks.join(labels,on=[labels.talk_page_id == talks.articleid])
251
252     #talks.filter(talks.wp10==7).select('talk_page_id').distinct().count()
253
254     talks = talks.withColumn('timediff', f.datediff(talks.timestamp, talks.date_time))
255
256     talks = talks.filter(talks.timediff <= 0)
257
258     win = Window.partitionBy("talk_page_id")
259     talks = talks.withColumn('best_timediff', f.max('timediff').over(win))
260     talks = talks.filter(talks.timediff == talks.best_timediff)
261
262     talks = talks.withColumn('article_title',f.substring_index(f.col("title"),':',-1))
263     talks = talks.select(['article_title','wp10',f.col('timestamp').alias('timestamp'),'talk_page_id']).distinct()
264
265     articles = articles.join(talks,on=[talks.article_title == articles.title])
266
267     articles = articles.withColumn('timediff', f.datediff(articles.timestamp, articles.date_time))
268     articles = articles.filter(articles.timediff <= 0)
269
270     win2 = Window.partitionBy("articleid")
271     articles = articles.filter(f.col("revert")==False)
272     articles = articles.withColumn('best_timediff', f.max('timediff').over(win2))
273     articles = articles.filter(articles.timediff == articles.best_timediff)
274     articles = articles.select(['revid','timestamp','wp10','articleid','title'])
275     
276     articles = articles.groupby(['timestamp','wp10','articleid','title']).agg(f.first(f.col("revid")).alias("revid"))
277
278     articles.write.parquet("data/article_quality_data.parquet",mode='overwrite')
279     
280     tpid = pd.read_parquet("data/article_quality_data.parquet")
281
282     # we want to sample /papges/ not /labels/.
283     # so we need to do a /full/ groupby pages.
284     # this is why we have a lot of RAM!
285     # we need the number of 
286     label_counts = {}
287     sample_page_ids = {}
288     label_max_samplesize = int(samplesize / len(WP10))
289     sample_chunks = []
290     
291     for lab in WP10:
292         print(lab)
293         page_ids = tpid.loc[tpid.wp10==lab].articleid
294         label_counts[lab] = len(page_ids)
295         print(lab,label_counts)
296         if(label_counts[lab] <= label_max_samplesize):
297             sample_page_ids[lab] = page_ids
298         else:
299             sample_page_ids[lab] = random.choice(page_ids,label_max_samplesize,replace=False)
300            
301         # get the labels for each sampled article
302         sample_data_lab = tpid.loc[(tpid.articleid.isin(sample_page_ids[lab]))]
303
304         sample_chunks.append(sample_data_lab)
305
306     remember = Remember(f='remember_sample_quality_labels.RDS')
307
308     remember(label_max_samplesize, 'label_max_samplesize')
309     
310
311     # Note that different wikiprojects can have different labels
312     sample = pd.concat(sample_chunks,ignore_index=True)
313
314     revisions_per_article = sparkdf.filter(sparkdf.namespace==0).select(['revid','articleid','date_time','title'])
315     revisions_per_article = revisions_per_article.filter(f.col("date_time") >= datetime(2019,1,1))
316     revisions_per_article = revisions_per_article.filter(f.col("date_time") <= datetime(2019,12,31))
317     revisions_per_article = revisions_per_article.groupby(["articleid",'title']).count().toPandas()
318
319     revisions_per_article['title'] = revisions_per_article.title.apply(lambda s: unquote(s).strip('\"'))
320
321     revisions_per_article = pd.merge(revisions_per_article,tpid,left_on='articleid',right_on='articleid')
322     
323     revisions_per_class = revisions_per_article.groupby('wp10').agg({'count':'sum'}).reset_index()
324     revisions_per_class['wp10'] = revisions_per_class.wp10.apply(lambda s: WP10(s).to_string())
325     
326     label_counts = pd.DataFrame({'wp10':map(lambda x: x.to_string(),label_counts.keys()),'n_articles':label_counts.values()})
327     label_counts = pd.merge(label_counts,revisions_per_class,left_on='wp10',right_on='wp10')
328     label_counts = label_counts.rename(columns={'count':'n_revisions'})
329
330     remember(label_counts, 'label_sample_counts')
331     
332     sample.to_feather("data/20200301_article_labelings_sample.feather")
333
334     sample = pd.read_feather("data/20200301_article_labelings_sample.feather")
335     sample_counts = sample.articleid.groupby(sample.wp10).count().reset_index()
336     remember(sample_counts,'sample_counts')
337
338     sample_labels = sample.apply(ArticlePageLabel.from_row,axis=1)
339     sample_labels = map(PageLabel.to_json, sample_labels)
340
341     with open("data/20200301_article_labelings_sample.json",'w') as of:
342         of.writelines((l + '\n' for l in sample_labels))
343
344     pool.close()
345
346 if __name__ == "__main__":
347     fire.Fire(main)
348

Community Data Science Collective || Want to submit a patch?