Spaces:
Running
Running
| import torch | |
| class NestedTensor: | |
| def __init__(self, tensors): | |
| self.tensors = list(tensors) | |
| self.is_nested = True | |
| def _copy(self): | |
| return NestedTensor(self.tensors) | |
| def apply_operation(self, other, operation): | |
| o = self._copy() | |
| if isinstance(other, NestedTensor): | |
| for i, t in enumerate(o.tensors): | |
| o.tensors[i] = operation(t, other.tensors[i]) | |
| else: | |
| for i, t in enumerate(o.tensors): | |
| o.tensors[i] = operation(t, other) | |
| return o | |
| def __add__(self, b): | |
| return self.apply_operation(b, lambda x, y: x + y) | |
| def __sub__(self, b): | |
| return self.apply_operation(b, lambda x, y: x - y) | |
| def __mul__(self, b): | |
| return self.apply_operation(b, lambda x, y: x * y) | |
| # def __itruediv__(self, b): | |
| # return self.apply_operation(b, lambda x, y: x / y) | |
| def __truediv__(self, b): | |
| return self.apply_operation(b, lambda x, y: x / y) | |
| def __getitem__(self, *args, **kwargs): | |
| return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) | |
| def unbind(self): | |
| return self.tensors | |
| def to(self, *args, **kwargs): | |
| o = self._copy() | |
| for i, t in enumerate(o.tensors): | |
| o.tensors[i] = t.to(*args, **kwargs) | |
| return o | |
| def new_ones(self, *args, **kwargs): | |
| return self.tensors[0].new_ones(*args, **kwargs) | |
| def float(self): | |
| return self.to(dtype=torch.float) | |
| def chunk(self, *args, **kwargs): | |
| return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) | |
| def size(self): | |
| return self.tensors[0].size() | |
| def shape(self): | |
| return self.tensors[0].shape | |
| def ndim(self): | |
| dims = 0 | |
| for t in self.tensors: | |
| dims = max(t.ndim, dims) | |
| return dims | |
| def device(self): | |
| return self.tensors[0].device | |
| def dtype(self): | |
| return self.tensors[0].dtype | |
| def layout(self): | |
| return self.tensors[0].layout | |
| def cat_nested(tensors, *args, **kwargs): | |
| cated_tensors = [] | |
| for i in range(len(tensors[0].tensors)): | |
| tens = [] | |
| for j in range(len(tensors)): | |
| tens.append(tensors[j].tensors[i]) | |
| cated_tensors.append(torch.cat(tens, *args, **kwargs)) | |
| return NestedTensor(cated_tensors) | |