Skip to content
Snippets Groups Projects
Commit c9d3858d authored by Lanka Naga Sai Deep's avatar Lanka Naga Sai Deep
Browse files

Update accuracy graph.png, AlexNet.ipynb, confusion matrix.png, examples...

Update accuracy graph.png, AlexNet.ipynb, confusion matrix.png, examples predicted and target.png, loss.png, helper_dataset.py, helper_evaluation.py, helper_plotting.py, helper_train.py, minibatch_loss_list.txt, train_acc_list.txt, valid_acc_list.txt files
parent a80dcf30
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
accuracy graph.png

20.1 KiB

confusion matrix.png

8.67 KiB

examples predicted and target.png

63 KiB

import torch
from torch.utils.data import sampler
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torchvision import transforms
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Parameters:
------------
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
------------
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
def get_dataloaders_mnist(batch_size, num_workers=0,
validation_fraction=None,
train_transforms=None,
test_transforms=None):
if train_transforms is None:
train_transforms = transforms.ToTensor()
if test_transforms is None:
test_transforms = transforms.ToTensor()
train_dataset = datasets.MNIST(root='data',
train=True,
transform=train_transforms,
download=True)
valid_dataset = datasets.MNIST(root='data',
train=True,
transform=test_transforms)
test_dataset = datasets.MNIST(root='data',
train=False,
transform=test_transforms)
if validation_fraction is not None:
num = int(validation_fraction * 60000)
train_indices = torch.arange(0, 60000 - num)
valid_indices = torch.arange(60000 - num, 60000)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler=valid_sampler)
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
sampler=train_sampler)
else:
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False)
if validation_fraction is None:
return train_loader, test_loader
else:
return train_loader, valid_loader, test_loader
def get_dataloaders_cifar10(batch_size, num_workers=0,
validation_fraction=None,
train_transforms=None,
test_transforms=None):
if train_transforms is None:
train_transforms = transforms.ToTensor()
if test_transforms is None:
test_transforms = transforms.ToTensor()
train_dataset = datasets.CIFAR10(root='data',
train=True,
transform=train_transforms,
download=True)
valid_dataset = datasets.CIFAR10(root='data',
train=True,
transform=test_transforms)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=test_transforms)
if validation_fraction is not None:
num = int(validation_fraction * 50000)
train_indices = torch.arange(0, 50000 - num)
valid_indices = torch.arange(50000 - num, 50000)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
valid_loader = DataLoader(dataset=valid_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler=valid_sampler)
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
sampler=train_sampler)
else:
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False)
if validation_fraction is None:
return train_loader, test_loader
else:
return train_loader, valid_loader, test_loader
# imports from installed libraries
import os
import numpy as np
import random
import torch
from distutils.version import LooseVersion as Version
from itertools import product
def set_all_seeds(seed):
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def set_deterministic():
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if torch.__version__ <= Version("1.7"):
torch.set_deterministic(True)
else:
torch.use_deterministic_algorithms(True)
def compute_accuracy(model, data_loader, device):
with torch.no_grad():
correct_pred, num_examples = 0, 0
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.float().to(device)
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
def compute_confusion_matrix(model, data_loader, device):
all_targets, all_predictions = [], []
with torch.no_grad():
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
all_targets.extend(targets.to('cpu'))
all_predictions.extend(predicted_labels.to('cpu'))
all_predictions = all_predictions
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
class_labels = np.unique(np.concatenate((all_targets, all_predictions)))
if class_labels.shape[0] == 1:
if class_labels[0] != 0:
class_labels = np.array([0, class_labels[0]])
else:
class_labels = np.array([class_labels[0], 1])
n_labels = class_labels.shape[0]
lst = []
z = list(zip(all_targets, all_predictions))
for combi in product(class_labels, repeat=2):
lst.append(z.count(combi))
mat = np.asarray(lst)[:, None].reshape(n_labels, n_labels)
return mat
\ No newline at end of file
# imports from installed libraries
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_training_loss(minibatch_loss_list, num_epochs, iter_per_epoch,
results_dir=None, averaging_iterations=100):
plt.figure()
ax1 = plt.subplot(1, 1, 1)
ax1.plot(range(len(minibatch_loss_list)),
(minibatch_loss_list), label='Minibatch Loss')
if len(minibatch_loss_list) > 1000:
ax1.set_ylim([
0, np.max(minibatch_loss_list[1000:])*1.5
])
ax1.set_xlabel('Iterations')
ax1.set_ylabel('Loss')
ax1.plot(np.convolve(minibatch_loss_list,
np.ones(averaging_iterations,)/averaging_iterations,
mode='valid'),
label='Running Average')
ax1.legend()
###################
# Set scond x-axis
ax2 = ax1.twiny()
newlabel = list(range(num_epochs+1))
newpos = [e*iter_per_epoch for e in newlabel]
ax2.set_xticks(newpos[::10])
ax2.set_xticklabels(newlabel[::10])
ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 45))
ax2.set_xlabel('Epochs')
ax2.set_xlim(ax1.get_xlim())
###################
plt.tight_layout()
if results_dir is not None:
image_path = os.path.join(results_dir, 'plot_training_loss.pdf')
plt.savefig(image_path)
def plot_accuracy(train_acc_list, valid_acc_list, results_dir):
num_epochs = len(train_acc_list)
plt.plot(np.arange(1, num_epochs+1),
train_acc_list, label='Training')
plt.plot(np.arange(1, num_epochs+1),
valid_acc_list, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
if results_dir is not None:
image_path = os.path.join(
results_dir, 'plot_acc_training_validation.pdf')
plt.savefig(image_path)
def show_examples(model, data_loader, unnormalizer=None, class_dict=None):
for batch_idx, (features, targets) in enumerate(data_loader):
with torch.no_grad():
features = features
targets = targets
logits = model(features)
predictions = torch.argmax(logits, dim=1)
break
fig, axes = plt.subplots(nrows=3, ncols=5,
sharex=True, sharey=True)
if unnormalizer is not None:
for idx in range(features.shape[0]):
features[idx] = unnormalizer(features[idx])
nhwc_img = np.transpose(features, axes=(0, 2, 3, 1))
if nhwc_img.shape[-1] == 1:
nhw_img = np.squeeze(nhwc_img.numpy(), axis=3)
for idx, ax in enumerate(axes.ravel()):
ax.imshow(nhw_img[idx], cmap='binary')
if class_dict is not None:
ax.title.set_text(f'P: {class_dict[predictions[idx].item()]}'
f'\nT: {class_dict[targets[idx].item()]}')
else:
ax.title.set_text(f'P: {predictions[idx]} | T: {targets[idx]}')
ax.axison = False
else:
for idx, ax in enumerate(axes.ravel()):
ax.imshow(nhwc_img[idx])
if class_dict is not None:
ax.title.set_text(f'P: {class_dict[predictions[idx].item()]}'
f'\nT: {class_dict[targets[idx].item()]}')
else:
ax.title.set_text(f'P: {predictions[idx]} | T: {targets[idx]}')
ax.axison = False
plt.tight_layout()
plt.show()
def plot_confusion_matrix(conf_mat,
hide_spines=False,
hide_ticks=False,
figsize=None,
cmap=None,
colorbar=False,
show_absolute=True,
show_normed=False,
class_names=None):
if not (show_absolute or show_normed):
raise AssertionError('Both show_absolute and show_normed are False')
if class_names is not None and len(class_names) != len(conf_mat):
raise AssertionError('len(class_names) should be equal to number of'
'classes in the dataset')
total_samples = conf_mat.sum(axis=1)[:, np.newaxis]
normed_conf_mat = conf_mat.astype('float') / total_samples
fig, ax = plt.subplots(figsize=figsize)
ax.grid(False)
if cmap is None:
cmap = plt.cm.Blues
if figsize is None:
figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25)
if show_normed:
matshow = ax.matshow(normed_conf_mat, cmap=cmap)
else:
matshow = ax.matshow(conf_mat, cmap=cmap)
if colorbar:
fig.colorbar(matshow)
for i in range(conf_mat.shape[0]):
for j in range(conf_mat.shape[1]):
cell_text = ""
if show_absolute:
cell_text += format(conf_mat[i, j], 'd')
if show_normed:
cell_text += "\n" + '('
cell_text += format(normed_conf_mat[i, j], '.2f') + ')'
else:
cell_text += format(normed_conf_mat[i, j], '.2f')
ax.text(x=j,
y=i,
s=cell_text,
va='center',
ha='center',
color="white" if normed_conf_mat[i, j] > 0.5 else "black")
if class_names is not None:
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=90)
plt.yticks(tick_marks, class_names)
if hide_spines:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
if hide_ticks:
ax.axes.get_yaxis().set_ticks([])
ax.axes.get_xaxis().set_ticks([])
plt.xlabel('predicted label')
plt.ylabel('true label')
return fig, ax
\ No newline at end of file
import time
import torch
import numpy as np
from tqdm import tqdm
from helper_evaluation import compute_accuracy
def train_model(model, num_epochs, train_loader,
valid_loader, optimizer,
device, logging_interval=50,
scheduler=None,
scheduler_on='valid_acc'):
start_time1 = time.time()
min_valid_loss = 0
minibatch_loss_list, train_acc_list, valid_acc_list = [], [], []
for epoch in range(num_epochs):
start_time = time.time()
model.train()
for batch_idx, (features, targets) in enumerate(tqdm(train_loader)):
features = features.to(device)
targets = targets.to(device)
# ## FORWARD AND BACK PROP
logits = model(features)
loss = torch.nn.functional.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
# ## UPDATE MODEL PARAMETERS
optimizer.step()
# ## LOGGING
minibatch_loss_list.append(loss.item())
'''if not batch_idx % logging_interval:
print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
f'| Loss: {loss:.4f}')'''
model.eval()
with torch.no_grad(): # save memory during inference
train_acc = compute_accuracy(model, train_loader, device=device)
valid_acc = compute_accuracy(model, valid_loader, device=device)
print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
f'| Train: {train_acc :.2f}% '
f'| Validation: {valid_acc :.2f}%')
train_acc_list.append(train_acc.item())
valid_acc_list.append(valid_acc.item())
if min_valid_loss < valid_acc.item():
print("Validation acc Increased(",min_valid_loss,"--->",valid_acc.item(),") \t Saving The Model")
min_valid_loss = valid_acc.item()
torch.save(model.state_dict(), '/home/user/research/AlexNet/Alexnet_model.pth')
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
'''if scheduler is not None:
if scheduler_on == 'valid_acc':
scheduler.step(valid_acc_list[-1])
elif scheduler_on == 'minibatch_loss':
scheduler.step(minibatch_loss_list[-1])
else:
raise ValueError(f'Invalid `scheduler_on` choice.')'''
elapsed = (time.time() - start_time1)/60
print(f'Total Training Time: {elapsed:.2f} min')
#test_acc = compute_accuracy(model, test_loader, device=device)
#print(f'Test accuracy {test_acc :.2f}%')
return minibatch_loss_list, train_acc_list, valid_acc_list
loss.png 0 → 100644
loss.png

20.6 KiB

Source diff could not be displayed: it is too large. Options to address this: view the blob.
52.89007568359375,64.47822570800781,59.98997497558594,67.95857238769531,66.99520874023438,72.55818939208984,75.86590576171875,74.03385925292969,77.20235443115234,76.79585266113281,82.84329986572266,84.62522888183594,82.83773040771484,83.24980163574219,86.21227264404297,86.74127960205078,86.8637924194336,87.59883880615234,88.80721282958984,87.78260040283203,88.49537658691406,89.27497100830078,88.85733032226562,86.39603424072266,90.07684326171875,90.12139129638672,90.76734161376953,89.90978240966797,89.94876861572266,90.4944839477539,90.65596771240234,91.31305694580078,92.04254150390625,91.86991119384766,91.53023529052734,90.42765808105469,91.29078674316406,92.05924987792969,92.37665557861328,91.98128509521484,91.9645767211914,93.2898941040039,93.54605102539062,92.8221435546875,93.6629867553711,93.45695495605469,91.9478759765625,93.07829284667969,94.14188385009766,93.25648498535156,94.21427154541016,93.0838623046875,92.91680145263672,94.37576293945312,93.56832885742188,94.2309799194336,94.58180236816406,94.03050994873047,94.10289764404297,94.4815673828125,94.8769302368164,93.68526458740234,95.09967803955078,94.44258880615234,95.47276306152344,94.69316864013672,94.4203109741211,95.11637878417969,95.86813354492188,95.14978790283203,94.96603393554688,95.46163177490234,94.69873809814453,94.65419006347656,95.36139678955078,95.56743621826172,95.55072784423828,95.66766357421875,95.75675964355469,96.2746353149414,96.2022476196289,95.89041137695312,96.10201263427734,95.8235855102539,96.21337890625,95.9349594116211,95.66210174560547,95.96279907226562,95.84586334228516,96.2022476196289,96.30804443359375,95.91825103759766,96.3860092163086,95.76789855957031,96.72012329101562,96.89274597167969,96.40271759033203,96.43612670898438,96.47510528564453,96.48067474365234,
\ No newline at end of file
52.56410217285156,58.508155822753906,53.991844177246094,69.46387481689453,66.84149169921875,70.42540740966797,77.33100891113281,72.66899871826172,78.55477905273438,79.02098083496094,83.7995376586914,84.38228607177734,82.92540740966797,85.2272720336914,85.92657470703125,87.93706512451172,86.82984161376953,87.03380584716797,86.30535888671875,86.97552490234375,86.68415069580078,88.75291442871094,87.41259002685547,86.15967559814453,90.20979309082031,87.58741760253906,88.95687866210938,89.10256958007812,89.68531799316406,88.86946105957031,90.41375732421875,90.18065643310547,90.00582885742188,90.18065643310547,90.8216781616211,87.0046615600586,88.89860534667969,91.72494506835938,91.11305236816406,90.15151977539062,89.51049041748047,90.64685821533203,90.87995147705078,91.14219665527344,91.43356323242188,91.02564239501953,90.8216781616211,91.02564239501953,92.33683013916016,92.1911392211914,92.24942016601562,90.70513153076172,92.39511108398438,92.59906768798828,90.15151977539062,91.92890167236328,92.94872283935547,92.74475860595703,92.1911392211914,92.39511108398438,93.00698852539062,91.66667175292969,92.77389526367188,92.01631927490234,92.01631927490234,92.365966796875,91.40443420410156,92.1911392211914,93.0361328125,93.85198211669922,92.71562194824219,92.01631927490234,91.55011749267578,92.62820434570312,92.83216857910156,93.44405364990234,93.0361328125,93.67715454101562,92.94872283935547,92.77389526367188,92.6573486328125,93.06526947021484,92.94872283935547,92.5407943725586,93.56060791015625,92.83216857910156,93.44405364990234,92.94872283935547,92.56993103027344,94.05594635009766,93.38578796386719,92.71562194824219,93.29837036132812,91.43356323242188,92.83216857910156,93.96853637695312,93.06526947021484,93.32750701904297,92.62820434570312,93.2109603881836,
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment