Skip to content

Add script for finding optimal anchor shapes#59

Open
mihaimartalogu wants to merge 2 commits into
BichenWuUCB:masterfrom
mihaimartalogu:master
Open

Add script for finding optimal anchor shapes#59
mihaimartalogu wants to merge 2 commits into
BichenWuUCB:masterfrom
mihaimartalogu:master

Conversation

@mihaimartalogu
Copy link
Copy Markdown

Hi @BichenWuUCB,
Many thanks for publishing the sources for your model.

I'm working on fitting it to my own dataset (more smartphone-camera shaped), and one of the things I needed to do was to adapt the shapes of the anchors.

Here I'm sharing the script I used for finding the optimal anchor sizes.

Note however that for the KITTI dataset I don't get at all the same results as the ones in the repository:
In config/kitti_res50_config.py:

         [[  94.,  49.], [ 225., 161.], [ 170.,  91.],
           [ 390., 181.], [  41.,  32.], [ 128.,  64.],
           [ 298., 164.], [ 232.,  99.], [  65.,  42.]])]

In config/kitti_squeezeDet_config.py:

          [[  36.,  37.], [ 366., 174.], [ 115.,  59.],
           [ 162.,  87.], [  38.,  90.], [ 258., 173.],
           [ 224., 108.], [  78., 170.], [  72.,  43.]])]

What I get instead is:

$ python scripts/kmeans_anchors.py --geometry 1248x384 --kmeans-max-iter 1000000
...
[[70.45, 41.96], [390.36, 165.54], [125.10, 66.58],
[98.62, 186.29], [29.57, 26.18], [43.72, 94.05],
[356.65, 339.55], [269.63, 170.46], [198.84, 101.97]]

screen shot 2017-06-16 at 17 52 44

Note for example that the "wide and short" sizes don't get an anchor, but maybe it's fair... Have I missed something? Could you have a look, and consider merging it if the implementation is correct? (I'm new to the field of machine learning, so not 100% confident)

Cheers,
Mihai

Copy link
Copy Markdown

@ilystsov ilystsov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added docstring.
Removed the comment about Python 2 since it's outdated and not relevant in modern Python.
Replaced nonlocals dictionary with the more modern nonlocal keyword available in Python 3.
mplified the method by removing unnecessary checks and updating tqdm progress bars.
ed os.path.join for creating paths.

Comment thread scripts/kmeans_anchors.py
Comment on lines +50 to +81
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory. You might need to adapt this if your dataset is really huge.
"""
nonlocals = { # Python 2 doesn't support nonlocal, using a mutable dict() instead
'entries_done': 0,
'metadata': dict(),
'entries_done_pbar': None
}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:

for entry in tqdm(dataset_entries, desc='Scheduling jobs'):
if nonlocals['entries_done_pbar'] is None:
# instantiating here so that it appears after the 'Scheduling jobs' one
nonlocals['entries_done_pbar'] = tqdm(total=len(dataset_entries), desc='Retrieving metadata')

def entry_done(future):
""" Record progress """
nonlocals['entries_done'] += 1
nonlocals['entries_done_pbar'].update(1)
fr = future.result()
if fr is not None:
local_entry, value = fr # do NOT use the entry variable from the scope!
nonlocals['metadata'][local_entry] = value

future = pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h)
future.add_done_callback(entry_done) # FIXME: doesn't work if chained directly to submit(). bug in futures? reproduce and submit report.
nonlocals['entries_done_pbar'].close()
assert len(nonlocals['metadata'].values()) >= 0.9 * len(dataset_entries) # catch if entry_done doesn't update the dict correctly
return nonlocals['metadata']
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory. You might need to adapt this if your dataset is really huge.
"""
nonlocals = { # Python 2 doesn't support nonlocal, using a mutable dict() instead
'entries_done': 0,
'metadata': dict(),
'entries_done_pbar': None
}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:
for entry in tqdm(dataset_entries, desc='Scheduling jobs'):
if nonlocals['entries_done_pbar'] is None:
# instantiating here so that it appears after the 'Scheduling jobs' one
nonlocals['entries_done_pbar'] = tqdm(total=len(dataset_entries), desc='Retrieving metadata')
def entry_done(future):
""" Record progress """
nonlocals['entries_done'] += 1
nonlocals['entries_done_pbar'].update(1)
fr = future.result()
if fr is not None:
local_entry, value = fr # do NOT use the entry variable from the scope!
nonlocals['metadata'][local_entry] = value
future = pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h)
future.add_done_callback(entry_done) # FIXME: doesn't work if chained directly to submit(). bug in futures? reproduce and submit report.
nonlocals['entries_done_pbar'].close()
assert len(nonlocals['metadata'].values()) >= 0.9 * len(dataset_entries) # catch if entry_done doesn't update the dict correctly
return nonlocals['metadata']
def get_dataset_metadata(dataset_root, input_w, input_h, max_jobs):
"""
Load all dataset metadata into memory.
Args:
- dataset_root (str): The root directory of the dataset.
- input_w (int): Width of the input.
- input_h (int): Height of the input.
- max_jobs (int): Maximum number of concurrent processes to use.
Returns:
- dict: Metadata for the dataset.
"""
entries_done = 0
metadata = {}
with open(os.path.join(dataset_root, 'ImageSets', 'trainval.txt')) as f:
dataset_entries = f.read().splitlines()
with tqdm(total=len(dataset_entries), desc='Retrieving metadata') as entries_done_pbar:
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as pool:
futures = [pool.submit(get_entry_metadata, dataset_root, entry, input_w, input_h) for entry in dataset_entries]
for future in concurrent.futures.as_completed(futures):
fr = future.result()
if fr is not None:
local_entry, value = fr
metadata[local_entry] = value
entries_done += 1
entries_done_pbar.update(1)
assert len(metadata.values()) >= 0.9 * len(dataset_entries), "Entry_done didn't update the dict correctly."
return metadata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants