GAN을 사용하여 CIFAR-10 이미지를 생성하는 프로젝틀르 진행하던 중 plt.imshow()부분에서 오류를 마주했다.
( 해당 프로젝트 노트북 파일 :
https://github.com/YOOHYOJEONG/AIFFEL_LMS_project/blob/master/ex13/ex13_DCGAN_CIFAR10.ipynb )
프로젝트 진행 순서는 다음과 같다.
데이터셋 구성 - 생성자 모델 구현 - 판별자 모델 구현 - 손실 함수와 최적화 함수 구현 - 훈련과정 상세 기능 구현 - 학습 과정 진행 - GAN 훈련 과정 개선
오류를 마주한 부분은 '훈련과정 상세 기능 구현' 부분이었다.
해당 과정에서 샘플을 생성하는 함수를 만들어야 했는데,
이 함수는 tensorflow 형식의 데이터셋을 input으로 입력받아 모델에서 예측한 이미지로 샘플 이미지를 생성하는 함수이다. 즉, 모델에서 예측하여 나온 결과물도 tensorflow 형식이다.
input으로 들어오는 데이터는 이전 단계에서 -1 ~ 1 사이의 값으로 정규화를 해 준 상태였으며
정규화를 한 이미지로 만들어진 샘플 이미지를 시각화 하려고 하면
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
해당 오류 메세지가 뜨며 어두운 이미지를 출력했다.
imshow는 RGB 데이터가 0~1 또는 0~255 사이의 정수 값을 가지고 있어야만 제대로 시각화를 해주기 때문이다.
따라서 역정규화 과정을 추가로 진행하여 샘플 이미지를 시각화 해야 했다.
위의 코드에서 빨간 네모 박스 부분을 고쳐야했다.
처음에는 plt.imshow() 안에서 predictions에 역 정규화만 해 주었는데 제대로 된 결과물이 나오지 않았다.
imshow는 정수형만 표현할 수 있기에 int형으로 바꿔줘야했던 것이다.
문제는 predictions이 tensorflow 데이터 타입이었고 tensorflow를 int로 바꾸는방법을 몰랐기에 구글링을 통해 tf.to_int32라는 것을 사용할 수 있을 것 같았다.
하지만 tf.to_int32를 사용하여 데이터 타입을 변경하려고 했더니
AttributerError : module 'tensorflow' has no attribute 'to_int32' 헤딩 오류를 마주했다.
다시 구글링을 통해 얻은 또 다른 방법은 tf.cast를 사용하는 것이었는데, 이 또한 모듈이 없다는 오류가 떴다.
어떻게 해결할 수 있을지 고민을 하다가 함께 공부하는 분들의 도움을 받아 해결할 수 있었다.
위의 사진에서 빨간 박스 안의 내용처럼 변경을 했다.
tensorflow 데이터 타입인 predictions을 predictions.numpy()를 사용하여 numpy 타입으로 변경 해 준 뒤
역정규화를 진행하고 데이터 타입을 inv_predictions.astypoe(int)를 사용하여 int형으로 바꾸어주었다.
그 결과 역정규화가 잘 된 제대로 된 이미지 결과가 나왔다.