-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathKDETool.py
More file actions
55 lines (40 loc) · 1.48 KB
/
KDETool.py
File metadata and controls
55 lines (40 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from race.hashes import *
import numpy as np
import argparse
import sys
import os
''' Tool to evaluate ground-truth KDEs
'''
parser = argparse.ArgumentParser(description = "Evaluate ground truth KDEs. Produces a results file data.gtruth")
parser.add_argument("data", help="npy file with (n x d) data entries")
parser.add_argument("queries",help="npy file with (m x d) queries")
parser.add_argument("kernel_id", type=int, help="0: L2 LSH kernel, 1: Angular kernel")
parser.add_argument("bandwidth", type=float, help="density estimate bandwidth")
args = parser.parse_args()
# x = lambda a, b, c : a + b + c
# Gaussian kernel:
# kernel = lambda x,y,w : np.exp(-1.0/(2*w) * np.linalg.norm(x - y)**2)
if args.kernel_id == 0:
kernel = lambda x,y,w : P_L2(np.linalg.norm(x-y),w)
elif args.kernel_id == 1:
kernel = lambda x,y,w : P_SRP(x,y)**(int(w))
else:
print("Unsupported kernel id.")
sys.exit()
dataset = np.load(args.data)
queries = np.load(args.queries)
NQ,d = queries.shape
N,d = dataset.shape
print("Processing ground truth for",NQ," queries and",N," dataset vectors")
sys.stdout.flush()
results = np.zeros(NQ)
for j,data in enumerate(dataset):
for i,query in enumerate(queries):
results[i] += kernel(data,query,args.bandwidth)
if j % 100 == 0:
sys.stdout.write('\r')
sys.stdout.write('Progress: {0:.4f}'.format(j/N * 100)+' %')
sys.stdout.flush()
results = results / N
output_filename = os.path.splitext(args.data)[0]+'.gtruth'
np.savetxt(output_filename, results, delimiter=',')