본문 바로가기
프로젝트/AI명종원

Roboflow Java API

by HWK 2024. 5. 10.

roboflow API를 사용하는 데에는 api-key와 api-endpoint가 필요하다.

나는 이 둘을 application.properties에 저장했다.

roboflow.api-key= 자신의 API 키
roboflow.api-endpoint = https://detect.roboflow.com/모델명/버전번호

 

먼저 완성된 모델에 사진을 보내주고, 사진을 모델에서 인식해서, 받은 정보를 쓸 수 있도록 가공해야 한다.

아래와 같이 컨트롤러를 작성한다.

@RestController
@RequestMapping("/api")
public class ImageController {

    private final ImageService imageService;

    public ImageController(ImageService imageService) {
        this.imageService = imageService;
    }

    @PostMapping("/photo-recognition")
    public String uploadImage(@RequestParam("file") MultipartFile file) {
        return imageService.uploadAndProcessImage(file);
    }
}

이 코드는 @RequestParam을 이용해 사진파일을 받게끔해준다.

위 코드를 동작시키기 위해 작성한 서비스 코드는 아래와 같다.

@Service
public class ImageService {
    // Roboflow API 키
    @Value("${roboflow.api-key}")
    private String ROBOFLOW_API_KEY;

    // Roboflow API Endpoint
    @Value("${roboflow.api-endpoint}")
    private String ROBOFLOW_API_ENDPOINT;

    /**
     * roboflow 모델 API에 이미지 보내서 인식된 재료를 String 형으로 가져옴
     * @return : "파, 마늘, 양파, 삼겹살, ..."
     * */
    public String uploadAndProcessImage(MultipartFile imageFile) {
        try {
            // Base64 인코딩
            String encodedFile = new String(Base64.getEncoder().encode(imageFile.getBytes()), StandardCharsets.US_ASCII);

            // 업로드 URL 생성
            String uploadURL = ROBOFLOW_API_ENDPOINT + "?api_key=" + ROBOFLOW_API_KEY + "&name=YOUR_IMAGE.jpg";

            // HTTP 요청 설정
            HttpURLConnection connection = null;
            try {
                // URL에 연결
                URL url = new URL(uploadURL);
                connection = (HttpURLConnection) url.openConnection();
                connection.setRequestMethod("POST");
                connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded");
                connection.setRequestProperty("Content-Length", Integer.toString(encodedFile.getBytes().length));
                connection.setRequestProperty("Content-Language", "en-US");
                connection.setUseCaches(false);
                connection.setDoOutput(true);

                // 요청 보내기
                DataOutputStream wr = new DataOutputStream(connection.getOutputStream());
                wr.writeBytes(encodedFile);
                wr.close();

                // 응답 받기
                InputStream stream = connection.getInputStream();
                BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
                StringBuilder responseBuilder = new StringBuilder();
                String line;
                while ((line = reader.readLine()) != null) {
                    responseBuilder.append(line);
                }
                reader.close();

                // 예측 결과를 String 형으로 변환하여 반환

                return extractClassIdFromResponse(responseBuilder.toString());
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                if (connection != null) {
                    connection.disconnect();
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

        return null; // 예측 결과를 받지 못한 경우
    }

    /**
     * class_id 중복을 없애줌
     * 이후 joinIngredients 사용
     * */
    private String extractClassIdFromResponse(String response) {
        Set<Integer> classIds = new HashSet<>(); // 빠른 중복검사를 위해 HashSet 사용
        String ingredients = "";
        try {
            // JSON 파싱을 위한 ObjectMapper 생성
            ObjectMapper objectMapper = new ObjectMapper();
            // JSON 문자열을 JsonNode로 변환
            JsonNode rootNode = objectMapper.readTree(response);

            // "predictions" 배열의 각 객체에서 "class" 필드 값을 추출하여 리스트에 추가
            if (rootNode.has("predictions")) {
                JsonNode predictionsNode = rootNode.get("predictions");
                if (predictionsNode.isArray()) {
                    for (JsonNode predictionNode : predictionsNode) {
                        if (predictionNode.has("class_id")) {
                            int classId = predictionNode.get("class_id").asInt();
                            System.out.println(predictionNode.get("class").asText() + classId);
                            classIds.add(classId);
                        }
                    }
                }
            }
            ingredients = joinIngredients(classIds);
        } catch (IOException e) {
            e.printStackTrace();
        }

        return ingredients;
    }

    /**
     * IngredientMapper Enum Class를 이용해서 class_id에 맞는 재료의 한글명 가져옴
     * */
    private String joinIngredients(Set<Integer> classIds) {
        StringJoiner resultJoiner = new StringJoiner(", ");

        // classIds에 대해 반복하면서 IngredientMapper를 참조하여 문자열 구성
        for (int classId : classIds) {
            try {
                // 정적(static) 메서드 호출: 클래스 이름으로 직접 호출
                IngredientMapper ingredient = IngredientMapper.getByClassId(classId);
                String className = ingredient.getClassName();

                // StringJoiner에 클래스 이름 추가
                resultJoiner.add(className);
            } catch (IllegalArgumentException e) {
                System.err.println("Invalid class_id: " + classId);
                // 예외 처리: 유효하지 않은 class_id에 대한 처리
            }
        }

        // StringJoiner를 사용하여 최종 결과 문자열 생성
        return resultJoiner.toString();
    }
}
  • uploadAndProcessImage(): 사진 파일을 받아서 Roboflow에 보낼 형식을 작성해준다. roboflow로 부터 response가 오면, extractClassIdFromResponse()와 joinIngredients()를 통해 원하는 반환형으로 바꿔준다.
    ex) "파, 양파, 마늘, 소고기"
  • extractClassIdFromResponse(): 전달받은 response에서 HashSet으로 ClassId를 추출해 중복을 제거한다.
    만들어진 HashSet은 joinIngredients()로 전달한다.
  • joinIngredients(): 전달받은 classId를 중복제거 후 IngredientMapper를 통해 재료 이름으로 바꿔준다.
    이때 재료의 이름을 모두 묶고 ','로 나눠서 반환해준다.

아래는 IngredientMapper이다. 많은 정보를 저장할 필요가 없고, DB와 연결하지 않아도 되므로 좋은 방법이라고 생각하지만, 모델의 클래스가 변경되면 변경시켜줘야 하는 코드이다.

더보기
// 각각에 class_id에 맞는 식재료 한글명
public enum IngredientMapper {
    AVOCADO(0, "아보카도"),
    BEAN_SPROUTS(1, "콩나물"),
    BEEF(2,"소고기"),
    BROCCOLI(3, "브로콜리"),
    CABBAGE(4, "양배추"),
    CARROT(5, "당근"),
    CHEESE(6, "치즈"),
    CHICKEN(7, "닭고기"),
    CHILI(8, "고추"),
    CUCUMBER(9, "오이"),
    DAIKON(10, "무"),
    EGG(11, "계란"),
    EGGPLANT(12, "가지"),
    GARLIC(13, "마늘"),
    GREEN_ONION(14, "파"),
    HAM(15, "햄"),
    KIMCHI(16, "김치"),
    LETTUCE(17, "상추"),
    MUSHROOM(18, "버섯"),
    ONION(19, "양파"),
    PAPRIKA(20, "파프리카"),
    PORK_BELLY(21, "삼겹살"),
    POTATO(22, "감자"),
    SAUSAGE(23, "소시지"),
    SPINACH(24, "시금치"),
    SWEET_POTATO(25, "고구마"),
    TOFU(26, "두부"),
    TOMATO(27, "토마토"),
    ZUCCHINI(28, "애호박");

    private final int classId;
    private final String className;

    IngredientMapper(int classId, String className) {
        this.classId = classId;
        this.className = className;
    }

    public int getClassId() {
        return classId;
    }

    public String getClassName() {
        return className;
    }
    public static IngredientMapper getByClassId(int classId) {
        for (IngredientMapper ingredient : values()) {
            if (ingredient.getClassId() == classId) {
                return ingredient;
            }
        }
        throw new IllegalArgumentException("Invalid class_id: " + classId);
    }
}