# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
  # obtain confidence scores for all classes
  scores = F.softmax(outputs.data, dim=1)
  log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # adding ids based on the order of the images
  _id = 0
  for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
    # add required info to data table:
    # id, image pixels, model's guess, true label, scores for all classes
    img_id = str(_id) + "_" + str(log_counter)
    test_table.add_data(img_id, wandb.Image(i), p, l, *s)
    _id += 1
    if _id == NUM_IMAGES_PER_BATCH:
      break
# W&B: Initialize a new run to track this model's training
with wandb.init(project="table-quickstart") as run:
    # W&B: Log hyperparameters using config
    cfg = run.config
    cfg.update({"epochs" : EPOCHS, "batch_size": BATCH_SIZE, "lr" : LEARNING_RATE,
                "l1_size" : L1_SIZE, "l2_size": L2_SIZE,
                "conv_kernel" : CONV_KERNEL_SIZE,
                "img_count" : min(10000, NUM_IMAGES_PER_BATCH*NUM_BATCHES_TO_LOG)})
    # define model, loss, and optimizer
    model = ConvNet(NUM_CLASSES).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # train the model
    total_step = len(train_loader)
    for epoch in range(EPOCHS):
        # training step
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            # forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            # W&B: Log loss over training steps, visualized in the UI live
            run.log({"loss" : loss})
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                    .format(epoch+1, EPOCHS, i+1, total_step, loss.item()))
                
        # W&B: Create a Table to store predictions for each test step
        columns=["id", "image", "guess", "truth"]
        for digit in range(10):
        columns.append("score_" + str(digit))
        test_table = wandb.Table(columns=columns)
        # test the model
        model.eval()
        log_counter = 0
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                if log_counter < NUM_BATCHES_TO_LOG:
                log_test_predictions(images, labels, outputs, predicted, test_table, log_counter)
                log_counter += 1
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            acc = 100 * correct / total
            # W&B: Log accuracy across training epochs, to visualize in the UI
            run.log({"epoch" : epoch, "acc" : acc})
            print('Test Accuracy of the model on the 10000 test images: {} %'.format(acc))
        # W&B: Log predictions table to wandb
        run.log({"test_predictions" : test_table})