update
This commit is contained in:
parent
18293b45b1
commit
6db4dbc306
369
data.ipynb
369
data.ipynb
File diff suppressed because one or more lines are too long
|
@ -26,6 +26,11 @@ import matplotlib.pyplot as plt
|
||||||
import albumentations as Aug
|
import albumentations as Aug
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
|
||||||
|
OBJ_LABELS = {
|
||||||
|
"truck": 0, "bicycle": 1, "car": 2, "motorcycle": 3,
|
||||||
|
"train": 4, "bus": 5, "traffic sign": 6, "rider": 7, "person": 8,
|
||||||
|
"traffic light NA": 9, "traffic light R": 10, "traffic light G": 11, "traffic light B": 12
|
||||||
|
}
|
||||||
|
|
||||||
def use_device(GPU):
|
def use_device(GPU):
|
||||||
if GPU is not None:
|
if GPU is not None:
|
||||||
|
@ -48,8 +53,28 @@ def tensor2im(tensor=None):
|
||||||
output = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
output = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
OBJ_LABELS = {
|
def padding(image, patch_size, fill_value=0):
|
||||||
"truck": 0, "bicycle": 1, "car": 2, "motorcycle": 3,
|
# make the image sizes divisible by patch_size
|
||||||
"train": 4, "bus": 5, "traffic sign": 6, "rider": 7, "person": 8,
|
H, W = image.size(2), image.size(3)
|
||||||
"traffic light NA": 9, "traffic light R": 10, "traffic light G": 11, "traffic light B": 12
|
pad_h, pad_w = 0, 0
|
||||||
}
|
if H % patch_size > 0:
|
||||||
|
pad_h = patch_size - (H % patch_size)
|
||||||
|
if W % patch_size > 0:
|
||||||
|
pad_w = patch_size - (W % patch_size)
|
||||||
|
image_padded = image
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
image_padded = F.pad(image, (0, pad_w, 0, pad_h), value=fill_value)
|
||||||
|
return image_padded
|
||||||
|
|
||||||
|
|
||||||
|
def unpadding(image, target_size):
|
||||||
|
H, W = target_size
|
||||||
|
H_pad, W_pad = image.size(2), image.size(3)
|
||||||
|
# crop predictions on extra pixels coming from padding
|
||||||
|
extra_h = H_pad - H
|
||||||
|
extra_w = W_pad - W
|
||||||
|
if extra_h > 0:
|
||||||
|
image = image[:, :, :-extra_h]
|
||||||
|
if extra_w > 0:
|
||||||
|
image = image[:, :, :, :-extra_w]
|
||||||
|
return image
|
|
@ -1,7 +1,7 @@
|
||||||
from base import *
|
from src.perception.base import *
|
||||||
|
|
||||||
class AutoDriveDataset(Dataset):
|
class AutoDriveDataset(Dataset):
|
||||||
def __init__(self, csv_file, image_dir, lane_dir, da_dir, transform=None):
|
def __init__(self, csv_file, image_dir, lane_dir, da_dir, transform=True):
|
||||||
self.data_frame = pd.read_json(csv_file)
|
self.data_frame = pd.read_json(csv_file)
|
||||||
self.image_dir, self.lane_dir, self.da_dir = image_dir, lane_dir, da_dir
|
self.image_dir, self.lane_dir, self.da_dir = image_dir, lane_dir, da_dir
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
@ -13,7 +13,7 @@ class AutoDriveDataset(Dataset):
|
||||||
image_name = os.path.join(self.image_dir, self.data_frame.iloc[idx, 0])
|
image_name = os.path.join(self.image_dir, self.data_frame.iloc[idx, 0])
|
||||||
label_name = os.path.join(self.da_dir, self.data_frame.iloc[idx, 0]).replace("jpg", "png")
|
label_name = os.path.join(self.da_dir, self.data_frame.iloc[idx, 0]).replace("jpg", "png")
|
||||||
image = cv2.cvtColor(cv2.imread(f"{image_name}"), cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(cv2.imread(f"{image_name}"), cv2.COLOR_BGR2RGB)
|
||||||
da = cv2.cvtColor(cv2.imread("{}".format(label_name)), cv2.COLOR_BGR2RGB)
|
drivable = cv2.cvtColor(cv2.imread("{}".format(label_name)), cv2.COLOR_BGR2RGB)
|
||||||
lane = cv2.cvtColor(cv2.imread("{}".format(label_name.replace("drivable", "lane"))), cv2.COLOR_BGR2RGB)
|
lane = cv2.cvtColor(cv2.imread("{}".format(label_name.replace("drivable", "lane"))), cv2.COLOR_BGR2RGB)
|
||||||
label_data = self.data_frame.iloc[idx, 3]
|
label_data = self.data_frame.iloc[idx, 3]
|
||||||
boxes = []
|
boxes = []
|
||||||
|
@ -33,49 +33,66 @@ class AutoDriveDataset(Dataset):
|
||||||
|
|
||||||
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
||||||
labels = torch.as_tensor(labels, dtype=torch.int64)
|
labels = torch.as_tensor(labels, dtype=torch.int64)
|
||||||
target = {"lane": lane, "drivable": da, "boxes": boxes, "labels": labels}
|
object_det = {"boxes": boxes, "labels": labels}
|
||||||
return image, target
|
return image, lane, drivable, object_det
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
image, target = self.__readdata__(idx=idx)
|
image, lane, drivable, object_det = self.__readdata__(idx=idx)
|
||||||
if self.transform:
|
if self.transform:
|
||||||
image, target = self.__augmentation__(image, target)
|
image, lane, drivable, object_det = self.__augmentation__(image, lane, drivable, object_det)
|
||||||
|
|
||||||
return image, target
|
image = torch.from_numpy(image).float() / 255
|
||||||
|
lane = torch.from_numpy(lane).float() / 255
|
||||||
|
drivable = torch.from_numpy(drivable).float() / 255
|
||||||
|
object_det["boxes"] = torch.from_numpy(np.array(object_det["boxes"]))
|
||||||
|
return image, lane, drivable, object_det
|
||||||
|
|
||||||
def __augmentation__(self, image, target):
|
def __augmentation__(self, image, lane, drivable, object_det):
|
||||||
transform = Aug.Compose([
|
transform = Aug.Compose([
|
||||||
Aug.Resize (720, 640, p=1), Aug.HorizontalFlip(p=0.5), Aug.RandomBrightnessContrast(p=0.5)],
|
Aug.Resize (360, 640, p=1), Aug.HorizontalFlip(p=0.5), Aug.RandomBrightnessContrast(p=0.5)],
|
||||||
bbox_params=Aug.BboxParams(format='pascal_voc', label_fields=['labels']))
|
bbox_params=Aug.BboxParams(format='pascal_voc', label_fields=['labels']))
|
||||||
transformed = transform(image=image, masks=[target["lane"], target["drivable"]],
|
transformed = transform(image=image, masks=[lane, drivable],
|
||||||
bboxes=target["boxes"], labels=target["labels"])
|
bboxes=object_det["boxes"], labels=object_det["labels"])
|
||||||
image = transformed["image"].transpose(2, 0, 1)
|
image = transformed["image"].transpose(2, 0, 1)
|
||||||
target = {"lane":transformed["masks"][0], "drivable": transformed["masks"][1],
|
lane = transformed["masks"][0].transpose(2, 0, 1)
|
||||||
"boxes": transformed["bboxes"], "labels": transformed["labels"]}
|
drivable = transformed["masks"][1].transpose(2, 0, 1)
|
||||||
return image, target
|
object_det = {"boxes": transformed["bboxes"], "labels": transformed["labels"]}
|
||||||
|
return image, lane, drivable, object_det
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
images, targets = zip(*batch)
|
images, lane, drivable, object_det = zip(*batch)
|
||||||
images = images
|
images = torch.stack(images, dim=0)
|
||||||
targets = [{k: v for k, v in t.items()} for t in targets]
|
lane = torch.stack(lane, dim=0)
|
||||||
return images, targets
|
drivable = torch.stack(drivable, dim=0)
|
||||||
|
object_det = [{k: v for k, v in t.items()} for t in object_det]
|
||||||
|
return images, lane, drivable, object_det
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dataset = AutoDriveDataset(csv_file="/home/bayes/data/ETS2/bdd100k/label/det_20/det_train.json",
|
dataset = AutoDriveDataset(csv_file="/home/bayes/data/ETS2/bdd100k/label/det_20/det_train.json",
|
||||||
image_dir="/home/bayes/data/ETS2/bdd100k/image/100k/train/",
|
image_dir="/home/bayes/data/ETS2/bdd100k/image/100k/train/",
|
||||||
lane_dir="/home/bayes/data/ETS2/bdd100k/label/lane/colormaps/train/",
|
lane_dir="/home/bayes/data/ETS2/bdd100k/label/lane/colormaps/train/",
|
||||||
da_dir="/home/bayes/data/ETS2/bdd100k/label/drivable/colormaps/train/", transform=True)
|
da_dir="/home/bayes/data/ETS2/bdd100k/label/drivable/colormaps/train/", transform=True)
|
||||||
A, B = dataset.__getitem__(idx=10)
|
A, B, C, D = dataset.__getitem__(idx=0)
|
||||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn)
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn)
|
||||||
images, targets = next(iter(dataloader))
|
images, lane, drivable, object_det = next(iter(dataloader))
|
||||||
fig, ax = plt.subplots(1)
|
fig, ax = plt.subplots(2)
|
||||||
ax.imshow(A.transpose(1, 2, 0))
|
ax[0].imshow(A.permute(1,2,0).numpy())
|
||||||
ax.imshow(B["lane"], alpha=0.5)
|
ax[0].imshow(B.permute(1,2,0).numpy(), alpha=0.5)
|
||||||
ax.imshow(B["drivable"], alpha=0.2)
|
ax[0].imshow(C.permute(1,2,0).numpy(), alpha=0.2)
|
||||||
for i in range(0, len(B["boxes"])):
|
for i in range(0, len(D["boxes"])):
|
||||||
rect = patches.Rectangle((B["boxes"][i][0], B["boxes"][i][1]),
|
rect = patches.Rectangle((D["boxes"][i][0], D["boxes"][i][1]),
|
||||||
B["boxes"][i][2] - B["boxes"][i][0], B["boxes"][i][3] - B["boxes"][i][1], linewidth=1, edgecolor='r', facecolor='none')
|
D["boxes"][i][2] - D["boxes"][i][0], D["boxes"][i][3] - D["boxes"][i][1], linewidth=1, edgecolor='r', facecolor='none')
|
||||||
ax.add_patch(rect)
|
ax[0].add_patch(rect) # import matplotlib.patches as patches
|
||||||
|
|
||||||
plt.savefig("sample.png", dpi=500)
|
ax[0].axis("off")
|
||||||
|
|
||||||
|
ax[1].imshow(images[0].permute(1,2,0).numpy())
|
||||||
|
ax[1].imshow(lane[0].permute(1,2,0).numpy(), alpha=0.5)
|
||||||
|
ax[1].imshow(drivable[0].permute(1,2,0).numpy(), alpha=0.2)
|
||||||
|
for i in range(0, len(object_det[0]["boxes"])):
|
||||||
|
rect = patches.Rectangle((object_det[0]["boxes"][i][0], object_det[0]["boxes"][i][1]),
|
||||||
|
object_det[0]["boxes"][i][2] - object_det[0]["boxes"][i][0], object_det[0]["boxes"][i][3] - object_det[0]["boxes"][i][1], linewidth=1, edgecolor='r', facecolor='none')
|
||||||
|
ax[1].add_patch(rect) # import matplotlib.patches as patches
|
||||||
|
|
||||||
|
ax[1].axis("off")
|
||||||
|
plt.savefig("sample.png", dpi=250)
|
|
@ -0,0 +1,187 @@
|
||||||
|
import copy
|
||||||
|
from typing import Optional, List
|
||||||
|
from src.perception.base import *
|
||||||
|
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
|
||||||
|
def __clones__(module, N):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
class ResNet(nn.Module):
|
||||||
|
'''
|
||||||
|
ResNet50 as the Backbone
|
||||||
|
'''
|
||||||
|
def __init__(self, pretrained):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
resnet = timm.create_model("resnet50", pretrained=pretrained)
|
||||||
|
stage_idx = ["conv1", "bn1", "act1", "maxpool", "layer1", "layer2", "layer3", "layer4"]
|
||||||
|
return_layers, self.idx = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_idx)]), len(stage_idx) - 1
|
||||||
|
self.model = IntermediateLayerGetter(resnet, return_layers=return_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.model(x)
|
||||||
|
return output["stage7"] #[output["stage3"], output["stage4"], output["stage5"], output["stage6"], output["stage7"]]
|
||||||
|
|
||||||
|
class Compression(nn.Module):
|
||||||
|
def __init__(self, back_dim=2048, embed_dim=256):
|
||||||
|
super(Compression, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=back_dim, out_channels=embed_dim, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class PositionEmbedding(nn.Module):
|
||||||
|
'''
|
||||||
|
A Learnable Positional Embedding
|
||||||
|
'''
|
||||||
|
def __init__(self, num_queries=100, hidden_dim=256):
|
||||||
|
super(PositionEmbedding, self).__init__()
|
||||||
|
self.row_embed = nn.Parameter(torch.rand(num_queries // 2, hidden_dim // 2))
|
||||||
|
self.col_embed = nn.Parameter(torch.rand(num_queries // 2, hidden_dim // 2))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
pos = torch.cat([self.col_embed[:w].unsqueeze(0).repeat(h, 1, 1),
|
||||||
|
self.row_embed[:h].unsqueeze(1).repeat(1, w, 1)], dim=-1).flatten(0, 1).unsqueeze(1).repeat(1, b, 1)
|
||||||
|
# y = pos + 0.1 * y.flatten(2).permute(2,0,1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
class SinePositionEmbedding(nn.Module):
|
||||||
|
'''
|
||||||
|
A Static Positional Embedding
|
||||||
|
'''
|
||||||
|
def __init__(self, num_queries=100, hidden_dim=256):
|
||||||
|
super(SinePositionEmbedding, self).__init__()
|
||||||
|
self.num_queries, self.hidden_dim = num_queries, hidden_dim
|
||||||
|
self.base = 10000
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pos = torch.zeros(self.num_queries, self.hidden_dim)
|
||||||
|
position = torch.arange(0, num_queries, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim))
|
||||||
|
pos[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pos[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
return pos.unsqueeze(1).repeat(1, x.shape[1], 1)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, ffn_dim=512, dropout=0.0):
|
||||||
|
super(FeedForward, self).__init__()
|
||||||
|
self.fc_1 = nn.Conv1d(in_channels=embed_dim, out_channels=ffn_dim, kernel_size=1)
|
||||||
|
self.fc_2 = nn.Conv1d(in_channels=ffn_dim, out_channels=embed_dim, kernel_size=1)
|
||||||
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
|
self.drop = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.drop(self.fc_2(F.relu(self.fc_1(x.transpose(1, 2))))) + x.transpose(1, 2)
|
||||||
|
output = self.norm(output.transpose(1, 2))
|
||||||
|
return output
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, num_heads=8, ffn_dim=512, dropout=0.0):
|
||||||
|
super(EncoderLayer, self).__init__()
|
||||||
|
self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
|
||||||
|
self.feedforward = FeedForward(embed_dim=embed_dim, ffn_dim=ffn_dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None):
|
||||||
|
query = key = x
|
||||||
|
x, attention = self.attention(query, key, x, attn_mask)
|
||||||
|
output = self.feedforward(x)
|
||||||
|
return output, attention
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, num_heads=8, ffn_dim=512, dropout=0.0):
|
||||||
|
super(DecoderLayer, self).__init__()
|
||||||
|
self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
|
||||||
|
self.feedforward = FeedForward(embed_dim=embed_dim, ffn_dim=ffn_dim, dropout=dropout)
|
||||||
|
|
||||||
|
def forward(self, x, queries, attn_mask=None):
|
||||||
|
output, attention = self.attention(queries, x, x, attn_mask)
|
||||||
|
output = self.feedforward(output)
|
||||||
|
return output, attention
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, num_heads=8, ffn_dim=512, dropout=0.0, num_layers=8, hidden_dim=256):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
layer = EncoderLayer(embed_dim=embed_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout)
|
||||||
|
self.encoder = __clones__(layer, num_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output, attns = x, []
|
||||||
|
for layer in self.encoder:
|
||||||
|
output, attention = layer(output)
|
||||||
|
attns.append(attention)
|
||||||
|
return output, attns
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, num_heads=8, ffn_dim=512, dropout=0.0, num_layers=8, hidden_dim=256):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
layer = DecoderLayer(embed_dim=embed_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout)
|
||||||
|
self.decoder = __clones__(layer, num_layers)
|
||||||
|
|
||||||
|
def forward(self, x, queries):
|
||||||
|
output, attns = x, []
|
||||||
|
for layer in self.decoder:
|
||||||
|
output, attention = layer(output, queries)
|
||||||
|
attns.append(attention)
|
||||||
|
return output.transpose(0, 1), attns
|
||||||
|
|
||||||
|
class Detector(nn.Module):
|
||||||
|
'''
|
||||||
|
Object Detection Head
|
||||||
|
'''
|
||||||
|
def __init__(self, num_classes=100, embed_dim=256):
|
||||||
|
super(Detector, self).__init__()
|
||||||
|
self.linear_class = nn.Linear(embed_dim, num_classes + 1)
|
||||||
|
self.linear_boxes = nn.Linear(embed_dim, 4)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return {"pred_logits": self.linear_class(x), "boxes": self.linear_boxes(x)}
|
||||||
|
|
||||||
|
class Segmentor(nn.Module):
|
||||||
|
'''
|
||||||
|
Semantic Segmentation Head
|
||||||
|
'''
|
||||||
|
def __init__(self, num_classes=3, embed_dim=256):
|
||||||
|
super(Segmentor, self).__init__()
|
||||||
|
channels = [embed_dim // 2**i for i in range(6)]
|
||||||
|
self.upsample = nn.Sequential(*nn.ModuleList([
|
||||||
|
nn.Sequential(nn.Conv2d(channels[i], channels[i+1], kernel_size=1),
|
||||||
|
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)) for i in range(len(channels) - 1)
|
||||||
|
]))
|
||||||
|
self.segmentor = nn.Conv2d(channels[-1], num_classes, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.segmentor(self.upsample(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Perception(nn.Module):
|
||||||
|
def __init__(self, pretrained=True, num_queries=100, num_classes={"obj": 13, "seg": 2}, embed_dim=256, num_heads=8, ffn_dim=512, dropout=0.0, num_layers=8, hidden_dim=256):
|
||||||
|
super(Perception, self).__init__()
|
||||||
|
self.backbone = ResNet(pretrained=True)
|
||||||
|
self.position = PositionEmbedding(num_queries=num_queries)
|
||||||
|
self.compress = Compression(back_dim=2048, embed_dim=embed_dim)
|
||||||
|
self.encoder = Encoder(embed_dim=embed_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout, num_layers=num_layers, hidden_dim=hidden_dim)
|
||||||
|
self.decoder = Decoder(embed_dim=embed_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout, num_layers=num_layers, hidden_dim=hidden_dim)
|
||||||
|
|
||||||
|
self.detector = Detector(num_classes=num_classes["obj"], embed_dim=embed_dim)
|
||||||
|
self.segmentor = Segmentor(num_classes=num_classes["seg"], embed_dim=embed_dim)
|
||||||
|
|
||||||
|
requires_grad(self.backbone)
|
||||||
|
|
||||||
|
def forward(self, x, queries):
|
||||||
|
h_ori, w_ori = x.size(2), x.size(3)
|
||||||
|
x = padding(x, 32)
|
||||||
|
h, w = x.size(2), x.size(3)
|
||||||
|
output_ori = self.compress(self.backbone(x))
|
||||||
|
output = self.position(output_ori) + output_ori.flatten(2).permute(2, 0, 1)
|
||||||
|
output_enc, attns_enc = self.encoder(output)
|
||||||
|
output, attns_dec = self.decoder(output_enc, queries)
|
||||||
|
dect, seg = self.detector(output), unpadding(self.segmentor(output_enc.permute(1, 2, 0).reshape(output_ori.shape)), (h_ori, w_ori))
|
||||||
|
return {"decoded": output, "detection": dect, "segment": seg}, attns_enc, attns_dec
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from src.perception.model.perception import *
|
||||||
|
x = torch.randn([2, 3, 720, 1280])
|
||||||
|
perception = Perception()
|
||||||
|
queries = torch.randn([100, 2, 256])
|
||||||
|
output, attns_enc, attns_dec = perception(x, queries)
|
Binary file not shown.
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 459 KiB |
Loading…
Reference in New Issue