quantization / quantization_metric /expert_number.py
chen459664's picture
Add files using upload-large-folder tool
55c92b3 verified
raw
history blame contribute delete
5.62 kB
import os
os.environ['HF_HOME'] = "/home/yiren/new_ssd2/MoLA/huggingface_cache"
os.environ['HF_DATASETS_CACHE'] = "/home/yiren/new_ssd2/MoLA/huggingface_cache"
os.environ['TRANSFORMERS_CACHE'] = "/home/yiren/new_ssd2/MoLA/huggingface_cache"
import argparse
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
import tqdm
from optimum.quanto.nn.qlinear import QLinear
# from optimum.quanto import unpack
# print(unpack)
def exponential_scaling(values, target_sum, exponent):
values = np.array(values)
scaled_values = np.power(values, exponent)
scaled_integers = np.round((scaled_values / scaled_values.sum()) * target_sum).astype(int)
while scaled_integers.sum() != target_sum:
difference = target_sum - scaled_integers.sum()
if difference > 0:
scaled_integers[np.argmin(scaled_values - scaled_integers)] += 1
else:
scaled_integers[np.argmax(scaled_values - scaled_integers)] -= 1
return scaled_integers
def fix_finger(w, bins=100, pl_fitting=True, EVALS_THRESH=1e-4, filter_zeros=False):
eigs = torch.square(torch.linalg.svdvals(w).flatten())
eigs, _ = torch.sort(eigs, descending=False)
if filter_zeros:
nz_eigs = eigs[eigs > EVALS_THRESH]
N = len(nz_eigs)
else:
# print(f"{name} Skip Filter Zero")
nz_eigs = eigs
N = len(nz_eigs)
log_nz_eigs = torch.log(nz_eigs)
alphas = torch.zeros(N - 1)
Ds = torch.ones(N - 1)
if pl_fitting:
hist_nz_eigs = torch.log10(nz_eigs)
min_e, max_e = hist_nz_eigs.min(), hist_nz_eigs.max()
counts = torch.histc(hist_nz_eigs, bins, min=min_e, max=max_e)
boundaries = torch.linspace(min_e, max_e, bins + 1)
h = counts, boundaries
ih = torch.argmax(h[0])
xmin2 = 10 ** h[1][ih]
xmin_min = torch.log10(0.95 * xmin2)
xmin_max = 1.5 * xmin2
for i, xmin in enumerate(nz_eigs[:-1]):
if pl_fitting == True:
if xmin < xmin_min:
continue
if xmin > xmin_max:
break
n = float(N - i)
#seq = torch.arange(n).cuda(nz_eigs.device)
alpha = 1 + n / (torch.sum(log_nz_eigs[i:]) - n * log_nz_eigs[i])
alphas[i] = alpha
if alpha > 1:
seq = torch.arange(n, device=nz_eigs.device)
Ds[i] = torch.max(torch.abs(
1 - (nz_eigs[i:] / xmin) ** (-alpha + 1) - seq / n
))
min_D_index = torch.argmin(Ds)
final_alpha = alphas[min_D_index]
return final_alpha
class WrappedGPT:
def __init__(self, layer, layer_id=0, layer_name="none"):
self.layer = layer
self.dev = layer.weight.device
self.rows, self.columns = layer.weight.data.shape
self.scaler_row = torch.zeros(self.columns, device=self.dev)
self.nsamples = 0
def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, torch.nn.Linear) and len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.scaler_row *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
self.scaler_row += torch.norm(inp.float(), p=2, dim=1) ** 2 / self.nsamples
# def find_layers(module, layers=[nn.Linear], name=''):
def find_layers(module, layers=[QLinear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(
child, layers=layers, name=name + '.' + name1 if name != '' else name1
))
return res
def calculate_expert(model):
all_layer_alpha = []
layers = model.model.layers
for i, layer in enumerate(layers):
subset = find_layers(layer)
print(f"Processing layer {i+1}--subset--{subset}")
# for name in subset:
# print(subset[name].weight)
# print(unpack)
# layer_final_alpha = [fix_finger(subset[name].weight._data._data.float()) for name in subset]
layer_final_alpha = [fix_finger(subset[name].weight.data.float()) for name in subset]
all_layer_alpha.append(torch.stack(layer_final_alpha).mean().item())
print(f"alpha value of layer {i+1} ---{torch.stack(layer_final_alpha).mean().item()} ")
torch.cuda.empty_cache()
return all_layer_alpha
def get_llm(model_name):
return AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', default="mistralai/Mistral-7B-v0.1", type=str)
parser.add_argument('--seed', type=int, default=25)
parser.add_argument('--beta', type=float, default=2.5)
parser.add_argument('--target_sum', type=int, default=160)
args = parser.parse_args()
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
model = get_llm(args.model)
model.eval()
distribution = calculate_expert(model)
print("Distribution:", distribution)
quantized_vector = exponential_scaling(distribution, args.target_sum, args.beta)
print("Total expert number:", sum(quantized_vector))
print("expert number: ", ','.join(map(str, quantized_vector)))
topkk = [2 if n > 1 else 1 for n in quantized_vector]
topkk = ','.join(map(str, topkk))
print("top_k: ", topkk)
if __name__ == '__main__':
main()