-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSocial_Dataset_utils.py
More file actions
203 lines (185 loc) · 8.41 KB
/
Social_Dataset_utils.py
File metadata and controls
203 lines (185 loc) · 8.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import numpy as np
## Utility functions for Social Dataset
## Segmentation functions
##################################################################################
## Takes two sets of segments and returns two lists: one that gives all breakpoints, and another that tells which animals have data between the breakpoints.
#TODO: TREAT THE ENDS CORRECTLY: extrapolate out to the same length, and prevent cutoff of the last entry.
def order_segsets(segseta,segsetb):
segsets = [segseta,segsetb]
inds_tagged = []
for s,segset in enumerate(segsets):
## First unpack both segment sets
ind_unpacked = [entry for tup in segset for entry in tup]
## Now associate a mouse code with them in the form of a one hot vector.
ind_id = np.zeros(len(segsets))
ind_tagged = [(ind_entry,s) for ind_entry in ind_unpacked]
inds_tagged.append(ind_tagged)
## Now sort according to the segment ids.
inds_together = sorted(inds_tagged[0]+inds_tagged[1])
## unzip, and take a cumulative sum over the identity array that results.
inds,ids = zip(*inds_together)
## build a onehot array out of ids.
template = np.eye(len(segsets))
ids_array = np.array([template[i] for i in ids])
intervals = [(inds[i],inds[i+1]) for i in range(len(inds)-1)]
ids_sorted = np.cumsum(np.array(ids_array),axis = 0)%2
## identify repeats, and remove them in the intervals and add the corresponding entries in the ids.
intervals_trimmed = []
ids_trimmed = []
for i,interval in enumerate(intervals):
if interval[0] == interval[1]:
pass
else:
intervals_trimmed.append(interval)
ids_trimmed.append(ids_sorted[i])
print(interval)
return intervals_trimmed,ids_trimmed
## Make a segment dictionary that splits on the starts and ends of either trajectories.
def find_segments(indices):
differences = np.diff(indices)
all_intervals = []
## Initialize with the first element added:
interval = []
interval.append(indices[0])
for i,diff in enumerate(differences):
if diff == 1:
pass # interval not yet over
else:
# last interval ended
if interval[0] == indices[i]:
interval.append(indices[i]+1)
else:
interval.append(indices[i]+1)
all_intervals.append(interval)
# start new interval
interval = [indices[i+1]]
if i == len(differences)-1:
interval.append(indices[-1]+1)
all_intervals.append(interval)
return all_intervals
def order_segsets(segseta,segsetb):
segsets = [segseta,segsetb]
inds_tagged = []
for s,segset in enumerate(segsets):
## First unpack both segment sets
ind_unpacked = [entry for tup in segset for entry in tup]
## Now associate a mouse code with them in the form of a one hot vector.
ind_id = np.zeros(len(segsets))
ind_tagged = [(ind_entry,s) for ind_entry in ind_unpacked]
inds_tagged.append(ind_tagged)
## Now sort according to the segment ids.
inds_together = sorted(inds_tagged[0]+inds_tagged[1])
## unzip, and take a cumulative sum over the identity array that results.
inds,ids = zip(*inds_together)
## build a onehot array out of ids.
template = np.eye(len(segsets))
ids_array = np.array([template[i] for i in ids])
intervals = [(inds[i],inds[i+1]) for i in range(len(inds)-1)]
ids_sorted = np.cumsum(np.array(ids_array),axis = 0)%2
## identify repeats, and remove them in the intervals and add the corresponding entries in the ids.
intervals_trimmed = []
ids_trimmed = []
for i,interval in enumerate(intervals):
if interval[0] == interval[1]:
pass
else:
intervals_trimmed.append(interval)
ids_trimmed.append(ids_sorted[i])
return intervals_trimmed,ids_trimmed
## Form this array in a few steps. 1. take your set of intervals, and append to each start and end an identity marker. then just argsort all of the indices, and return the sorted timestamps and the sorted id markers. do binary cumsums over the id markers to keep track of when we are in one interval or the other.
def ind_to_dict_split_for_config(indices,nb_part):
v_partinds,m_partinds = indices[nb_part],indices[nb_part+5]
vsegs,msegs = find_segments(v_partinds[:,0]),find_segments(m_partinds[:,0])
intervals,ids_trimmed = order_segsets(vsegs,msegs)
processed =np.array(intervals)
mask = np.array(ids_trimmed)
## We code each segement for each animal by an element of {0,1,-1}:
## 0 = forget
## 1 = keep
## -1 = switch
return processed,mask
def ind_to_dict_split(indices,nb_part):
v_partinds,m_partinds = indices[nb_part],indices[nb_part+5]
vsegs,msegs = find_segments(v_partinds[:,0]),find_segments(m_partinds[:,0])
intervals,ids_trimmed = order_segsets(vsegs,msegs)
processed =np.array(intervals)
mask = np.array(ids_trimmed)
## find points where neither mouse has information:
to_forget = np.where(np.sum(mask,axis = 1)!=0)
processed_rel = processed[to_forget]
mask_rel = mask[to_forget]
## We code each segement for each animal by an element of {0,1,-1}:
## 0 = forget
## 1 = keep
## -1 = switch
return processed_rel,mask_rel
def val_dist(trajraw,intervals,mask,currind,currid,mouseind,mouseid):
'''
A function to calculate distances betweeen the end of one trajectory and the beginning of another. Needs a mask argument because it needs to think about validity of all possible trajectories!!!
'''
if mask[currind,currid] == 0:
return np.nan
else:
## We don't have to worry about trajectories not existing because we should only be seeing processed ones in both arguments.
end = intervals[mouseind][-1]-1 ## -1 for interval indexing.
start = intervals[currind][0]
mouseid = int(mouseid)
trajstart = trajraw[start,currid*2:currid*2+2]
trajend = trajraw[end,mouseid*2:mouseid*2+2]
distance = np.linalg.norm(trajstart-trajend)
return distance
def val_time(intervals,mask,currind,currid,mouseind,mouseid):
if mask[currind,currid] == 0:
return np.nan
else:
## We don't have to worry about trajectories not existing because we should only be seeing processed ones in both arguments.
end = intervals[mouseind][-1]-1
start = intervals[currind][0]
return start-end
### To process intra-segment distances.
def intra_tv(trajraw,intervals,mask,i,m):
if mask[i,m] == 0:
return np.nan
else:
interval = intervals[i,:]
# Isolate traces at relevant times:
traj = trajraw[slice(*interval),tuple(m*2+np.array([0,1]))]
tv = np.sum(np.linalg.norm(np.diff(traj,axis = 0),axis = 1))
#print(traj,np.diff(traj,axis=0),np.linalg.norm(np.diff(traj,axis = 0),axis = 1))
return tv
### To process intra-segment distances.
def intra_dist(trajraw,intervals,mask,i,m):
if mask[i,m] == 0:
return np.nan
else:
interval = intervals[i,:]
# Isolate traces at relevant times:
end = trajraw[interval[-1]-1,m*2+np.array([0,1])]
start = trajraw[interval[0],m*2+np.array([0,1])]
return np.linalg.norm(end-start)
def intra_time(intervals,mask,i,m):
if mask[i,m] == 0:
return np.nan
else:
interval = intervals[i,:]
return interval[-1]-interval[0]
###############################################################3
## Path handling functions
def filepaths(folderpath):
## Look for all files that we should analyze within the folder:
## These will be distinguished by the name cropped_part*.mp4
all_files = os.listdir(folderpath)
data = [folderpath+'/'+fileset for fileset in all_files if fileset.split('.')[-1] == 'h5' and 'cropped_part' in fileset.split('.')[-2]]
return data
def moviepath(filepath):
relevant_part = filepath.split('DeepCut')[0]
movie_append = '.mp4'
return relevant_part+movie_append
def datapaths(folderpath):
all_files = os.listdir(folderpath)
data = [folderpath+'/'+fileset for fileset in all_files if 'dataset_' in fileset and 'ethogram' not in fileset]
return data
def excelpaths(folderpath):
all_files = os.listdir(folderpath)
data = [folderpath+'/'+fileset for fileset in all_files if fileset.split(".")[-1] == "xlsx" and "Behavior" in fileset]
return data