Spaces:
Sleeping
Sleeping
| import torch | |
| def extract_text_feature(prompt, model, processor, device="cpu"): | |
| """Extract text features | |
| Args: | |
| prompt: a single text query | |
| model: OwlViT model | |
| processor: OwlViT processor | |
| device (str, optional): device to run. Defaults to 'cpu'. | |
| """ | |
| device = "cpu" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| with torch.no_grad(): | |
| input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device) | |
| print(input_ids.device) | |
| text_outputs = model.owlvit.text_model( | |
| input_ids=input_ids, | |
| attention_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ) | |
| text_embeds = text_outputs[1] | |
| text_embeds = model.owlvit.text_projection(text_embeds) | |
| text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 | |
| query_embeds = text_embeds | |
| return input_ids, query_embeds | |
| def prompt2vec(prompt: str, model, processor): | |
| """Convert prompt into a computational vector | |
| Args: | |
| prompt (str): Text to be tokenized | |
| Returns: | |
| xq: vector from the tokenizer, representing the original prompt | |
| """ | |
| # inputs = tokenizer(prompt, return_tensors='pt') | |
| # out = clip.get_text_features(**inputs) | |
| input_ids, xq = extract_text_feature(prompt, model, processor) | |
| input_ids = input_ids.detach().cpu().numpy() | |
| xq = xq.detach().cpu().numpy() | |
| return input_ids, xq | |
| def tune(clf, X, y, iters=2): | |
| """Train the Zero-shot Classifier | |
| Args: | |
| X (numpy.ndarray): Input vectors (retreived vectors) | |
| y (list of floats or numpy.ndarray): Scores given by user | |
| iters (int, optional): iterations of updates to be run | |
| """ | |
| assert len(X) == len(y) | |
| # train the classifier | |
| clf.fit(X, y, iters=iters) | |
| # extract new vector | |
| return clf.get_weights() | |
| class Classifier: | |
| """Multi-Class Zero-shot Classifier | |
| This Classifier provides proxy regarding to the user's reaction to the probed images. | |
| The proxy will replace the original query vector generated by prompted vector and finally | |
| give the user a satisfying retrieval result. | |
| This can be commonly seen in a recommendation system. The classifier will recommend more | |
| precise result as it accumulating user's activity. | |
| This is a multiclass classifier. For N queries it will set the all queries to the first-N classes | |
| and the last one takes the negative one. | |
| """ | |
| def __init__(self, client, obj_db:str, xq: list): | |
| init_weight = torch.Tensor(xq) | |
| self.num_class = xq.shape[0] | |
| self.DIMS = xq.shape[1] | |
| # convert initial query `xq` to tensor parameter to init weights | |
| self.weight = init_weight | |
| self.client = client | |
| self.obj_db = obj_db | |
| def fit(self, X: list, y: list, iters: int = 5): | |
| # convert X and y to tensor | |
| xq_s = [ | |
| f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]" | |
| for _xq in self.get_weights().tolist() | |
| ] | |
| for _ in range(iters): | |
| # zero gradients | |
| grad = [] | |
| # Normalize the weight before inference | |
| # This will constrain the gradient or you will have an explosion on query vector | |
| self.weight /= torch.norm( | |
| self.weight, p=2, dim=-1, keepdim=True | |
| ) | |
| for n in range(self.num_class): | |
| # select all training sample and create labels | |
| labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]]))) | |
| # NOTE from @fangruil | |
| # Use SQL to calculate the gradient | |
| # For binary cross entropy we have | |
| # g = (1/(1+\exp(-XW))-Y)^TX | |
| # To simplify the query, we separated | |
| # the calculation into class numbers | |
| grad_q_str = f""" | |
| SELECT avgForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad | |
| FROM ( | |
| SELECT groupArray(arrayPopBack(prelogit)) AS X, | |
| groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT | |
| FROM {self.obj_db} WHERE obj_id IN {objs})""" | |
| grad_ = [r['grad'] for r in self.client.query(grad_q_str).named_results()][0] | |
| grad.append(torch.as_tensor(grad_)) | |
| # update weights | |
| grad = torch.stack(grad, dim=0) | |
| self.weight -= 0.01 * grad | |
| self.weight /= torch.norm( | |
| self.weight, p=2, dim=-1, keepdim=True | |
| ) | |
| def get_weights(self): | |
| xq = self.weight.detach().numpy() | |
| return xq | |
| class SplitLayer(torch.nn.Module): | |
| def forward(self, x): | |
| return torch.split(x, 1, dim=-1) | |