A simple baseline, recall by next_item counter, LB 0.29+
In [1]:
import warnings
warnings.simplefilter('ignore')
import gc
import re
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
from tqdm.auto import tqdm
In [2]:
# df_prod = pd.read_csv('data/products_train.csv')
# df_prod
In [3]:
df_sess = pd.read_csv('data/sessions_train.csv')
df_sess
Out[3]:
In [4]:
df_test = pd.read_csv('data/sessions_test_task1.csv')
df_test
Out[4]:
In [5]:
def str2list(x):
x = x.replace('[', '').replace(']', '').replace("'", '').replace('\n', ' ').replace('\r', ' ')
l = [i for i in x.split() if i]
return l
In [6]:
next_item_dict = defaultdict(list)
for _, row in tqdm(df_sess.iterrows(), total=len(df_sess)):
prev_items = str2list(row['prev_items'])
next_item = row['next_item']
prev_items_length = len(prev_items)
if prev_items_length <= 1:
next_item_dict[prev_items[0]].append(next_item)
else:
for i, item in enumerate(prev_items[:-1]):
next_item_dict[item].append(prev_items[i+1])
next_item_dict[prev_items[-1]].append(next_item)
In [7]:
for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
prev_items = str2list(row['prev_items'])
prev_items_length = len(prev_items)
if prev_items_length <= 1:
continue
else:
for i, item in enumerate(prev_items[:-1]):
next_item_dict[item].append(prev_items[i+1])
In [8]:
next_item_map = {}
for item in tqdm(next_item_dict):
counter = Counter(next_item_dict[item])
next_item_map[item] = [i[0] for i in counter.most_common(100)]
In [9]:
k = []
v = []
for item in next_item_dict:
k.append(item)
v.append(next_item_dict[item])
df_next = pd.DataFrame({'item': k, 'next_item': v})
df_next = df_next.explode('next_item').reset_index(drop=True)
df_next
Out[9]:
In [10]:
top200 = df_next['next_item'].value_counts().index.tolist()[:200]
In [11]:
df_test['last_item'] = df_test['prev_items'].apply(lambda x: str2list(x)[-1])
df_test['next_item_prediction'] = df_test['last_item'].map(next_item_map)
df_test
Out[11]:
In [12]:
preds = []
for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
pred_orig = row['next_item_prediction']
pred = pred_orig
prev_items = str2list(row['prev_items'])
if type(pred) == float:
pred = top200[:100]
else:
if len(pred_orig) < 100:
for i in top200:
if i not in pred_orig and i not in prev_items:
pred.append(i)
if len(pred) >= 100:
break
else:
pred = pred[:100]
preds.append(pred)
In [13]:
df_test['next_item_prediction'] = preds
df_test
Out[13]:
In [14]:
df_test['next_item_prediction'].apply(len).describe()
Out[14]:
In [66]:
df_test[['locale', 'next_item_prediction']].to_parquet('submission_task1.parquet', engine='pyarrow')
In [ ]:
Content
Comments
You must login before you can post a comment.
great share
Thanks. Heng. Really helpful !
Just for discussion: I think this idea tries to find the following 2 patterns:
WDYT? Happy to hear you opinion. Thanks