This is something small for printing out classification details for a keras ImageDataGenerator .
Image you have three different datagenerators..
train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(512,512), classes=klasses, batch_size=32) valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(512,512), classes=klasses, batch_size=32) test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(512,512), classes=klasses, batch_size=32) dataset_stats([train_batches, valid_batches, test_batches], klasses)
Than you can use something like..
def dataset_stats(data_generators=[], klasses=[]): ''' ''' # collecting data data = {} for batch in data_generators: stats = {} batch_name = str(batch.directory.split('/')[-1]) for label in batch.labels: if klasses[label] in stats: stats[klasses[label]] = stats[klasses[label]]+1 else: stats[klasses[label]] = 1 data[batch_name] = stats # printing sth clean line = '{:20}'.format('Image Counts') for batch_name in data.keys(): line += '{:>20}'.format(batch_name) lines = line + '\n' for klass in klasses: line = '{:20}'.format(klass) for batch_name in data.keys(): value = 0 if klass in data[batch_name]: value = data[batch_name][klass] line += '{:>20}'.format(value) lines += '\n' + line lines += '\n\n{:20}'.format('SUM') for batch_name in data.keys(): summ = sum(item for item in data[batch_name].values()) lines += '{:>20}'.format(summ) return lines
To generate a clean output like this..
Image Counts test val train male 1000 1000 1000 female 1000 1000 1000 child 1000 1000 1000 SUM 3000 3000 3000