Deconstructing Deep Learning + δeviations
Drop me an email
Format :
Date | Title
TL; DR
[32] What is the state of pruning - Blalock, D., Ortiz, J. J. G., Frankle, J., & Guttag, J. (2020). What is the state of neural network pruning?. arXiv preprint arXiv:2003.03033. Paper
main.py --epochs 10 --lr 0.1 --log-interval 100 --arch "my"
will run default
This post wass previously done at my blog
Today we will look at pruning and the different approaches followed.
Hello. It has been a long hiatus. I have been coding a lot but writing too less. Although I did end up writing a lot of documentation haha. I also changed my environment to Vim. (More on that later). Without further rants, lets get to todays topic.
Note that most of this article is based on an excellent paper by Davis Blalock et al WHAT IS THE STATE OF NEURAL NETWORK PRUNING?
[Cite] : Blalock, D., Ortiz, J. J. G., Frankle, J., & Guttag, J. (2020). What is the state of neural network pruning?. arXiv preprint arXiv:2003.03033.
Pruning is something I have been interested in for a long time but somehow I could never get around to implementing it. It interested me for a lot of reasons. Mainly that of being able to reduce the size, cost and computational requirements of my models, all while maintaning the accuracy (sort of atleast).
TL;DR Generally this comes about by removing parameters in some form or fashion.
Rather than taking a mask, we can prune certain parts of the network by setting them to 0 or by dropping them if required. (aka weights and biases)
In most cases, the network is first trained for a while. Then pruned. Which reduces its accuracy and is thus trained again (fine tuning). This cycle is repeated until we get the results we require.
There are many types of such methods but they have been categorized based on what they change from the original idea. Note that, in the end the main idea remains the same, just how it is done is varied.
Extending this from the previous topic and trying to code a bit.
NOTE These codes are TOY examples to understand the major ideas. For proper code refer to my repository here. Link
Let us get some random values in a nested array. In all cases, assume these are weights of a network. Each sub array represents a layer of the network let us say.
weights = [rand(10) for _ in 1:10]
We also consider a percentage of values/ number of values to drop as an input.
Okay so we first flatten the nested array. Then we can easily find the n smallest values because it is a single long array now. We can write a tiny function to identify if the value is greater than the a value and return 0 or the original otherwise.
function setval(val, cmpval, setter)
if val<cmpval
return setter
else
return val
end
end
We can then sort this flattened list. After we do that, we run an element wise map function (over each layer) and pass each element to our previous function. This will allow us to set all the required elements to 0 (or our own value) if the current value is smaller than the value we chose. This will effectively drop the value.
"""
weights = array
n = lowest nth smallest value to prune below ( aka take the nth smallest value and prune below it ) (eg: 3)
"""
function global_mag_prune(weights, n = 3, setter = 0)
temp = collect(Iterators.flatten(weights))
sort!(temp)
map.( x-> setval(x,temp[n], setter) , weights )
end
We also write a function to determine how sparse the network has become. (Aka how many zeros).
return temp2, sum(x->x==0, collect(Iterators.flatten(temp2)) , dims=1)/(size(temp)[1]*size(temp)[1])*100
Modifying the global layerwise and applying it per layer instead. To do this, we first make a copy of the weights. Then for every layer in the array, we find the least n values, take the nth value and set all the others to 0. As an edge case, if the number of elements entered is greater than the total length of the layer, then the entire layer is set to 0.
"""
weights = array
n = lowest nth smallest value to prune below ( aka take the nth smallest value and prune below it ) (eg: 3)
"""
function layer_mag_prune(weights, n = 3, setter = 0)
backup = deepcopy(weights)
for layer in 1:length(weights)
sort!(backup[layer])
if n>length(weights[layer])
n = length(weights[layer])
end
backup[layer] = map.( x-> setval(x,weights[layer][n], setter) , weights[layer] )
end
return backup, sum(x->x==0, collect(Iterators.flatten(backup)) , dims=1)/(size(weights)[1]*size(weights)[1])*100
end
For this we would need to compute the gradient of the weights so far. Well, how about I just use a random array as proof of concept so to speak. So we basically take the previous code and change the inital part to being a product of the weights and the gradient.
function global_grad_prune(weights, n = 3, setter = 0)
temp = weights*rand(size(weights))
temp = collect(Iterators.flatten(weights))
sort!(temp)
temp2 = map.( x-> setval(x,temp[n], setter) , weights )
return temp2, sum(x->x==0, collect(Iterators.flatten(temp2)) , dims=1)/(size(temp)[1]*size(temp)[1])*100
end
"""
weights = array
n = lowest nth smallest value to prune below ( aka take the nth smallest value and prune below it ) (eg: 3)
"""
function layer_grad_prune(weightsnew, n = 3, setter = 0)
weights = deepcopy(weightsnew)*rand(size(weightsnew))
backup = deepcopy(weights)
for layer in 1:length(weights)
sort!(backup[layer])
if n>length(weights[layer])
n = length(weights[layer])
end
backup[layer] = map.( x-> setval(x,weights[layer][n], setter) , weights[layer] )
end
return backup, sum(x->x==0, collect(Iterators.flatten(backup)) , dims=1)/(size(weights)[1]*size(weights)[1])*100
end
For this we first take the number of values to prune by identifying the total size of the weights and then multiplying it by the fraction of values to remove. We then flatten the array, set the values we do not need to 0 and then reshape back to the original shape.
"""
frac = percentage of values to remove
"""
function random_prune(weights, frac = .3, setter = 0)
num_prune = Int(round(frac*(size(weights)[1]*size(weights)[1])))
@info num_prune
backup = collect(Iterators.flatten(deepcopy(weights)))
backup[rand(1:length(backup), num_prune)] .= 0
@info size(weights)
return reshape(backup, size(weights))
end
That ends our little journey into pruning. Wow that took longer then expected. I should go sleep.